sync from monorepo @ 2452e92e

This commit is contained in:
2026-05-08 01:59:04 +02:00
commit b03dc15371
459 changed files with 129586 additions and 0 deletions
+58
View File
@@ -0,0 +1,58 @@
//! Audit logging for sensitive tool operations.
//!
//! This module provides structured logging for:
//! - File read/write operations
//! - Terminal command execution
//! - Permission decisions
//! - Sandbox violations
//!
//! All audit logs include:
//! - Timestamp
//! - User/session context
//! - Operation type
//! - Parameters (sanitized)
//! - Outcome (success/error)
//!
//! TODO: Implement audit logging
use tracing::{info, warn};
/// Log a file read operation.
///
/// TODO: Implement with structured fields
pub fn log_file_read(_path: &str, _success: bool) {
// Placeholder - will use tracing with structured fields
info!("File read audit log placeholder");
}
/// Log a file write operation.
///
/// TODO: Implement with structured fields
pub fn log_file_write(_path: &str, _success: bool) {
// Placeholder - will use tracing with structured fields
info!("File write audit log placeholder");
}
/// Log a terminal command execution.
///
/// TODO: Implement with structured fields
pub fn log_terminal_exec(_command: &str, _success: bool) {
// Placeholder - will use tracing with structured fields
info!("Terminal exec audit log placeholder");
}
/// Log a permission decision.
///
/// TODO: Implement with structured fields
pub fn log_permission_decision(_operation: &str, _allowed: bool) {
// Placeholder - will use tracing with structured fields
info!("Permission decision audit log placeholder");
}
/// Log a sandbox violation attempt.
///
/// TODO: Implement with structured fields
pub fn log_sandbox_violation(_path: &str, _reason: &str) {
// Placeholder - will use tracing with structured fields
warn!("Sandbox violation audit log placeholder");
}
+425
View File
@@ -0,0 +1,425 @@
//! Configuration types for Phase 03 features.
//!
//! This module defines configuration structures for:
//! - `SandboxConfig` - Filesystem sandboxing configuration
//! - `PermissionConfig` - Permission prompt and decision caching
//! - `TerminalConfig` - Terminal/command execution limits
//! - `SearchConfig` - Search operation limits and defaults
//! - `EmbeddingConfig` - File embedding thresholds and policies
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
// =============================================================================
// Sandbox Configuration
// =============================================================================
/// Filesystem sandboxing configuration.
///
/// Determines which paths are accessible and how symlinks are handled.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct SandboxConfig {
/// Absolute paths that are allowed for file operations.
///
/// Operations outside these roots will be rejected.
pub allowed_roots: Vec<PathBuf>,
/// Path patterns that are explicitly blocked even if within allowed roots.
///
/// Supports glob patterns like "**/.env", "**/secrets/**".
pub blocked_paths: Vec<String>,
/// Whether to allow symlinks to escape allowed roots.
///
/// If false (recommended), symlinks pointing outside allowed roots are rejected.
pub allow_symlink_escape: bool,
/// Whether to follow symlinks within allowed roots.
///
/// If true, symlinks within allowed roots are followed.
pub follow_symlinks_within_roots: bool,
/// Enable read operations.
pub read_enabled: bool,
/// Enable write operations.
pub write_enabled: bool,
/// Maximum bytes to read in a single request.
///
/// Soft cap for previews. Default: 1 MB.
pub max_read_bytes: u64,
/// Maximum bytes to write in a single request.
///
/// Default: 1 MB.
pub max_write_bytes: u64,
/// End-of-line policy for file operations.
pub eol_policy: EolPolicy,
/// Text encoding support.
///
/// Currently only UTF-8 is supported.
pub encoding: String,
}
impl Default for SandboxConfig {
fn default() -> Self {
Self {
allowed_roots: vec![],
blocked_paths: vec!["**/.env".to_string(), "**/secrets/**".to_string()],
allow_symlink_escape: false,
follow_symlinks_within_roots: true,
read_enabled: true,
write_enabled: false,
max_read_bytes: 1_048_576, // 1 MB
max_write_bytes: 1_048_576, // 1 MB
eol_policy: EolPolicy::Preserve,
encoding: "utf-8".to_string(),
}
}
}
/// End-of-line handling policy.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EolPolicy {
/// Preserve original line endings.
Preserve,
/// Normalize to LF (\n).
Lf,
/// Normalize to CRLF (\r\n).
Crlf,
}
// =============================================================================
// Permission Configuration
// =============================================================================
/// Permission prompt and decision caching configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct PermissionConfig {
/// Permission mode strategy.
pub mode: PermissionMode,
/// Whether to remember permission decisions.
pub remember_decisions: bool,
/// Time-to-live for cached decisions in seconds.
///
/// Default: 86400 (24 hours).
pub remember_ttl_secs: u64,
/// Scope of cached decisions.
pub scope: DecisionScope,
/// Whitelist configuration for whitelist mode.
pub whitelist: WhitelistConfig,
}
impl Default for PermissionConfig {
fn default() -> Self {
Self {
mode: PermissionMode::Whitelist,
remember_decisions: true,
remember_ttl_secs: 86_400, // 24 hours
scope: DecisionScope::PerConnector,
whitelist: WhitelistConfig::default(),
}
}
}
/// Permission mode strategy.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PermissionMode {
/// Prompt for every sensitive operation.
Ask,
/// Auto-approve whitelisted operations, prompt for others.
Whitelist,
/// Auto-approve all operations (with audit logging).
Yolo,
}
/// Scope for cached permission decisions.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DecisionScope {
/// Decisions persist per connector.
PerConnector,
/// Decisions persist only within a session.
PerSession,
}
/// Whitelist configuration for auto-approved operations.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct WhitelistConfig {
/// Path patterns that are safe for write operations.
///
/// Glob patterns like "C:/work/project/**".
pub write_paths: Vec<String>,
/// Commands that are safe to execute.
///
/// Glob patterns like "cargo", "npm", "git".
pub execute_commands: Vec<String>,
}
impl Default for WhitelistConfig {
fn default() -> Self {
Self {
write_paths: vec![],
execute_commands: vec![],
}
}
}
// =============================================================================
// Terminal Configuration
// =============================================================================
/// Terminal/command execution configuration.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct TerminalConfig {
/// Enable terminal operations.
pub enabled: bool,
/// Default current working directory for terminal commands.
///
/// Must be within an allowed sandbox root.
pub default_cwd: Option<PathBuf>,
/// Environment variable names that are allowed.
///
/// Only these variables can be set in spawned processes.
pub env_allowlist: Vec<String>,
/// Command patterns that are blocked (best-effort).
///
/// Glob patterns for dangerous commands like "rm", "format".
pub command_blocklist: Vec<String>,
/// Maximum bytes to capture from terminal output (ring buffer).
///
/// Default: 200,000 bytes.
pub output_byte_limit: u64,
/// Maximum runtime for a terminal command in seconds.
///
/// Commands exceeding this will be killed. Default: 3600 (1 hour).
pub max_runtime_secs: u64,
}
impl Default for TerminalConfig {
fn default() -> Self {
Self {
enabled: true,
default_cwd: None,
env_allowlist: vec![
"RUST_LOG".to_string(),
"NODE_ENV".to_string(),
"PATH".to_string(),
],
command_blocklist: vec![
"rm".to_string(),
"rd".to_string(),
"format".to_string(),
"mkfs*".to_string(),
],
output_byte_limit: 200_000,
max_runtime_secs: 3_600, // 1 hour
}
}
}
// =============================================================================
// Search Configuration
// =============================================================================
/// Search operation configuration (glob, grep, ls).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct SearchConfig {
/// Maximum number of results to return.
///
/// Default: 5,000.
pub max_results: u32,
/// Maximum total bytes in search results.
///
/// Default: 1,000,000 (1 MB).
pub max_bytes: u64,
/// Default include patterns for searches.
pub default_include_globs: Vec<String>,
/// Default exclude patterns for searches.
///
/// Common directories to skip like target/, .git/, node_modules/.
pub default_exclude_globs: Vec<String>,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
max_results: 5_000,
max_bytes: 1_000_000, // 1 MB
default_include_globs: vec![],
default_exclude_globs: vec![
"**/target/**".to_string(),
"**/.git/**".to_string(),
"**/node_modules/**".to_string(),
"**/__pycache__/**".to_string(),
"**/.venv/**".to_string(),
],
}
}
}
// =============================================================================
// Embedding Configuration
// =============================================================================
/// File embedding configuration for ACP prompt context.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(default)]
pub struct EmbeddingConfig {
/// Maximum bytes to embed as ContentBlock::resource per file.
///
/// Larger files will use resource_link instead. Default: 256,000.
pub max_embed_bytes: u64,
/// Whether to allow resource_link for large files.
///
/// If false, large files are rejected instead of linked.
pub allow_resource_link: bool,
/// Regex patterns for redacting secrets in embedded content.
///
/// Best-effort redaction (does not modify files on disk).
pub redact_patterns: Vec<String>,
/// Snippet strategy when file is too large to embed fully.
pub snippet_strategy: SnippetStrategy,
/// Maximum number of files to embed in a single prompt.
pub max_files_per_prompt: u32,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
max_embed_bytes: 256_000,
allow_resource_link: true,
redact_patterns: vec![
// Common secret patterns
r"(?i)(api[_-]?key|password|secret|token)[:]\s*[']?([a-zA-Z0-9_\-\.]+)[']?".to_string(),
],
snippet_strategy: SnippetStrategy::HeadTail,
max_files_per_prompt: 10,
}
}
}
/// Strategy for creating snippets from large files.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SnippetStrategy {
/// Include beginning and end of file.
HeadTail,
/// Include only the beginning.
HeadOnly,
/// Include only the end.
TailOnly,
}
// =============================================================================
// Validation
// =============================================================================
impl SandboxConfig {
/// Validate the sandbox configuration.
///
/// Returns an error message if the configuration is invalid.
pub fn validate(&self) -> Result<(), String> {
if self.allowed_roots.is_empty() && (self.read_enabled || self.write_enabled) {
return Err("allowed_roots cannot be empty when read or write is enabled".to_string());
}
if self.encoding != "utf-8" {
return Err(format!("unsupported encoding: {}", self.encoding));
}
Ok(())
}
/// Normalize allowed roots to canonical paths.
///
/// This should be called after loading configuration to ensure all roots are canonical.
/// Panics if any root cannot be canonicalized (configuration error).
pub fn normalize_roots(&mut self) {
self.allowed_roots = self.allowed_roots
.iter()
.map(|root| {
dunce::canonicalize(root)
.unwrap_or_else(|_| panic!("Failed to canonicalize allowed root: {:?}", root))
})
.collect();
}
}
impl TerminalConfig {
/// Validate the terminal configuration.
///
/// Returns an error message if the configuration is invalid.
pub fn validate(&self) -> Result<(), String> {
if self.output_byte_limit == 0 {
return Err("output_byte_limit must be greater than 0".to_string());
}
if self.max_runtime_secs == 0 {
return Err("max_runtime_secs must be greater than 0".to_string());
}
Ok(())
}
}
impl SearchConfig {
/// Validate the search configuration.
///
/// Returns an error message if the configuration is invalid.
pub fn validate(&self) -> Result<(), String> {
if self.max_results == 0 {
return Err("max_results must be greater than 0".to_string());
}
if self.max_bytes == 0 {
return Err("max_bytes must be greater than 0".to_string());
}
Ok(())
}
}
impl EmbeddingConfig {
/// Validate the embedding configuration.
///
/// Returns an error message if the configuration is invalid.
pub fn validate(&self) -> Result<(), String> {
if self.max_embed_bytes == 0 {
return Err("max_embed_bytes must be greater than 0".to_string());
}
if self.max_files_per_prompt == 0 {
return Err("max_files_per_prompt must be greater than 0".to_string());
}
Ok(())
}
}
+142
View File
@@ -0,0 +1,142 @@
//! Dispatch orchestration: registry → SecurityFloor → permission → Tool::run.
use crate::floor::{FloorDecision, SecurityFloor};
use crate::registry::ToolRegistry;
use crate::tool::{AnyToolInput, ToolContext, ToolEventSink};
/// Outcome of a dispatch call. Matches `AnyTool::run`'s return shape:
/// `Ok` is a successful structured result, `Err` is a structured error.
pub type DispatchResult = Result<serde_json::Value, serde_json::Value>;
/// Dispatch a tool call through the harness.
///
/// 1. Resolve the tool via the registry (built-in vs dynamic, with collision
/// policy and per-client/per-protocol filters applied).
/// 2. Run the hardcoded [`SecurityFloor`]. Cannot be bypassed by settings.
/// 3. (Permission check is handled by the caller for now — the existing
/// [`crate::permission::check::check_permission`] takes a different
/// operation type and is wired in by the connector. This is documented in
/// `2026-04-28-tool-harness-design.md`.)
/// 4. Run the tool, awaiting its final result.
pub async fn dispatch(
registry: &ToolRegistry,
floor: &SecurityFloor,
name: &str,
input: serde_json::Value,
events: ToolEventSink,
ctx: &ToolContext,
) -> DispatchResult {
let tool = match registry.resolve(name, ctx) {
Some(t) => t,
None => return Err(serde_json::json!({
"error": format!("unknown tool: {name}"),
})),
};
if let FloorDecision::Block { reason } = floor.check(name, &input, ctx) {
return Err(serde_json::json!({ "error": reason }));
}
tool.run(AnyToolInput::Final(input), events, ctx).await
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig};
use crate::permission::check::PermissionContext;
use crate::permission::whitelist::CompiledWhitelist;
use crate::registry::CollisionPolicy;
use crate::tool::{AnyTool, ClientKind, ProtocolKind, Tool, ToolInput, ToolKind};
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Serialize, Deserialize, JsonSchema)]
struct EchoIn { msg: String }
#[derive(Serialize, Deserialize)]
struct EchoOut { echoed: String }
#[derive(Default)]
struct Echo;
#[async_trait]
impl Tool for Echo {
type Input = EchoIn;
type Output = EchoOut;
const NAME: &'static str = "echo";
fn kind() -> ToolKind { ToolKind::Other }
async fn run(
self: Arc<Self>, input: ToolInput<EchoIn>,
_e: ToolEventSink, _c: &ToolContext,
) -> Result<EchoOut, EchoOut> {
let i = match input { ToolInput::Final(i) => i, _ => unreachable!() };
Ok(EchoOut { echoed: i.msg })
}
}
fn ctx() -> ToolContext {
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
let pc = PermissionContext::new("c".to_string(), None, wl);
ToolContext::for_test(
"c", ClientKind::claude(), ProtocolKind::acp(),
PathBuf::from("/tmp"),
SandboxConfig::default(), PermissionConfig::default(), pc,
)
}
fn registry_with_echo() -> ToolRegistry {
let any: Arc<dyn AnyTool> = <Echo as Tool>::erase(Arc::new(Echo));
ToolRegistry::new(vec![any], CollisionPolicy::BuiltInWins)
}
#[tokio::test]
async fn dispatch_runs_known_tool() {
let r = registry_with_echo();
let f = SecurityFloor::new();
let (sink, _rx) = ToolEventSink::new();
let out = dispatch(&r, &f, "echo",
serde_json::json!({ "msg": "hi" }), sink, &ctx()).await.unwrap();
assert_eq!(out["echoed"], "hi");
}
#[tokio::test]
async fn dispatch_rejects_unknown_tool() {
let r = registry_with_echo();
let f = SecurityFloor::new();
let (sink, _rx) = ToolEventSink::new();
let err = dispatch(&r, &f, "nope",
serde_json::json!({}), sink, &ctx()).await.unwrap_err();
assert!(err["error"].as_str().unwrap().contains("unknown tool"));
}
#[tokio::test]
async fn dispatch_blocks_on_security_floor() {
// We can't easily inject a "terminal" tool here, but we exercise the
// floor by registering a tool literally named "terminal".
#[derive(Default)] struct Term;
#[derive(Serialize, Deserialize, JsonSchema)] struct In { command: String }
#[derive(Serialize, Deserialize)] struct Out;
#[async_trait]
impl Tool for Term {
type Input = In; type Output = Out;
const NAME: &'static str = "terminal";
fn kind() -> ToolKind { ToolKind::Execute }
async fn run(
self: Arc<Self>, _i: ToolInput<In>,
_e: ToolEventSink, _c: &ToolContext,
) -> Result<Out, Out> { Ok(Out) }
}
let any: Arc<dyn AnyTool> = <Term as Tool>::erase(Arc::new(Term));
let r = ToolRegistry::new(vec![any], CollisionPolicy::BuiltInWins);
let f = SecurityFloor::new();
let (sink, _rx) = ToolEventSink::new();
let err = dispatch(&r, &f, "terminal",
serde_json::json!({ "command": "rm -rf /" }), sink, &ctx())
.await.unwrap_err();
assert!(err["error"].as_str().unwrap().to_lowercase().contains("blocked"));
}
}
@@ -0,0 +1,467 @@
//! Embedding decision logic for file attachments.
//!
//! This module implements the decision matrix for choosing between embedded content,
//! resource links, or snippets based on capabilities, size, and configuration.
use std::path::Path;
use crate::config::EmbeddingConfig;
use crate::error::ToolResult;
use crate::fs::file_type::detect_file_type;
/// Strategy for including a file in an ACP prompt.
#[derive(Debug, Clone, PartialEq)]
pub enum EmbeddingStrategy {
/// Embed full file content as text resource.
EmbedText {
/// File content (possibly redacted).
content: String,
/// MIME type for the content.
mime_type: String,
},
/// Embed file as base64-encoded blob.
EmbedBlob {
/// Base64-encoded binary data.
data: Vec<u8>,
/// MIME type for the binary data.
mime_type: String,
},
/// Create a resource link (don't embed full content).
Link {
/// URI for the resource (e.g., dirigent://resource/<hash>).
uri: String,
/// Human-readable name (relative path).
name: String,
/// File size in bytes.
size: u64,
/// MIME type (if known).
mime_type: Option<String>,
},
/// Embed a snippet (head/tail) with optional link to full file.
Snippet {
/// Head portion of the file.
head: String,
/// Tail portion of the file.
tail: String,
/// Total size of the original file.
total_size: u64,
/// MIME type.
mime_type: String,
},
/// Deny the attachment (too large, blocked, etc.).
Deny {
/// Reason for denial.
reason: String,
},
}
/// Embedding decider.
///
/// Decides how to include files in ACP prompts based on agent capabilities,
/// file properties, and configuration limits.
pub struct EmbeddingDecider {
/// Embedding configuration.
config: EmbeddingConfig,
/// Whether the agent supports embedded context.
agent_supports_embedded: bool,
/// Total bytes accumulated across all files in this prompt.
accumulated_bytes: usize,
/// Number of files processed so far.
file_count: usize,
}
impl EmbeddingDecider {
/// Create a new embedding decider.
///
/// # Arguments
///
/// * `config` - Embedding configuration with size limits and policies
/// * `agent_supports_embedded` - Whether the agent advertised `embeddedContext` capability
pub fn new(config: EmbeddingConfig, agent_supports_embedded: bool) -> Self {
Self {
config,
agent_supports_embedded,
accumulated_bytes: 0,
file_count: 0,
}
}
/// Decide the embedding strategy for a file.
///
/// This implements the decision tree from the file embedding policy:
/// 1. File type detection
/// 2. Size checks (per-file and total accumulated)
/// 3. Capability check (embeddedContext)
/// 4. Strategy selection (embed text, blob, link, snippet, deny)
///
/// # Arguments
///
/// * `path` - Path to the file to embed
///
/// # Returns
///
/// The chosen embedding strategy.
pub fn decide(&mut self, path: &Path) -> ToolResult<EmbeddingStrategy> {
// Check file count limit
if self.file_count >= self.config.max_files_per_prompt as usize {
return Ok(EmbeddingStrategy::Deny {
reason: format!(
"Maximum file count ({}) exceeded",
self.config.max_files_per_prompt
),
});
}
// Get file metadata
let metadata = std::fs::metadata(path)?;
let file_size = metadata.len();
// Detect file type
let file_type = detect_file_type(path)?;
// Build a name for the file (use file name, not full path)
let name = path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string();
// Decide based on file type and capabilities
let strategy = if file_type.is_binary {
// Binary files: prefer link, allow small blobs if needed
if file_size <= self.config.max_embed_bytes && self.agent_supports_embedded {
// Small binary - could embed as blob, but prefer link for efficiency
if self.config.allow_resource_link {
EmbeddingStrategy::Link {
uri: format!("dirigent://resource/{}", Self::hash_path(path)),
name,
size: file_size,
mime_type: file_type.mime_type,
}
} else {
// Embed as blob only if linking is disabled
let data = std::fs::read(path)?;
EmbeddingStrategy::EmbedBlob {
data,
mime_type: file_type.mime_type.unwrap_or_else(|| "application/octet-stream".to_string()),
}
}
} else {
// Large binary - link only
if self.config.allow_resource_link {
EmbeddingStrategy::Link {
uri: format!("dirigent://resource/{}", Self::hash_path(path)),
name,
size: file_size,
mime_type: file_type.mime_type,
}
} else {
EmbeddingStrategy::Deny {
reason: format!("Binary file too large ({} bytes) and linking is disabled", file_size),
}
}
}
} else {
// Text files: embed if small and capability supports, otherwise link or snippet
if !self.agent_supports_embedded {
// Agent doesn't support embedded context - always link
if self.config.allow_resource_link {
EmbeddingStrategy::Link {
uri: format!("dirigent://resource/{}", Self::hash_path(path)),
name,
size: file_size,
mime_type: file_type.mime_type,
}
} else {
EmbeddingStrategy::Deny {
reason: "Agent does not support embedded context and linking is disabled".to_string(),
}
}
} else if file_size <= self.config.max_embed_bytes {
// Small text file - check total accumulated bytes
let new_total = self.accumulated_bytes + file_size as usize;
// For Phase 03, use max_embed_bytes * max_files_per_prompt as total cap
let max_total_bytes = (self.config.max_embed_bytes as usize)
* (self.config.max_files_per_prompt as usize);
if new_total <= max_total_bytes {
// Embed the file
let content = std::fs::read_to_string(path)?;
EmbeddingStrategy::EmbedText {
content,
mime_type: file_type.mime_type.unwrap_or_else(|| "text/plain; charset=utf-8".to_string()),
}
} else {
// Exceeds total byte cap - link or deny
if self.config.allow_resource_link {
EmbeddingStrategy::Link {
uri: format!("dirigent://resource/{}", Self::hash_path(path)),
name,
size: file_size,
mime_type: file_type.mime_type,
}
} else {
EmbeddingStrategy::Deny {
reason: format!("Total embedded bytes would exceed limit ({} bytes)", max_total_bytes),
}
}
}
} else {
// Large text file - link or snippet
if self.config.allow_resource_link {
EmbeddingStrategy::Link {
uri: format!("dirigent://resource/{}", Self::hash_path(path)),
name,
size: file_size,
mime_type: file_type.mime_type,
}
} else if self.config.snippet_strategy != crate::config::SnippetStrategy::HeadTail {
// Snippet embedding not configured
EmbeddingStrategy::Deny {
reason: format!("File too large ({} bytes) for embedding and linking is disabled", file_size),
}
} else {
// Generate snippet (handled in EMBED-05)
// For now, deny and indicate snippet is needed
EmbeddingStrategy::Deny {
reason: "Snippet generation not yet implemented".to_string(),
}
}
}
};
// Update accumulated bytes if we're embedding
match &strategy {
EmbeddingStrategy::EmbedText { content, .. } => {
self.accumulated_bytes += content.len();
self.file_count += 1;
}
EmbeddingStrategy::EmbedBlob { data, .. } => {
self.accumulated_bytes += data.len();
self.file_count += 1;
}
EmbeddingStrategy::Link { .. } => {
self.file_count += 1;
}
EmbeddingStrategy::Snippet { .. } => {
self.file_count += 1;
}
EmbeddingStrategy::Deny { .. } => {
// Don't increment count for denied files
}
}
Ok(strategy)
}
/// Get the total bytes accumulated so far.
pub fn accumulated_bytes(&self) -> usize {
self.accumulated_bytes
}
/// Get the number of files processed so far.
pub fn file_count(&self) -> usize {
self.file_count
}
/// Generate a stable hash for a file path (used in URIs).
///
/// This creates an opaque, stable identifier for the file.
fn hash_path(path: &Path) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
path.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
}
/// Standalone function for simple decision-making (no state tracking).
///
/// This is a convenience wrapper for one-off decisions without state.
pub fn decide_embedding_strategy(
path: &Path,
agent_supports_embedded: bool,
config: &EmbeddingConfig,
) -> ToolResult<EmbeddingStrategy> {
let mut decider = EmbeddingDecider::new(config.clone(), agent_supports_embedded);
decider.decide(path)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::SnippetStrategy;
use std::io::Write;
use tempfile::NamedTempFile;
fn default_config() -> EmbeddingConfig {
EmbeddingConfig {
max_embed_bytes: 256_000,
allow_resource_link: true,
redact_patterns: vec![],
snippet_strategy: SnippetStrategy::HeadTail,
max_files_per_prompt: 10,
}
}
#[test]
fn test_small_text_file_with_capability() {
let mut temp = NamedTempFile::with_suffix(".txt").unwrap();
temp.write_all(b"Hello, world!").unwrap();
let config = default_config();
let mut decider = EmbeddingDecider::new(config, true);
let strategy = decider.decide(temp.path()).unwrap();
match strategy {
EmbeddingStrategy::EmbedText { content, mime_type } => {
assert_eq!(content, "Hello, world!");
assert!(mime_type.contains("text"));
}
_ => panic!("Expected EmbedText strategy"),
}
}
#[test]
fn test_small_text_file_without_capability() {
let mut temp = NamedTempFile::with_suffix(".txt").unwrap();
temp.write_all(b"Hello, world!").unwrap();
let config = default_config();
let mut decider = EmbeddingDecider::new(config, false);
let strategy = decider.decide(temp.path()).unwrap();
match strategy {
EmbeddingStrategy::Link { .. } => {
// Correct - should link when capability is missing
}
_ => panic!("Expected Link strategy when capability is missing"),
}
}
#[test]
fn test_large_text_file_creates_link() {
let mut temp = NamedTempFile::with_suffix(".txt").unwrap();
// Create file larger than max_embed_bytes
let large_content = "x".repeat(300_000);
temp.write_all(large_content.as_bytes()).unwrap();
let config = default_config();
let mut decider = EmbeddingDecider::new(config, true);
let strategy = decider.decide(temp.path()).unwrap();
match strategy {
EmbeddingStrategy::Link { size, .. } => {
assert_eq!(size, 300_000);
}
_ => panic!("Expected Link strategy for large file"),
}
}
#[test]
fn test_binary_file_creates_link() {
let mut temp = NamedTempFile::with_suffix(".png").unwrap();
temp.write_all(b"\x89PNG\r\n\x1a\n").unwrap();
let config = default_config();
let mut decider = EmbeddingDecider::new(config, true);
let strategy = decider.decide(temp.path()).unwrap();
match strategy {
EmbeddingStrategy::Link { mime_type, .. } => {
assert_eq!(mime_type, Some("image/png".to_string()));
}
_ => panic!("Expected Link strategy for binary file"),
}
}
#[test]
fn test_max_file_count_exceeded() {
let mut config = default_config();
config.max_files_per_prompt = 2;
let mut decider = EmbeddingDecider::new(config, true);
// Process 2 files successfully
let mut temp1 = NamedTempFile::with_suffix(".txt").unwrap();
temp1.write_all(b"File 1").unwrap();
let _ = decider.decide(temp1.path()).unwrap();
let mut temp2 = NamedTempFile::with_suffix(".txt").unwrap();
temp2.write_all(b"File 2").unwrap();
let _ = decider.decide(temp2.path()).unwrap();
// Third file should be denied
let mut temp3 = NamedTempFile::with_suffix(".txt").unwrap();
temp3.write_all(b"File 3").unwrap();
let strategy = decider.decide(temp3.path()).unwrap();
match strategy {
EmbeddingStrategy::Deny { reason } => {
assert!(reason.contains("Maximum file count"));
}
_ => panic!("Expected Deny strategy when file count exceeded"),
}
}
#[test]
fn test_linking_disabled_denies_large_file() {
let mut temp = NamedTempFile::with_suffix(".txt").unwrap();
let large_content = "x".repeat(300_000);
temp.write_all(large_content.as_bytes()).unwrap();
let mut config = default_config();
config.allow_resource_link = false;
let mut decider = EmbeddingDecider::new(config, true);
let strategy = decider.decide(temp.path()).unwrap();
match strategy {
EmbeddingStrategy::Deny { reason } => {
assert!(reason.contains("linking is disabled"));
}
_ => panic!("Expected Deny when linking disabled and file too large"),
}
}
#[test]
fn test_accumulated_bytes_tracking() {
let config = default_config();
let mut decider = EmbeddingDecider::new(config, true);
assert_eq!(decider.accumulated_bytes(), 0);
let mut temp1 = NamedTempFile::with_suffix(".txt").unwrap();
temp1.write_all(b"Hello").unwrap();
let _ = decider.decide(temp1.path()).unwrap();
assert_eq!(decider.accumulated_bytes(), 5);
let mut temp2 = NamedTempFile::with_suffix(".txt").unwrap();
temp2.write_all(b"World!").unwrap();
let _ = decider.decide(temp2.path()).unwrap();
assert_eq!(decider.accumulated_bytes(), 11);
}
#[test]
fn test_hash_path_is_stable() {
let path = Path::new("/test/file.txt");
let hash1 = EmbeddingDecider::hash_path(path);
let hash2 = EmbeddingDecider::hash_path(path);
assert_eq!(hash1, hash2);
}
}
@@ -0,0 +1,12 @@
//! File embedding utilities for ACP prompts.
//!
//! This module provides the core logic for deciding how to include file content
//! in ACP prompts based on agent capabilities, file size, and configuration.
pub mod decider;
pub mod redactor;
pub mod snippet;
pub use decider::{decide_embedding_strategy, EmbeddingStrategy, EmbeddingDecider};
pub use redactor::{ContentRedactor, RedactionPattern};
pub use snippet::{generate_snippet, Snippet, SnippetConfig};
@@ -0,0 +1,382 @@
//! Content redaction for embedded files.
//!
//! This module implements pattern-based redaction to prevent secrets from
//! being included in embedded file content. Redaction only affects the
//! payload sent to the agent - files on disk are never modified.
use regex::Regex;
use crate::error::ToolResult;
/// A single redaction pattern with name and replacement text.
#[derive(Debug, Clone)]
pub struct RedactionPattern {
/// Human-readable name for this pattern (e.g., "API_KEY").
pub name: String,
/// Compiled regular expression to match.
pub regex: Regex,
/// Replacement text (e.g., "<REDACTED:API_KEY>").
pub replacement: String,
}
impl RedactionPattern {
/// Create a new redaction pattern.
///
/// # Arguments
///
/// * `name` - Human-readable name
/// * `pattern` - Regular expression pattern
/// * `replacement` - Replacement text (can include pattern name)
///
/// # Returns
///
/// A compiled redaction pattern, or an error if the regex is invalid.
pub fn new(
name: impl Into<String>,
pattern: &str,
replacement: impl Into<String>,
) -> ToolResult<Self> {
let name = name.into();
let regex = Regex::new(pattern).map_err(|e| {
crate::error::ToolError::InvalidInput(format!("Invalid redaction pattern: {}", e))
})?;
Ok(Self {
name,
regex,
replacement: replacement.into(),
})
}
/// Create a redaction pattern with default "<REDACTED:NAME>" replacement.
pub fn with_default_replacement(name: impl Into<String>, pattern: &str) -> ToolResult<Self> {
let name_str = name.into();
let replacement = format!("<REDACTED:{}>", name_str.to_uppercase());
Self::new(name_str, pattern, replacement)
}
}
/// Content redactor for secret patterns.
///
/// Applies a set of redaction patterns to content before embedding.
pub struct ContentRedactor {
/// Redaction patterns to apply.
patterns: Vec<RedactionPattern>,
}
impl ContentRedactor {
/// Create a new content redactor with the given patterns.
///
/// # Arguments
///
/// * `pattern_strings` - List of regex pattern strings from configuration
///
/// # Returns
///
/// A content redactor, or an error if any pattern is invalid.
pub fn new(pattern_strings: &[String]) -> ToolResult<Self> {
let mut patterns = Vec::new();
for (i, pattern_str) in pattern_strings.iter().enumerate() {
let pattern = RedactionPattern::with_default_replacement(
format!("CUSTOM_{}", i),
pattern_str,
)?;
patterns.push(pattern);
}
Ok(Self { patterns })
}
/// Create a redactor with default built-in patterns.
///
/// Default patterns include:
/// - API keys
/// - AWS credentials
/// - Generic secrets and tokens
/// - Passwords
pub fn with_default_patterns() -> Self {
let mut patterns = Vec::new();
// API keys (various formats)
if let Ok(p) = RedactionPattern::with_default_replacement(
"API_KEY",
r#"(?i)(api[_-]?key|apikey)[:=\s]+["']?([a-zA-Z0-9_\-\.]{20,})["']?"#,
) {
patterns.push(p);
}
// AWS access key IDs
if let Ok(p) = RedactionPattern::with_default_replacement(
"AWS_ACCESS_KEY",
r"AKIA[0-9A-Z]{16}",
) {
patterns.push(p);
}
// Generic secrets and tokens
if let Ok(p) = RedactionPattern::with_default_replacement(
"SECRET",
r#"(?i)(secret|token)[:=\s]+["']?([a-zA-Z0-9_\-\.]{16,})["']?"#,
) {
patterns.push(p);
}
// Passwords
if let Ok(p) = RedactionPattern::with_default_replacement(
"PASSWORD",
r#"(?i)password[:=\s]+["']?([^\s"']{8,})["']?"#,
) {
patterns.push(p);
}
// Bearer tokens
if let Ok(p) = RedactionPattern::with_default_replacement(
"BEARER_TOKEN",
r"(?i)bearer\s+([a-zA-Z0-9_\-\.=]+)",
) {
patterns.push(p);
}
Self { patterns }
}
/// Create a redactor combining default and custom patterns.
///
/// # Arguments
///
/// * `custom_patterns` - Additional custom pattern strings
pub fn with_custom_patterns(custom_patterns: &[String]) -> ToolResult<Self> {
let mut redactor = Self::with_default_patterns();
for (i, pattern_str) in custom_patterns.iter().enumerate() {
let pattern = RedactionPattern::with_default_replacement(
format!("CUSTOM_{}", i),
pattern_str,
)?;
redactor.patterns.push(pattern);
}
Ok(redactor)
}
/// Redact sensitive content from the given text.
///
/// Applies all configured redaction patterns in order and returns
/// the redacted content. The original content is never modified.
///
/// # Arguments
///
/// * `content` - The content to redact
///
/// # Returns
///
/// Redacted content with sensitive data replaced.
pub fn redact(&self, content: &str) -> String {
let mut redacted = content.to_string();
for pattern in &self.patterns {
redacted = pattern.regex.replace_all(&redacted, &pattern.replacement).to_string();
}
// Ensure UTF-8 safety (should always be valid since we're only replacing with ASCII)
debug_assert!(redacted.is_char_boundary(redacted.len()));
redacted
}
/// Check if content contains any patterns that would be redacted.
///
/// This can be used to warn users before embedding.
pub fn contains_secrets(&self, content: &str) -> bool {
self.patterns.iter().any(|p| p.regex.is_match(content))
}
/// Get the number of redaction patterns configured.
pub fn pattern_count(&self) -> usize {
self.patterns.len()
}
}
// Implement a safe Debug that doesn't leak pattern details
impl std::fmt::Debug for ContentRedactor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ContentRedactor")
.field("pattern_count", &self.patterns.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_redaction_pattern_creation() {
let pattern = RedactionPattern::new("TEST", r"\d+", "XXX").unwrap();
assert_eq!(pattern.name, "TEST");
assert_eq!(pattern.replacement, "XXX");
assert!(pattern.regex.is_match("123"));
}
#[test]
fn test_redaction_pattern_with_default_replacement() {
let pattern = RedactionPattern::with_default_replacement("api_key", r"key_\d+").unwrap();
assert_eq!(pattern.name, "api_key");
assert_eq!(pattern.replacement, "<REDACTED:API_KEY>");
}
#[test]
fn test_redaction_pattern_invalid_regex() {
let result = RedactionPattern::new("TEST", r"[invalid(", "XXX");
assert!(result.is_err());
}
#[test]
fn test_default_patterns_api_key() {
let redactor = ContentRedactor::with_default_patterns();
let content = "API_KEY=sk_live_1234567890abcdefghij";
let redacted = redactor.redact(content);
assert!(redacted.contains("<REDACTED:"));
assert!(!redacted.contains("sk_live_"));
}
#[test]
fn test_default_patterns_aws_key() {
let redactor = ContentRedactor::with_default_patterns();
let content = "AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE";
let redacted = redactor.redact(content);
assert!(redacted.contains("<REDACTED:"));
assert!(!redacted.contains("AKIAIOSFODNN7EXAMPLE"));
}
#[test]
fn test_default_patterns_secret() {
let redactor = ContentRedactor::with_default_patterns();
let content = "secret: my_super_secret_value_12345";
let redacted = redactor.redact(content);
assert!(redacted.contains("<REDACTED:"));
assert!(!redacted.contains("my_super_secret"));
}
#[test]
fn test_default_patterns_password() {
let redactor = ContentRedactor::with_default_patterns();
let content = "password=MyP@ssw0rd123";
let redacted = redactor.redact(content);
assert!(redacted.contains("<REDACTED:"));
assert!(!redacted.contains("MyP@ssw0rd"));
}
#[test]
fn test_default_patterns_bearer_token() {
let redactor = ContentRedactor::with_default_patterns();
let content = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9";
let redacted = redactor.redact(content);
assert!(redacted.contains("<REDACTED:"));
assert!(!redacted.contains("eyJhbGci"));
}
#[test]
fn test_custom_patterns() {
let custom = vec![r"CUSTOM-\d{4}".to_string()];
let redactor = ContentRedactor::with_custom_patterns(&custom).unwrap();
let content = "My custom ID is CUSTOM-1234";
let redacted = redactor.redact(content);
assert!(redacted.contains("<REDACTED:"));
assert!(!redacted.contains("CUSTOM-1234"));
}
#[test]
fn test_multiple_patterns() {
let redactor = ContentRedactor::with_default_patterns();
let content = r#"
API_KEY="sk_test_abcdefghijklmnopqrst"
password=secret123
token: my_long_secret_token_value
"#;
let redacted = redactor.redact(content);
// All secrets should be redacted
assert!(!redacted.contains("sk_test_"));
assert!(!redacted.contains("secret123"));
assert!(!redacted.contains("my_long_secret"));
assert!(redacted.contains("<REDACTED:"));
}
#[test]
fn test_contains_secrets() {
let redactor = ContentRedactor::with_default_patterns();
assert!(redactor.contains_secrets("API_KEY=sk_test_12345678901234567890"));
assert!(redactor.contains_secrets("password=secret"));
assert!(!redactor.contains_secrets("Hello, world!"));
}
#[test]
fn test_utf8_safety() {
let redactor = ContentRedactor::with_default_patterns();
let content = "API_KEY=sk_test_12345678901234567890 你好世界";
let redacted = redactor.redact(content);
// Should still contain valid UTF-8
assert!(redacted.contains("你好世界"));
assert!(std::str::from_utf8(redacted.as_bytes()).is_ok());
}
#[test]
fn test_no_false_positives_on_code() {
let redactor = ContentRedactor::with_default_patterns();
let code = r#"
fn api_key_validator(input: &str) -> bool {
input.len() > 20
}
"#;
let redacted = redactor.redact(code);
// Function name and code should not be redacted
assert!(redacted.contains("api_key_validator"));
assert!(redacted.contains("fn"));
}
#[test]
fn test_pattern_count() {
let redactor = ContentRedactor::with_default_patterns();
assert_eq!(redactor.pattern_count(), 5); // 5 default patterns
let custom = vec![r"CUSTOM-\d{4}".to_string()];
let redactor_custom = ContentRedactor::with_custom_patterns(&custom).unwrap();
assert_eq!(redactor_custom.pattern_count(), 6); // 5 default + 1 custom
}
#[test]
fn test_debug_impl_safe() {
let redactor = ContentRedactor::with_default_patterns();
let debug_str = format!("{:?}", redactor);
// Should not leak actual pattern details
assert!(debug_str.contains("ContentRedactor"));
assert!(debug_str.contains("pattern_count"));
assert!(!debug_str.contains("regex")); // Internal details hidden
}
}
@@ -0,0 +1,354 @@
//! Snippet generation for large files.
//!
//! This module implements partial selection strategies (head/tail/window)
//! for files exceeding embed size limits.
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use crate::error::ToolResult;
/// Configuration for snippet generation.
#[derive(Debug, Clone)]
pub struct SnippetConfig {
/// Number of lines to include from the beginning.
pub head_lines: usize,
/// Number of lines to include from the end.
pub tail_lines: usize,
/// Maximum total bytes for the snippet.
pub max_snippet_bytes: usize,
}
impl Default for SnippetConfig {
fn default() -> Self {
Self {
head_lines: 200,
tail_lines: 200,
max_snippet_bytes: 128_000, // Half of default max_embed_bytes
}
}
}
/// A snippet extracted from a file.
#[derive(Debug, Clone)]
pub struct Snippet {
/// First portion of the file (head).
pub head: Option<String>,
/// Last portion of the file (tail).
pub tail: Option<String>,
/// Total size of the original file in bytes.
pub total_size: u64,
/// Total number of lines in the original file.
pub total_lines: usize,
/// Whether the snippet is truncated (doesn't show full file).
pub truncated: bool,
}
impl Snippet {
/// Render the snippet with metadata comments.
///
/// Returns a formatted string with:
/// - Metadata header showing file info and truncation
/// - Head content
/// - Separator (if both head and tail present)
/// - Tail content
pub fn render(&self, file_path: &Path) -> String {
let file_name = file_path
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown");
let mut output = String::new();
// Metadata header
output.push_str(&format!("# File: {} (truncated)\n", file_name));
output.push_str(&format!(
"# Total size: {:.1} KB, {} lines\n",
self.total_size as f64 / 1024.0,
self.total_lines
));
if self.head.is_some() && self.tail.is_some() {
output.push_str(&format!(
"# Showing: first {} lines + last {} lines\n\n",
self.head.as_ref().unwrap().lines().count(),
self.tail.as_ref().unwrap().lines().count()
));
} else if self.head.is_some() {
output.push_str(&format!(
"# Showing: first {} lines\n\n",
self.head.as_ref().unwrap().lines().count()
));
} else if self.tail.is_some() {
output.push_str(&format!(
"# Showing: last {} lines\n\n",
self.tail.as_ref().unwrap().lines().count()
));
}
// Head content
if let Some(ref head) = self.head {
output.push_str(head);
if !head.ends_with('\n') {
output.push('\n');
}
}
// Separator if both head and tail
if self.head.is_some() && self.tail.is_some() {
output.push_str("\n# ... (middle content omitted) ...\n\n");
}
// Tail content
if let Some(ref tail) = self.tail {
output.push_str(tail);
if !tail.ends_with('\n') {
output.push('\n');
}
}
output
}
}
/// Generate a snippet from a file.
///
/// Reads the first N and last N lines from the file, respecting UTF-8
/// boundaries and byte limits.
///
/// # Arguments
///
/// * `path` - Path to the file
/// * `config` - Snippet configuration
///
/// # Returns
///
/// A snippet with head/tail content and metadata.
pub fn generate_snippet(path: &Path, config: &SnippetConfig) -> ToolResult<Snippet> {
let metadata = std::fs::metadata(path)?;
let total_size = metadata.len();
// Read the full file to count lines and extract head/tail
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut all_lines: Vec<String> = Vec::new();
for line in reader.lines() {
all_lines.push(line?);
}
let total_lines = all_lines.len();
// Check if file is small enough to not need truncation
if total_lines <= config.head_lines + config.tail_lines {
return Ok(Snippet {
head: Some(all_lines.join("\n")),
tail: None,
total_size,
total_lines,
truncated: false,
});
}
// Extract head and tail
let head_content = all_lines
.iter()
.take(config.head_lines)
.cloned()
.collect::<Vec<_>>()
.join("\n");
let tail_content = all_lines
.iter()
.skip(all_lines.len().saturating_sub(config.tail_lines))
.cloned()
.collect::<Vec<_>>()
.join("\n");
// Check if combined snippet exceeds max bytes
let combined_bytes = head_content.len() + tail_content.len();
if combined_bytes > config.max_snippet_bytes {
// Reduce to head only if combined is too large
let truncated_head = truncate_to_bytes(&head_content, config.max_snippet_bytes);
Ok(Snippet {
head: Some(truncated_head),
tail: None,
total_size,
total_lines,
truncated: true,
})
} else {
Ok(Snippet {
head: Some(head_content),
tail: Some(tail_content),
total_size,
total_lines,
truncated: true,
})
}
}
/// Truncate a string to fit within a byte limit, respecting UTF-8 boundaries.
fn truncate_to_bytes(s: &str, max_bytes: usize) -> String {
if s.len() <= max_bytes {
return s.to_string();
}
// Find the largest valid UTF-8 boundary within max_bytes
let mut boundary = max_bytes;
while boundary > 0 && !s.is_char_boundary(boundary) {
boundary -= 1;
}
s[..boundary].to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_snippet_config_default() {
let config = SnippetConfig::default();
assert_eq!(config.head_lines, 200);
assert_eq!(config.tail_lines, 200);
assert_eq!(config.max_snippet_bytes, 128_000);
}
#[test]
fn test_generate_snippet_small_file() {
let mut temp = NamedTempFile::new().unwrap();
temp.write_all(b"Line 1\nLine 2\nLine 3\n").unwrap();
let config = SnippetConfig {
head_lines: 10,
tail_lines: 10,
max_snippet_bytes: 10_000,
};
let snippet = generate_snippet(temp.path(), &config).unwrap();
assert!(!snippet.truncated);
assert_eq!(snippet.total_lines, 3);
assert!(snippet.head.is_some());
assert!(snippet.tail.is_none());
}
#[test]
fn test_generate_snippet_large_file() {
let mut temp = NamedTempFile::new().unwrap();
for i in 1..=1000 {
writeln!(temp, "Line {}", i).unwrap();
}
let config = SnippetConfig {
head_lines: 10,
tail_lines: 10,
max_snippet_bytes: 100_000,
};
let snippet = generate_snippet(temp.path(), &config).unwrap();
assert!(snippet.truncated);
assert_eq!(snippet.total_lines, 1000);
assert!(snippet.head.is_some());
assert!(snippet.tail.is_some());
// Head should have first 10 lines
let head = snippet.head.unwrap();
assert!(head.contains("Line 1"));
assert!(head.contains("Line 10"));
assert!(!head.contains("Line 11"));
// Tail should have last 10 lines
let tail = snippet.tail.unwrap();
assert!(tail.contains("Line 1000"));
assert!(tail.contains("Line 991"));
assert!(!tail.contains("Line 990"));
}
#[test]
fn test_snippet_render() {
let snippet = Snippet {
head: Some("Line 1\nLine 2".to_string()),
tail: Some("Line 99\nLine 100".to_string()),
total_size: 10_000,
total_lines: 100,
truncated: true,
};
let rendered = snippet.render(Path::new("/test/file.txt"));
assert!(rendered.contains("# File: file.txt (truncated)"));
assert!(rendered.contains("# Total size:"));
assert!(rendered.contains("100 lines"));
assert!(rendered.contains("Line 1"));
assert!(rendered.contains("Line 100"));
assert!(rendered.contains("... (middle content omitted) ..."));
}
#[test]
fn test_snippet_render_head_only() {
let snippet = Snippet {
head: Some("Line 1\nLine 2".to_string()),
tail: None,
total_size: 1_000,
total_lines: 10,
truncated: true,
};
let rendered = snippet.render(Path::new("/test/file.txt"));
assert!(rendered.contains("Line 1"));
assert!(!rendered.contains("... (middle content omitted) ..."));
}
#[test]
fn test_truncate_to_bytes_exact() {
let s = "Hello";
let truncated = truncate_to_bytes(s, 5);
assert_eq!(truncated, "Hello");
}
#[test]
fn test_truncate_to_bytes_shorter() {
let s = "Hello, world!";
let truncated = truncate_to_bytes(s, 5);
assert_eq!(truncated, "Hello");
}
#[test]
fn test_truncate_to_bytes_utf8_boundary() {
let s = "Hello 你好世界";
// Try to truncate in the middle of a multibyte character
let truncated = truncate_to_bytes(s, 8);
// Should truncate before the multibyte char to maintain UTF-8 validity
assert_eq!(truncated, "Hello ");
}
#[test]
fn test_snippet_exceeds_max_bytes() {
let mut temp = NamedTempFile::new().unwrap();
for i in 1..=1000 {
writeln!(temp, "This is a very long line number {} with lots of text", i).unwrap();
}
let config = SnippetConfig {
head_lines: 500,
tail_lines: 500,
max_snippet_bytes: 1_000, // Very small limit
};
let snippet = generate_snippet(temp.path(), &config).unwrap();
assert!(snippet.truncated);
// Should fall back to head only when combined exceeds max
assert!(snippet.head.is_some());
let head = snippet.head.unwrap();
assert!(head.len() <= config.max_snippet_bytes);
}
}
+108
View File
@@ -0,0 +1,108 @@
//! Error types for tool operations.
use thiserror::Error;
/// Result type for tool operations.
pub type ToolResult<T> = Result<T, ToolError>;
/// Errors that can occur during tool operations.
#[derive(Error, Debug)]
pub enum ToolError {
/// File or directory not found.
#[error("Not found: {path}")]
NotFound { path: String },
/// Permission denied by OS or sandbox policy.
#[error("Permission denied: {reason}")]
PermissionDenied { reason: String },
/// Path outside allowed sandbox roots.
#[error("Sandbox violation: {reason}")]
SandboxViolation { reason: String },
/// Path matched blocklist pattern.
#[error("Blocked path: {reason}")]
BlockedPath { reason: String },
/// File too large to process.
#[error("File too large: {size} bytes exceeds limit of {limit} bytes")]
FileTooLarge { size: u64, limit: u64 },
/// Invalid encoding (non-UTF-8).
#[error("Encoding not supported: {encoding}")]
EncodingUnsupported { encoding: String },
/// User rejected permission prompt.
#[error("Permission rejected by user")]
PermissionRejected,
/// Terminal operation failed.
#[error("Terminal error: {message}")]
TerminalError { message: String },
/// Terminal not found.
#[error("Terminal not found: {terminal_id}")]
TerminalNotFound { terminal_id: String },
/// Search operation exceeded limits.
#[error("Search limit exceeded: {reason}")]
SearchLimitExceeded { reason: String },
/// Invalid configuration.
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
/// File read error with detailed information.
#[error("Failed to read file {path}: {source}")]
FileReadError {
path: String,
#[source]
source: std::io::Error,
},
/// Invalid input or parameters.
#[error("Invalid input: {0}")]
InvalidInput(String),
/// I/O error.
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
/// JSON serialization error.
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
/// Generic error.
#[error("{0}")]
Other(#[from] anyhow::Error),
}
impl ToolError {
/// Create a sandbox violation error without exposing the full path.
pub fn sandbox_violation(reason: impl Into<String>) -> Self {
Self::SandboxViolation {
reason: reason.into(),
}
}
/// Create a blocked path error without exposing the full path.
pub fn blocked_path(reason: impl Into<String>) -> Self {
Self::BlockedPath {
reason: reason.into(),
}
}
/// Create a permission denied error.
pub fn permission_denied(reason: impl Into<String>) -> Self {
Self::PermissionDenied {
reason: reason.into(),
}
}
/// Create a terminal error.
pub fn terminal_error(message: impl Into<String>) -> Self {
Self::TerminalError {
message: message.into(),
}
}
}
+189
View File
@@ -0,0 +1,189 @@
//! Hardcoded security floor. Settings cannot bypass these rules.
//!
//! Initial rule set lifted from Zed's `HARDCODED_SECURITY_RULES`:
//! catastrophic recursive deletes (`rm -rf /`, `~`, `$HOME`, `.`, `..`).
pub mod shell;
use crate::tool::ToolContext;
use regex::Regex;
use std::sync::OnceLock;
#[derive(Debug)]
pub enum FloorDecision {
Pass,
Block { reason: &'static str },
}
pub struct SecurityFloor {
terminal_deny: Vec<Regex>,
}
impl SecurityFloor {
pub fn new() -> Self {
Self { terminal_deny: terminal_deny_patterns().clone() }
}
/// Check whether the given tool invocation hits a hard rule.
/// `tool` is the tool name; `input` is the JSON arguments object.
pub fn check(&self, tool: &str, input: &serde_json::Value, _ctx: &ToolContext) -> FloorDecision {
if tool != "terminal" && tool != "bash" {
return FloorDecision::Pass;
}
let Some(cmd) = extract_command(input) else { return FloorDecision::Pass };
// Check the raw command first
if self.matches_any(cmd) {
return FloorDecision::Block {
reason: "Blocked by built-in security rule (recursive delete of \
root, home, current, or parent directory).",
};
}
// Then each sub-command in a chain
for sub in shell::split_chain(cmd) {
if self.matches_any(sub) {
return FloorDecision::Block {
reason: "Blocked by built-in security rule (chained recursive \
delete detected).",
};
}
}
if interpolation_pattern().is_match(cmd) {
return FloorDecision::Block {
reason: "Blocked: shell substitutions/interpolations are not allowed \
in terminal commands. Resolve $VAR, ${VAR}, $(...), backticks, \
$((...)), <(...), >(...) before calling.",
};
}
FloorDecision::Pass
}
fn matches_any(&self, cmd: &str) -> bool {
self.terminal_deny.iter().any(|re| re.is_match(cmd))
}
}
impl Default for SecurityFloor { fn default() -> Self { Self::new() } }
fn extract_command(input: &serde_json::Value) -> Option<&str> {
input.get("command").and_then(|v| v.as_str())
}
fn terminal_deny_patterns() -> &'static Vec<Regex> {
static SET: OnceLock<Vec<Regex>> = OnceLock::new();
SET.get_or_init(|| {
const FLAGS: &str = r"(--[a-zA-Z0-9][-a-zA-Z0-9_]*(=[^\s]*)?\s+|-[a-zA-Z]+\s+)*";
const TRAIL: &str = r"(\s+--[a-zA-Z0-9][-a-zA-Z0-9_]*(=[^\s]*)?|\s+-[a-zA-Z]+)*\s*";
vec![
Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?/\*?{TRAIL}$")).unwrap(),
Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?~/?\*?{TRAIL}$")).unwrap(),
Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?(\$HOME|\$\{{HOME\}})/?\*?{TRAIL}$")).unwrap(),
Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?\./?\*?{TRAIL}$")).unwrap(),
Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?\.\./?\*?{TRAIL}$")).unwrap(),
]
})
}
fn interpolation_pattern() -> &'static Regex {
static RE: OnceLock<Regex> = OnceLock::new();
RE.get_or_init(|| {
// $VAR, ${...}, $(...), $((...)), `...`, <(...), >(...)
Regex::new(r#"(\$[A-Za-z_][A-Za-z0-9_]*|\$\{[^}]*\}|\$\([^)]*\)|\$\(\([^)]*\)\)|`[^`]*`|<\([^)]*\)|>\([^)]*\))"#).unwrap()
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig};
use crate::permission::check::PermissionContext;
use crate::permission::whitelist::CompiledWhitelist;
use crate::tool::{ClientKind, ProtocolKind};
use std::path::PathBuf;
fn ctx() -> ToolContext {
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
let pc = PermissionContext::new("c".to_string(), None, wl);
ToolContext::for_test(
"c", ClientKind::claude(), ProtocolKind::acp(),
PathBuf::from("/tmp"),
SandboxConfig::default(), PermissionConfig::default(), pc,
)
}
fn input(cmd: &str) -> serde_json::Value {
serde_json::json!({ "command": cmd })
}
#[test]
fn floor_passes_non_terminal_tools() {
let f = SecurityFloor::new();
assert!(matches!(f.check("read", &input("rm -rf /"), &ctx()), FloorDecision::Pass));
}
#[test]
fn floor_blocks_rm_rf_root() {
let f = SecurityFloor::new();
for cmd in ["rm -rf /", "rm -rfv /", "rm -rf /*"] {
assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Block { .. }),
"expected block for {cmd}");
}
}
#[test]
fn floor_blocks_home_variants() {
let f = SecurityFloor::new();
for cmd in ["rm -rf ~", "rm -rf ~/", "rm -rf $HOME", "rm -rf ${HOME}", "rm -rf ${HOME}/*"] {
assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Block { .. }),
"expected block for {cmd}");
}
}
#[test]
fn floor_blocks_dot_dotdot() {
let f = SecurityFloor::new();
for cmd in ["rm -rf .", "rm -rf ./", "rm -rf ./*", "rm -rf ..", "rm -rf ../*"] {
assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Block { .. }),
"expected block for {cmd}");
}
}
#[test]
fn floor_passes_safe_commands() {
let f = SecurityFloor::new();
for cmd in ["ls", "rm foo.txt", "rm -rf ./build", "rm -rf /tmp/work"] {
assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Pass),
"expected pass for {cmd}");
}
}
#[test]
fn floor_blocks_chained_rm_rf_root() {
let f = SecurityFloor::new();
assert!(matches!(
f.check("terminal", &input("ls && rm -rf /"), &ctx()),
FloorDecision::Block { .. }
));
}
#[test]
fn floor_blocks_interpolation_in_terminal() {
let f = SecurityFloor::new();
for cmd in ["echo $VAR", "echo ${VAR}", "echo $(date)", "echo `date`",
"echo $((1+1))", "cat <(ls)", "tee >(ls)"] {
assert!(matches!(
f.check("terminal", &input(cmd), &ctx()),
FloorDecision::Block { .. }
), "expected block for {cmd}");
}
}
#[test]
fn floor_passes_command_without_interpolation() {
let f = SecurityFloor::new();
assert!(matches!(
f.check("terminal", &input("git status"), &ctx()),
FloorDecision::Pass
));
}
}
+66
View File
@@ -0,0 +1,66 @@
//! Lightweight POSIX shell command decomposition.
//!
//! Splits a one-line shell input on `&&`, `||`, `;`, `|` operators.
//! Quoted regions are preserved so chain tokens inside quotes do not split.
//! This is a heuristic — not a full shell parser. It is sufficient to ensure
//! a chained `rm -rf /` is not hidden behind a leading benign command.
/// Split `input` into individual sub-commands by POSIX chain operators.
/// Returns the original string as a single element if no chains are detected.
pub fn split_chain(input: &str) -> Vec<&str> {
let mut out = Vec::new();
let bytes = input.as_bytes();
let mut start = 0;
let mut i = 0;
let mut quote: Option<u8> = None;
let mut escape = false;
while i < bytes.len() {
let b = bytes[i];
if escape { escape = false; i += 1; continue; }
if b == b'\\' && quote != Some(b'\'') { escape = true; i += 1; continue; }
match quote {
Some(q) if q == b => { quote = None; i += 1; continue; }
Some(_) => { i += 1; continue; }
None => {
if b == b'\'' || b == b'"' { quote = Some(b); i += 1; continue; }
}
}
let two = bytes.get(i..i+2);
if two == Some(b"&&") || two == Some(b"||") {
out.push(input[start..i].trim());
i += 2; start = i; continue;
}
if b == b';' || b == b'|' {
out.push(input[start..i].trim());
i += 1; start = i; continue;
}
i += 1;
}
let tail = input[start..].trim();
if !tail.is_empty() { out.push(tail); }
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test] fn splits_on_and() {
assert_eq!(split_chain("ls && rm -rf /"), vec!["ls", "rm -rf /"]);
}
#[test] fn splits_on_semicolon_and_pipe() {
assert_eq!(split_chain("ls ; rm -rf / | wc"), vec!["ls", "rm -rf /", "wc"]);
}
#[test] fn keeps_quoted_chains_intact() {
assert_eq!(split_chain("echo 'a && b'"), vec!["echo 'a && b'"]);
}
#[test] fn handles_no_chain() {
assert_eq!(split_chain("ls -la"), vec!["ls -la"]);
}
#[test] fn handles_escaped_pipe() {
assert_eq!(split_chain(r"echo a\|b"), vec![r"echo a\|b"]);
}
}
+28
View File
@@ -0,0 +1,28 @@
//! File system operations with sandbox enforcement.
//!
//! This module provides:
//! - `read_text_file()` - Read UTF-8 text files with line range support (TOOLS-FS-01)
//! - `write_text_file()` - Write UTF-8 text files with atomic writes (TOOLS-FS-02)
//! - `generate_diff()` - Generate unified diffs for changes (TOOLS-FS-03)
//! - `edit_file()` - Apply text replacements and generate diffs (TOOLS-FS-04)
//!
//! All operations are subject to:
//! - Sandbox containment checks
//! - Blocklist enforcement
//! - Size limits
//! - Permission prompts (for writes)
//!
//! **Status**: All functions stubbed, implementation pending
pub mod read;
pub mod write;
pub mod diff;
pub mod edit;
pub mod file_type;
// Re-export main types and functions
pub use read::{read_text_file, ReadTextFileRequest, ReadTextFileResponse};
pub use write::{write_text_file, normalize_eol, WriteTextFileRequest, WriteTextFileResponse};
pub use diff::generate_diff;
pub use edit::{edit_file, EditFileRequest, EditFileResponse, EditOperation};
pub use file_type::{detect_file_type, is_valid_utf8, FileTypeInfo};
+195
View File
@@ -0,0 +1,195 @@
//! Unified diff generation for write operations.
//!
//! **Status**: Implemented (TOOLS-FS-03)
//!
//! This module implements:
//! - Unified diff generation using similar crate
//! - Edge case handling (new files, deleted files, no changes)
//! - Diff size limiting for UI rendering
//! - Binary file detection and fallback
use similar::{ChangeTag, TextDiff};
use std::path::Path;
/// Maximum diff size in characters before truncation.
const MAX_DIFF_SIZE: usize = 100_000;
/// Generate a unified diff between old and new file contents.
///
/// Returns a unified diff in standard format, suitable for UI rendering.
///
/// ## Edge Cases
///
/// - New file (old = empty) → Shows all lines as additions
/// - Deleted file (new = empty) → Shows all lines as deletions
/// - No change (old == new) → Returns empty string
/// - Very large diff → Truncated with message
///
/// ## Format
///
/// Standard unified diff format:
/// ```text
/// --- path/to/file.txt
/// +++ path/to/file.txt
/// @@ -1,3 +1,3 @@
/// context line
/// -old line
/// +new line
/// context line
/// ```
///
/// ## Error Handling
///
/// Diff generation is best-effort:
/// - Should never panic
/// - Falls back to generic message on error
/// - Logs warnings for debugging
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-03`
pub fn generate_diff(old_content: &str, new_content: &str, path: &Path) -> String {
// Early exit if contents are identical
if old_content == new_content {
return String::new();
}
// Handle new file case (old is empty)
if old_content.is_empty() {
let line_count = new_content.lines().count();
let mut result = format!(
"--- /dev/null\n+++ {}\n@@ -0,0 +1,{} @@\n",
path.display(),
line_count
);
for line in new_content.lines() {
result.push_str(&format!("+{}\n", line));
}
return truncate_diff(result);
}
// Handle deleted file case (new is empty)
if new_content.is_empty() {
let line_count = old_content.lines().count();
let mut result = format!(
"--- {}\n+++ /dev/null\n@@ -1,{} +0,0 @@\n",
path.display(),
line_count
);
for line in old_content.lines() {
result.push_str(&format!("-{}\n", line));
}
return truncate_diff(result);
}
// Generate unified diff using similar crate
match generate_unified_diff(old_content, new_content, path) {
Ok(diff) => truncate_diff(diff),
Err(e) => {
tracing::warn!(path = %path.display(), error = %e, "Failed to generate diff");
format!("# Content changed (diff generation failed: {})\n", e)
}
}
}
/// Generate unified diff using the similar crate.
fn generate_unified_diff(old_content: &str, new_content: &str, path: &Path) -> Result<String, String> {
let diff = TextDiff::from_lines(old_content, new_content);
let mut output = String::new();
output.push_str(&format!("--- {}\n", path.display()));
output.push_str(&format!("+++ {}\n", path.display()));
// Generate unified diff format with context
for hunk in diff.unified_diff().iter_hunks() {
// Write hunk header
output.push_str(&format!("{}", hunk.header()));
// Write changes
for change in hunk.iter_changes() {
let sign = match change.tag() {
ChangeTag::Delete => "-",
ChangeTag::Insert => "+",
ChangeTag::Equal => " ",
};
output.push_str(&format!("{}{}", sign, change.value()));
}
}
Ok(output)
}
/// Truncate diff if it exceeds maximum size.
fn truncate_diff(mut diff: String) -> String {
if diff.len() > MAX_DIFF_SIZE {
diff.truncate(MAX_DIFF_SIZE);
diff.push_str("\n... [diff truncated for display] ...\n");
}
diff
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_generate_diff_identical() {
let path = PathBuf::from("test.txt");
let content = "line1\nline2\nline3";
let diff = generate_diff(content, content, &path);
assert_eq!(diff, "");
}
#[test]
fn test_generate_diff_new_file() {
let path = PathBuf::from("test.txt");
let old = "";
let new = "line1\nline2\nline3";
let diff = generate_diff(old, new, &path);
assert!(diff.contains("--- /dev/null"));
assert!(diff.contains("+++ test.txt"));
assert!(diff.contains("+line1"));
assert!(diff.contains("+line2"));
assert!(diff.contains("+line3"));
}
#[test]
fn test_generate_diff_deleted_file() {
let path = PathBuf::from("test.txt");
let old = "line1\nline2\nline3";
let new = "";
let diff = generate_diff(old, new, &path);
assert!(diff.contains("--- test.txt"));
assert!(diff.contains("+++ /dev/null"));
assert!(diff.contains("-line1"));
assert!(diff.contains("-line2"));
assert!(diff.contains("-line3"));
}
#[test]
fn test_generate_diff_change() {
let path = PathBuf::from("test.txt");
let old = "line1\nline2\nline3";
let new = "line1\nmodified\nline3";
let diff = generate_diff(old, new, &path);
assert!(diff.contains("--- test.txt"));
assert!(diff.contains("+++ test.txt"));
assert!(diff.contains("-line2"));
assert!(diff.contains("+modified"));
}
#[test]
fn test_truncate_diff() {
let short_diff = "short diff";
assert_eq!(truncate_diff(short_diff.to_string()), short_diff);
let long_diff = "x".repeat(MAX_DIFF_SIZE + 1000);
let truncated = truncate_diff(long_diff);
assert!(truncated.len() <= MAX_DIFF_SIZE + 100); // Allow for truncation message
assert!(truncated.contains("[diff truncated for display]"));
}
}
+258
View File
@@ -0,0 +1,258 @@
//! Internal edit helper for read + transform + write operations.
//!
//! **Status**: Implemented (TOOLS-FS-04)
//!
//! This module implements:
//! - Edit operation abstraction (not ACP-native, internal API)
//! - Read + transform + write flow
//! - Automatic diff generation
//! - String replacement operations
use crate::config::SandboxConfig;
use crate::error::{ToolError, ToolResult};
use crate::fs::diff::generate_diff;
use crate::fs::read::{read_text_file, ReadTextFileRequest};
use crate::fs::write::{write_text_file, WriteTextFileRequest};
use crate::path::validate_path;
use serde::{Deserialize, Serialize};
/// Request to edit a file via transformation operations.
///
/// **Note**: This is an internal API, not exposed via ACP directly.
/// Agents use fs/write_text_file; this is for richer Dirigent UX.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EditFileRequest {
/// Absolute path to the file to edit.
pub path: String,
/// Ordered list of edit operations to apply.
pub edits: Vec<EditOperation>,
}
/// A single edit operation to apply to file content.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum EditOperation {
/// Replace occurrences of old_text with new_text.
Replace {
/// Text to find.
old_text: String,
/// Text to replace with.
new_text: String,
/// Replace all occurrences (true) or first only (false).
replace_all: bool,
},
/// Apply a unified diff patch (future).
#[allow(dead_code)]
Patch {
/// Unified diff string.
diff: String,
},
}
/// Response from editing a file.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EditFileResponse {
/// Unified diff showing changes made.
pub diff: String,
}
/// Edit a file by applying transformation operations.
///
/// ## Implementation
///
/// This function:
/// 1. Reads existing file content
/// 2. Applies edits in order:
/// - Replace: String find/replace (once or all)
/// - Patch: Apply unified diff (future)
/// 3. Writes transformed content
/// 4. Generates and returns diff
///
/// ## Algorithm
///
/// ```text
/// 1. old_content = read_text_file(path)
/// 2. new_content = old_content
/// 3. for each edit in edits:
/// new_content = apply_edit(new_content, edit)
/// 4. write_text_file(path, new_content)
/// 5. diff = generate_diff(old_content, new_content, path)
/// 6. return EditFileResponse { diff }
/// ```
///
/// ## Error Cases
///
/// - File not found → `ToolError::NotFound`
/// - Edit on new file → `ToolError::NotFound` (edits require existing content)
/// - Same sandboxing/permission errors as read/write
///
/// ## Tool Call Rendering
///
/// Always emits `ToolCallContent::diff` for UX visualization.
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-04`
pub async fn edit_file(
request: EditFileRequest,
sandbox_config: &SandboxConfig,
permission_config: &crate::config::PermissionConfig,
permission_context: &crate::permission::check::PermissionContext,
) -> ToolResult<EditFileResponse> {
// Validate path first to get canonical path for diff generation
let canonical_path = validate_path(&request.path, sandbox_config)?;
// Read existing file content
let read_request = ReadTextFileRequest {
path: request.path.clone(),
line: None,
limit: None,
};
let read_response = read_text_file(read_request, sandbox_config).await?;
let old_content = read_response.content;
// Apply all edits in order
let mut new_content = old_content.clone();
for edit in &request.edits {
new_content = apply_edit(&new_content, edit)?;
}
// Write transformed content
let write_request = WriteTextFileRequest {
path: request.path.clone(),
content: new_content.clone(),
};
write_text_file(write_request, sandbox_config, permission_config, permission_context).await?;
// Generate diff for UX
let diff = generate_diff(&old_content, &new_content, &canonical_path);
tracing::info!(
path = %request.path,
edit_count = request.edits.len(),
"File edited successfully"
);
Ok(EditFileResponse { diff })
}
/// Apply a single edit operation to content.
fn apply_edit(content: &str, edit: &EditOperation) -> ToolResult<String> {
match edit {
EditOperation::Replace {
old_text,
new_text,
replace_all,
} => {
if *replace_all {
// Replace all occurrences
Ok(content.replace(old_text, new_text))
} else {
// Replace only the first occurrence
if let Some(pos) = content.find(old_text) {
let mut result = String::with_capacity(content.len());
result.push_str(&content[..pos]);
result.push_str(new_text);
result.push_str(&content[pos + old_text.len()..]);
Ok(result)
} else {
// No match found - return content unchanged
// This could be a warning, but we'll allow it
tracing::warn!(
old_text = %old_text,
"Edit operation: old_text not found in content"
);
Ok(content.to_string())
}
}
}
EditOperation::Patch { diff: _ } => {
// Future: Apply unified diff patch
Err(ToolError::InvalidInput(
"Patch edit operation not yet implemented".to_string(),
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_apply_edit_replace_all() {
let content = "foo bar foo baz foo";
let edit = EditOperation::Replace {
old_text: "foo".to_string(),
new_text: "qux".to_string(),
replace_all: true,
};
let result = apply_edit(content, &edit).unwrap();
assert_eq!(result, "qux bar qux baz qux");
}
#[test]
fn test_apply_edit_replace_first() {
let content = "foo bar foo baz foo";
let edit = EditOperation::Replace {
old_text: "foo".to_string(),
new_text: "qux".to_string(),
replace_all: false,
};
let result = apply_edit(content, &edit).unwrap();
assert_eq!(result, "qux bar foo baz foo");
}
#[test]
fn test_apply_edit_no_match() {
let content = "foo bar baz";
let edit = EditOperation::Replace {
old_text: "qux".to_string(),
new_text: "quux".to_string(),
replace_all: false,
};
let result = apply_edit(content, &edit).unwrap();
assert_eq!(result, content); // Unchanged
}
#[test]
fn test_apply_edit_empty_replacement() {
let content = "foo bar foo";
let edit = EditOperation::Replace {
old_text: "foo".to_string(),
new_text: "".to_string(),
replace_all: true,
};
let result = apply_edit(content, &edit).unwrap();
assert_eq!(result, " bar ");
}
#[test]
fn test_apply_edit_multiline() {
let content = "line1\nfoo\nline3\nfoo\nline5";
let edit = EditOperation::Replace {
old_text: "foo".to_string(),
new_text: "bar".to_string(),
replace_all: true,
};
let result = apply_edit(content, &edit).unwrap();
assert_eq!(result, "line1\nbar\nline3\nbar\nline5");
}
#[test]
fn test_apply_edit_patch_unimplemented() {
let content = "foo bar baz";
let edit = EditOperation::Patch {
diff: "some diff".to_string(),
};
let result = apply_edit(content, &edit);
assert!(result.is_err());
assert!(matches!(result, Err(ToolError::InvalidInput(_))));
}
}
+331
View File
@@ -0,0 +1,331 @@
//! File type detection and MIME type guessing.
//!
//! This module provides utilities for detecting whether a file is text or binary,
//! and guessing appropriate MIME types based on file extensions.
use std::fs;
use std::io::Read;
use std::path::Path;
use crate::error::ToolResult;
/// Information about a file's type and encoding.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FileTypeInfo {
/// Whether the file appears to be text (not binary).
pub is_text: bool,
/// Whether the file appears to be binary (not text).
pub is_binary: bool,
/// MIME type guess based on extension and content.
pub mime_type: Option<String>,
/// Character set (always "utf-8" for text files in Phase 03).
pub charset: Option<String>,
}
impl FileTypeInfo {
/// Create a text file type info.
pub fn text(mime_type: impl Into<String>) -> Self {
Self {
is_text: true,
is_binary: false,
mime_type: Some(mime_type.into()),
charset: Some("utf-8".to_string()),
}
}
/// Create a binary file type info.
pub fn binary(mime_type: impl Into<String>) -> Self {
Self {
is_text: false,
is_binary: true,
mime_type: Some(mime_type.into()),
charset: None,
}
}
/// Create an unknown file type info (defaults to binary).
pub fn unknown() -> Self {
Self {
is_text: false,
is_binary: true,
mime_type: Some("application/octet-stream".to_string()),
charset: None,
}
}
}
/// Detect file type based on extension and content analysis.
///
/// This uses a fast path (extension-based detection) followed by optional
/// content sniffing for extensionless files.
///
/// # Arguments
///
/// * `path` - Path to the file to analyze
///
/// # Returns
///
/// File type information including text/binary classification and MIME type.
pub fn detect_file_type(path: &Path) -> ToolResult<FileTypeInfo> {
// First, try extension-based detection (fast path)
if let Some(info) = detect_by_extension(path) {
return Ok(info);
}
// For extensionless files, do content sniffing
detect_by_content(path)
}
/// Detect file type by extension.
///
/// This is the fast path for common file types. Returns None if the extension
/// is unknown or missing.
fn detect_by_extension(path: &Path) -> Option<FileTypeInfo> {
let ext = path.extension()?.to_str()?.to_lowercase();
match ext.as_str() {
// Programming languages
"rs" => Some(FileTypeInfo::text("text/x-rust")),
"ts" => Some(FileTypeInfo::text("text/typescript")),
"tsx" => Some(FileTypeInfo::text("text/tsx")),
"js" => Some(FileTypeInfo::text("text/javascript")),
"jsx" => Some(FileTypeInfo::text("text/jsx")),
"py" => Some(FileTypeInfo::text("text/x-python")),
"rb" => Some(FileTypeInfo::text("text/x-ruby")),
"go" => Some(FileTypeInfo::text("text/x-go")),
"java" => Some(FileTypeInfo::text("text/x-java")),
"c" => Some(FileTypeInfo::text("text/x-c")),
"cpp" | "cc" | "cxx" => Some(FileTypeInfo::text("text/x-c++")),
"h" | "hpp" => Some(FileTypeInfo::text("text/x-c-header")),
"cs" => Some(FileTypeInfo::text("text/x-csharp")),
"sh" | "bash" | "zsh" => Some(FileTypeInfo::text("text/x-shellscript")),
"ps1" => Some(FileTypeInfo::text("text/x-powershell")),
// Data formats
"json" => Some(FileTypeInfo::text("application/json")),
"toml" => Some(FileTypeInfo::text("application/toml")),
"yaml" | "yml" => Some(FileTypeInfo::text("application/yaml")),
"xml" => Some(FileTypeInfo::text("application/xml")),
"csv" => Some(FileTypeInfo::text("text/csv")),
// Markup
"html" | "htm" => Some(FileTypeInfo::text("text/html")),
"css" => Some(FileTypeInfo::text("text/css")),
"scss" | "sass" => Some(FileTypeInfo::text("text/x-scss")),
"md" | "markdown" => Some(FileTypeInfo::text("text/markdown")),
"rst" => Some(FileTypeInfo::text("text/x-rst")),
// Plain text
"txt" => Some(FileTypeInfo::text("text/plain")),
"log" => Some(FileTypeInfo::text("text/plain")),
// Config files
"ini" | "cfg" | "conf" => Some(FileTypeInfo::text("text/plain")),
"env" => Some(FileTypeInfo::text("text/plain")),
// Binary formats
"png" => Some(FileTypeInfo::binary("image/png")),
"jpg" | "jpeg" => Some(FileTypeInfo::binary("image/jpeg")),
"gif" => Some(FileTypeInfo::binary("image/gif")),
"webp" => Some(FileTypeInfo::binary("image/webp")),
"svg" => Some(FileTypeInfo::text("image/svg+xml")), // SVG is text
"pdf" => Some(FileTypeInfo::binary("application/pdf")),
"zip" => Some(FileTypeInfo::binary("application/zip")),
"gz" | "gzip" => Some(FileTypeInfo::binary("application/gzip")),
"tar" => Some(FileTypeInfo::binary("application/x-tar")),
"mp3" => Some(FileTypeInfo::binary("audio/mpeg")),
"mp4" => Some(FileTypeInfo::binary("video/mp4")),
"exe" | "dll" | "so" | "dylib" => Some(FileTypeInfo::binary("application/octet-stream")),
// Unknown extension
_ => None,
}
}
/// Detect file type by analyzing content.
///
/// Reads the first 8KB of the file and checks for UTF-8 validity and
/// presence of null bytes (indicating binary).
fn detect_by_content(path: &Path) -> ToolResult<FileTypeInfo> {
// Open file
let mut file = fs::File::open(path)?;
// Read first 8KB for analysis
let mut buffer = vec![0u8; 8192];
let bytes_read = file.read(&mut buffer)?;
buffer.truncate(bytes_read);
// Check for null bytes (strong indicator of binary)
if buffer.contains(&0) {
return Ok(FileTypeInfo::binary("application/octet-stream"));
}
// Try to validate as UTF-8
match std::str::from_utf8(&buffer) {
Ok(_) => Ok(FileTypeInfo::text("text/plain")),
Err(_) => Ok(FileTypeInfo::binary("application/octet-stream")),
}
}
/// Validate that a file contains valid UTF-8 text.
///
/// This is used to ensure text files can be safely read and embedded.
///
/// # Arguments
///
/// * `path` - Path to the file to validate
///
/// # Returns
///
/// `Ok(true)` if the file is valid UTF-8, `Ok(false)` if not.
pub fn is_valid_utf8(path: &Path) -> ToolResult<bool> {
let content = fs::read(path)?;
Ok(std::str::from_utf8(&content).is_ok())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_detect_by_extension_rust() {
let path = Path::new("test.rs");
let info = detect_by_extension(path).unwrap();
assert!(info.is_text);
assert!(!info.is_binary);
assert_eq!(info.mime_type, Some("text/x-rust".to_string()));
assert_eq!(info.charset, Some("utf-8".to_string()));
}
#[test]
fn test_detect_by_extension_typescript() {
let path = Path::new("test.ts");
let info = detect_by_extension(path).unwrap();
assert!(info.is_text);
assert_eq!(info.mime_type, Some("text/typescript".to_string()));
}
#[test]
fn test_detect_by_extension_json() {
let path = Path::new("config.json");
let info = detect_by_extension(path).unwrap();
assert!(info.is_text);
assert_eq!(info.mime_type, Some("application/json".to_string()));
}
#[test]
fn test_detect_by_extension_binary_png() {
let path = Path::new("image.png");
let info = detect_by_extension(path).unwrap();
assert!(!info.is_text);
assert!(info.is_binary);
assert_eq!(info.mime_type, Some("image/png".to_string()));
assert_eq!(info.charset, None);
}
#[test]
fn test_detect_by_extension_unknown() {
let path = Path::new("file.unknownext");
let info = detect_by_extension(path);
assert!(info.is_none());
}
#[test]
fn test_detect_by_extension_no_extension() {
let path = Path::new("Makefile");
let info = detect_by_extension(path);
assert!(info.is_none());
}
#[test]
fn test_detect_by_content_text() {
let mut temp = NamedTempFile::new().unwrap();
temp.write_all(b"Hello, world!\nThis is a text file.\n")
.unwrap();
let info = detect_by_content(temp.path()).unwrap();
assert!(info.is_text);
assert!(!info.is_binary);
assert_eq!(info.mime_type, Some("text/plain".to_string()));
}
#[test]
fn test_detect_by_content_binary_with_nulls() {
let mut temp = NamedTempFile::new().unwrap();
temp.write_all(b"\x00\x01\x02\x03binary data").unwrap();
let info = detect_by_content(temp.path()).unwrap();
assert!(!info.is_text);
assert!(info.is_binary);
assert_eq!(info.mime_type, Some("application/octet-stream".to_string()));
}
#[test]
fn test_detect_by_content_invalid_utf8() {
let mut temp = NamedTempFile::new().unwrap();
// Invalid UTF-8 sequence
temp.write_all(&[0xFF, 0xFE, 0xFD]).unwrap();
let info = detect_by_content(temp.path()).unwrap();
assert!(!info.is_text);
assert!(info.is_binary);
}
#[test]
fn test_detect_file_type_with_extension() {
let path = Path::new("test.rs");
// Will use extension-based detection (no file needed)
let info = detect_by_extension(path).unwrap();
assert!(info.is_text);
assert_eq!(info.mime_type, Some("text/x-rust".to_string()));
}
#[test]
fn test_is_valid_utf8_valid() {
let mut temp = NamedTempFile::new().unwrap();
temp.write_all("Hello, UTF-8! 你好世界".as_bytes())
.unwrap();
let result = is_valid_utf8(temp.path()).unwrap();
assert!(result);
}
#[test]
fn test_is_valid_utf8_invalid() {
let mut temp = NamedTempFile::new().unwrap();
temp.write_all(&[0xFF, 0xFE, 0xFD]).unwrap();
let result = is_valid_utf8(temp.path()).unwrap();
assert!(!result);
}
#[test]
fn test_file_type_info_constructors() {
let text = FileTypeInfo::text("text/plain");
assert!(text.is_text);
assert!(!text.is_binary);
assert_eq!(text.charset, Some("utf-8".to_string()));
let binary = FileTypeInfo::binary("application/pdf");
assert!(!binary.is_text);
assert!(binary.is_binary);
assert_eq!(binary.charset, None);
let unknown = FileTypeInfo::unknown();
assert!(!unknown.is_text);
assert!(unknown.is_binary);
assert_eq!(unknown.mime_type, Some("application/octet-stream".to_string()));
}
}
+215
View File
@@ -0,0 +1,215 @@
//! File read operation with sandboxing and line/limit support.
//!
//! **Status**: Implemented (TOOLS-FS-01)
//!
//! This module implements:
//! - Path validation using the security layer
//! - Asynchronous file reading
//! - UTF-8 encoding validation
//! - Line range and limit semantics
//! - Sandbox containment checks
//! - Blocklist enforcement
//! - Size limit soft caps
use crate::config::SandboxConfig;
use crate::error::{ToolError, ToolResult};
use crate::path::validate_path;
use serde::{Deserialize, Serialize};
/// Request to read a text file.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReadTextFileRequest {
/// Absolute path to the file to read.
pub path: String,
/// Optional starting line (1-indexed).
///
/// If provided with limit, reads from this line.
/// If provided without limit, reads from this line to end.
pub line: Option<usize>,
/// Optional maximum number of lines to read.
///
/// If provided without line, reads first N lines.
/// If provided with line, reads N lines starting from line.
pub limit: Option<usize>,
}
/// Response from reading a text file.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReadTextFileResponse {
/// UTF-8 text content of the file (or portion thereof).
pub content: String,
}
/// Read a text file with sandboxing and line/limit support.
///
/// ## Line/Limit Semantics
///
/// - `line: None, limit: None` → Read entire file
/// - `line: Some(n), limit: None` → Read from line n to end
/// - `line: None, limit: Some(m)` → Read first m lines
/// - `line: Some(n), limit: Some(m)` → Read m lines starting from line n
///
/// ## Error Cases
///
/// - Read disabled → `ToolError::PermissionDenied`
/// - Path outside allowed roots → `ToolError::SandboxViolation`
/// - Path matches blocklist → `ToolError::BlockedPath`
/// - File not found → `ToolError::NotFound`
/// - Non-UTF-8 content → `ToolError::EncodingUnsupported`
/// - File too large (soft cap) → Warning logged, content returned
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-01`
/// - Security spec: `docs/building/04_acp_client/03_fs_sandboxing_and_permissions_spec.md`
pub async fn read_text_file(
request: ReadTextFileRequest,
config: &SandboxConfig,
) -> ToolResult<ReadTextFileResponse> {
// Check if read is enabled
if !config.read_enabled {
return Err(ToolError::permission_denied("Read operations are disabled"));
}
// Validate and canonicalize path (checks containment and blocklist)
let canonical_path = validate_path(&request.path, config)?;
// Read file asynchronously
let content_bytes = tokio::fs::read(&canonical_path)
.await
.map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
ToolError::NotFound {
path: request.path.clone(),
}
} else {
ToolError::FileReadError {
path: request.path.clone(),
source: e,
}
}
})?;
// Check soft cap and log warning if exceeded
let file_size = content_bytes.len() as u64;
if file_size > config.max_read_bytes {
tracing::warn!(
path = %request.path,
size = file_size,
limit = config.max_read_bytes,
"File size exceeds max_read_bytes soft cap"
);
}
// Validate UTF-8 encoding
let content = String::from_utf8(content_bytes).map_err(|_| {
ToolError::EncodingUnsupported {
encoding: "non-UTF-8".to_string(),
}
})?;
// Apply line/limit semantics
let final_content = apply_line_limit(&content, request.line, request.limit);
Ok(ReadTextFileResponse {
content: final_content,
})
}
/// Apply line/limit semantics to file content.
///
/// ## Semantics
///
/// - `line: None, limit: None` → Return entire content
/// - `line: Some(n), limit: None` → Return from line n to end (1-indexed)
/// - `line: None, limit: Some(m)` → Return first m lines
/// - `line: Some(n), limit: Some(m)` → Return m lines starting from line n (1-indexed)
fn apply_line_limit(content: &str, line: Option<usize>, limit: Option<usize>) -> String {
match (line, limit) {
// No line or limit specified - return entire content
(None, None) => content.to_string(),
// Only limit specified - return first N lines
(None, Some(limit)) => {
content
.lines()
.take(limit)
.collect::<Vec<_>>()
.join("\n")
}
// Only line specified - return from line N to end (1-indexed)
(Some(start_line), None) => {
if start_line == 0 {
return content.to_string();
}
content
.lines()
.skip(start_line.saturating_sub(1))
.collect::<Vec<_>>()
.join("\n")
}
// Both line and limit specified - return N lines starting from line M (1-indexed)
(Some(start_line), Some(limit)) => {
if start_line == 0 {
return content
.lines()
.take(limit)
.collect::<Vec<_>>()
.join("\n");
}
content
.lines()
.skip(start_line.saturating_sub(1))
.take(limit)
.collect::<Vec<_>>()
.join("\n")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_apply_line_limit_no_params() {
let content = "line1\nline2\nline3";
assert_eq!(apply_line_limit(content, None, None), content);
}
#[test]
fn test_apply_line_limit_only_limit() {
let content = "line1\nline2\nline3\nline4";
assert_eq!(apply_line_limit(content, None, Some(2)), "line1\nline2");
}
#[test]
fn test_apply_line_limit_only_line() {
let content = "line1\nline2\nline3\nline4";
assert_eq!(apply_line_limit(content, Some(2), None), "line2\nline3\nline4");
}
#[test]
fn test_apply_line_limit_both() {
let content = "line1\nline2\nline3\nline4\nline5";
assert_eq!(apply_line_limit(content, Some(2), Some(2)), "line2\nline3");
}
#[test]
fn test_apply_line_limit_start_zero() {
let content = "line1\nline2\nline3";
assert_eq!(apply_line_limit(content, Some(0), None), content);
assert_eq!(apply_line_limit(content, Some(0), Some(2)), "line1\nline2");
}
#[test]
fn test_apply_line_limit_beyond_end() {
let content = "line1\nline2\nline3";
assert_eq!(apply_line_limit(content, Some(10), None), "");
assert_eq!(apply_line_limit(content, Some(2), Some(10)), "line2\nline3");
}
}
+232
View File
@@ -0,0 +1,232 @@
//! File write operation with sandboxing, permissions, and atomic writes.
//!
//! **Status**: Implemented (TOOLS-FS-02)
//!
//! This module implements:
//! - Path validation and canonicalization
//! - Permission checks (stubbed for now)
//! - Atomic write operations (temp + rename)
//! - EOL normalization
//! - Parent directory creation
//! - Size limit enforcement
use crate::config::{EolPolicy, PermissionConfig, SandboxConfig};
use crate::error::{ToolError, ToolResult};
use crate::path::validate_path;
use crate::permission::check::{check_permission, PermissionContext};
use crate::permission::cache::PermissionDecision;
use crate::permission::whitelist::PermissionOperation;
use serde::{Deserialize, Serialize};
use std::path::Path;
/// Request to write a text file.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WriteTextFileRequest {
/// Absolute path to the file to write.
pub path: String,
/// UTF-8 text content to write.
pub content: String,
}
/// Response from writing a text file.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WriteTextFileResponse {}
/// Write a text file with sandboxing, permission gating, and atomic writes.
///
/// ## Implementation
///
/// This function:
/// 1. Validates `write_enabled = true` in config
/// 2. Validates and canonicalizes path
/// 3. Checks containment and blocklist
/// 4. Checks permission (TODO: integrate with permission system)
/// 5. Validates content size against max_write_bytes
/// 6. Applies EOL policy normalization
/// 7. Creates parent directories if needed
/// 8. Performs atomic write (temp file + rename)
///
/// ## EOL Normalization
///
/// Applies configured EOL policy:
/// - `Preserve` → Keep original line endings
/// - `Lf` → Normalize to LF (\n)
/// - `Crlf` → Normalize to CRLF (\r\n)
///
/// ## Atomic Writes
///
/// To prevent partial writes:
/// 1. Write content to temporary file in same directory
/// 2. Rename temp file to target path (atomic on POSIX, near-atomic on Windows)
///
/// ## Error Cases
///
/// - Write disabled → `ToolError::PermissionDenied`
/// - Path outside allowed roots → `ToolError::SandboxViolation`
/// - Path matches blocklist → `ToolError::BlockedPath`
/// - Permission denied → `ToolError::PermissionDenied`
/// - Content too large → `ToolError::FileTooLarge`
/// - I/O errors → `ToolError::Io`
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-02`
/// - Security spec: `docs/building/04_acp_client/03_fs_sandboxing_and_permissions_spec.md`
pub async fn write_text_file(
request: WriteTextFileRequest,
sandbox_config: &SandboxConfig,
permission_config: &PermissionConfig,
permission_context: &PermissionContext,
) -> ToolResult<WriteTextFileResponse> {
// Check if write is enabled
if !sandbox_config.write_enabled {
return Err(ToolError::permission_denied("Write operations are disabled"));
}
// Validate content size
let content_size = request.content.len() as u64;
if content_size > sandbox_config.max_write_bytes {
return Err(ToolError::FileTooLarge {
size: content_size,
limit: sandbox_config.max_write_bytes,
});
}
// Validate and canonicalize path (checks containment and blocklist)
let canonical_path = validate_path(&request.path, sandbox_config)?;
// Check permission via permission system (TOOLS-PERM-01)
let operation = PermissionOperation::Write {
path: request.path.clone(),
};
let decision = check_permission(&operation, permission_context, permission_config).await?;
match decision {
PermissionDecision::Allowed => {
// Permission granted, proceed with write
tracing::debug!(
path = %request.path,
"Permission granted for write operation"
);
}
PermissionDecision::Denied => {
return Err(ToolError::PermissionRejected);
}
PermissionDecision::Cancelled => {
return Err(ToolError::permission_denied("Permission prompt was cancelled"));
}
}
// Apply EOL normalization
let normalized_content = normalize_eol(&request.content, sandbox_config.eol_policy);
// Create parent directories if they don't exist
if let Some(parent) = canonical_path.parent() {
tokio::fs::create_dir_all(parent).await.map_err(|e| {
ToolError::FileReadError {
path: parent.display().to_string(),
source: e,
}
})?;
}
// Perform atomic write
write_atomic(&canonical_path, &normalized_content).await?;
tracing::info!(
path = %request.path,
size = content_size,
"File written successfully"
);
Ok(WriteTextFileResponse {})
}
/// Normalize line endings according to configured policy.
pub fn normalize_eol(content: &str, policy: EolPolicy) -> String {
match policy {
EolPolicy::Preserve => content.to_string(),
EolPolicy::Lf => {
// Replace CRLF with LF, then ensure all CR are removed
content.replace("\r\n", "\n").replace('\r', "\n")
}
EolPolicy::Crlf => {
// First normalize to LF, then replace with CRLF
let lf_normalized = content.replace("\r\n", "\n").replace('\r', "\n");
lf_normalized.replace('\n', "\r\n")
}
}
}
/// Write content to a file atomically using temp file + rename.
///
/// This minimizes the risk of partial writes by:
/// 1. Writing to a temporary file in the same directory
/// 2. Renaming the temp file to the target path
///
/// The rename operation is atomic on POSIX systems and near-atomic on Windows.
async fn write_atomic(path: &Path, content: &str) -> ToolResult<()> {
// Create a temporary file in the same directory
let parent = path.parent().unwrap_or_else(|| Path::new("."));
let file_name = path.file_name().unwrap_or_default().to_string_lossy();
let temp_path = parent.join(format!(".{}.tmp", file_name));
// Write content to temporary file
tokio::fs::write(&temp_path, content).await.map_err(|e| {
ToolError::FileReadError {
path: temp_path.display().to_string(),
source: e,
}
})?;
// Rename temp file to target (atomic operation)
tokio::fs::rename(&temp_path, path).await.map_err(|e| {
// Clean up temp file on error
let _ = std::fs::remove_file(&temp_path);
ToolError::FileReadError {
path: path.display().to_string(),
source: e,
}
})?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_eol_preserve() {
let content = "line1\nline2\r\nline3\rline4";
assert_eq!(normalize_eol(content, EolPolicy::Preserve), content);
}
#[test]
fn test_normalize_eol_lf() {
let content = "line1\nline2\r\nline3\rline4";
let expected = "line1\nline2\nline3\nline4";
assert_eq!(normalize_eol(content, EolPolicy::Lf), expected);
}
#[test]
fn test_normalize_eol_crlf() {
let content = "line1\nline2\r\nline3\rline4";
let expected = "line1\r\nline2\r\nline3\r\nline4";
assert_eq!(normalize_eol(content, EolPolicy::Crlf), expected);
}
#[test]
fn test_normalize_eol_edge_cases() {
// Empty string
assert_eq!(normalize_eol("", EolPolicy::Lf), "");
// Only newlines
assert_eq!(normalize_eol("\n\n\n", EolPolicy::Crlf), "\r\n\r\n\r\n");
// Mixed line endings to LF
assert_eq!(normalize_eol("a\rb\nc\r\nd", EolPolicy::Lf), "a\nb\nc\nd");
}
}
+78
View File
@@ -0,0 +1,78 @@
//! # dirigent_tools
//!
//! Tool implementations for ACP (Agent-Client Protocol) client operations.
//!
//! This package provides:
//! - **File operations** (read, write, edit) with sandboxing
//! - **Terminal operations** (create, output, wait, kill) with isolation
//! - **Search operations** (glob, grep, ls) with result limiting
//! - **Permission system** for securing write and execute operations
//! - **Path sandboxing** with configurable allowed roots and blocklists
//!
//! ## Module Overview
//!
//! - [`error`] - Error types for tool operations
//! - [`config`] - Configuration types for sandboxing, permissions, and limits
//! - [`path`] - Path normalization and containment checking (cross-platform)
//! - [`fs`] - File system operations (read, write) with sandbox enforcement
//! - [`search`] - Search operations (glob, grep, ls) with result limiting
//! - [`terminal`] - Terminal/command execution with output capture
//! - [`permission`] - Permission prompts and decision caching
//! - [`audit`] - Audit logging for sensitive operations
//! - [`tool`] - Tool trait, events, and per-call context
//! - [`registry`] - ToolRegistry for built-in and dynamic tools
//! - [`floor`] - Hardcoded SecurityFloor (cannot be bypassed by settings)
//! - [`dispatch`] - Dispatch pipeline (registry → floor → run)
//! - [`tools`] - Built-in tool implementations (read, ...)
//!
//! ## Safety and Security
//!
//! All file and terminal operations are subject to:
//! - **Sandbox containment** - Operations restricted to configured allowed roots
//! - **Blocklist enforcement** - Sensitive paths can be explicitly denied
//! - **Permission prompts** - Write and execute operations can require user approval
//! - **Resource limits** - File size, search results, and terminal output are bounded
//! - **Audit logging** - All operations are logged with structured context
//!
//! ## Platform Support
//!
//! This crate is designed with Windows as a first-class platform:
//! - Handles Windows paths (backslashes, drive letters, UNC shares, `\\?\` prefixes)
//! - Supports MINGW-style paths (`/c/...`)
//! - Normalizes path separators for consistent policy enforcement
//! - Tests run on Windows, Linux, and macOS
pub mod error;
pub mod config;
pub mod path;
pub mod fs;
pub mod search;
pub mod terminal;
pub mod permission;
pub mod audit;
pub mod embedding;
pub mod tool;
pub mod registry;
pub mod floor;
pub mod dispatch;
pub mod tools;
// Re-export commonly used types
pub use error::{ToolError, ToolResult};
pub use config::{
SandboxConfig, PermissionConfig, TerminalConfig, SearchConfig, EmbeddingConfig,
};
pub use tool::{
AnyTool, AnyToolInput, ClientKind, Erased, PermissionRequestId, ProtocolKind, Tool,
ToolContext, ToolEvent, ToolEventSink, ToolInput, ToolKind, ToolLocation, ToolResultContent,
};
pub use registry::{CollisionPolicy, DynamicEntry, ToolRegistry, ToolSource, Winner};
pub use floor::{FloorDecision, SecurityFloor};
pub use dispatch::{dispatch, DispatchResult};
/// Re-exports of the policy types `dirigent_tools` consumes from
/// `dirigent_fermata`. Keeping this here pins the dependency direction:
/// `dirigent_tools` → `dirigent_fermata`, never the reverse.
pub mod policy {
pub use dirigent_fermata::core::{Decision, Op, Policy, Reason, Rule};
}
+114
View File
@@ -0,0 +1,114 @@
//! Path normalization and containment checking.
//!
//! This module provides cross-platform path utilities with special attention to Windows:
//! - Normalizes path separators (backslash vs forward slash)
//! - Handles Windows drive letters (C:\, /c/, etc.)
//! - Handles UNC paths (\\server\share\...)
//! - Handles long path prefixes (\\?\...)
//! - Handles MINGW-style paths (/c/Users/...)
//! - Canonical path resolution (symlinks, junctions)
//! - Containment checking for sandbox enforcement
//!
//! # Security-Critical
//!
//! This module is the security foundation for tool sandboxing. All operations must:
//! - Prevent path traversal attacks
//! - Prevent symlink escape attacks
//! - Handle all Windows path edge cases correctly
//! - Never expose disallowed paths in error messages
pub mod canonicalize;
pub mod containment;
pub mod blocklist;
pub mod validate;
// Re-export public API
pub use canonicalize::{canonicalize_path, SymlinkPolicy};
pub use containment::check_containment;
pub use blocklist::check_blocklist;
pub use validate::validate_path;
use std::path::{Path, PathBuf};
/// Get the basename (final component) of a path for safe error messages.
///
/// This is used to avoid leaking full paths in error messages.
pub fn basename(path: &Path) -> String {
path.file_name()
.and_then(|s| s.to_str())
.map(|s| s.to_string())
.unwrap_or_else(|| String::new())
}
/// Check if a path is absolute.
///
/// On Windows, this handles:
/// - Standard absolute paths (C:\...)
/// - UNC paths (\\server\share\...)
/// - Long-path prefixes (\\?\...)
/// - Verbatim UNC (\\?\UNC\server\share\...)
pub fn is_absolute(path: &Path) -> bool {
path.is_absolute()
}
/// Normalize a path to use the platform's standard separator.
///
/// On Windows, this converts forward slashes to backslashes.
/// On Unix, this is a no-op.
pub fn normalize_separators(path: &Path) -> PathBuf {
#[cfg(windows)]
{
let s = path.to_string_lossy();
let normalized = s.replace('/', "\\");
PathBuf::from(normalized)
}
#[cfg(not(windows))]
{
path.to_path_buf()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basename() {
assert_eq!(basename(Path::new("/foo/bar/baz.txt")), "baz.txt");
assert_eq!(basename(Path::new("baz.txt")), "baz.txt");
assert_eq!(basename(Path::new("/")), "");
}
#[cfg(windows)]
#[test]
fn test_is_absolute_windows() {
assert!(is_absolute(Path::new("C:\\Users")));
assert!(is_absolute(Path::new("C:/Users")));
assert!(is_absolute(Path::new("\\\\server\\share")));
assert!(is_absolute(Path::new("\\\\?\\C:\\Users")));
assert!(!is_absolute(Path::new("relative\\path")));
assert!(!is_absolute(Path::new("..\\parent")));
}
#[cfg(unix)]
#[test]
fn test_is_absolute_unix() {
assert!(is_absolute(Path::new("/home/user")));
assert!(!is_absolute(Path::new("relative/path")));
assert!(!is_absolute(Path::new("../parent")));
}
#[cfg(windows)]
#[test]
fn test_normalize_separators_windows() {
assert_eq!(
normalize_separators(Path::new("C:/Users/foo/bar.txt")),
PathBuf::from("C:\\Users\\foo\\bar.txt")
);
assert_eq!(
normalize_separators(Path::new("C:\\Users\\foo\\bar.txt")),
PathBuf::from("C:\\Users\\foo\\bar.txt")
);
}
}
+220
View File
@@ -0,0 +1,220 @@
//! Blocklist evaluation using glob patterns.
//!
//! This module checks if paths match configured blocklist patterns,
//! allowing fine-grained denial of sensitive paths even within allowed roots.
use crate::error::{ToolError, ToolResult};
use globset::{Glob, GlobSet, GlobSetBuilder};
use std::path::Path;
/// Check if a path matches any blocklist patterns.
///
/// Blocklist patterns can be:
/// - Absolute paths (exact match)
/// - Glob patterns (e.g., `**/.env`, `**/secrets/**`)
///
/// # Arguments
///
/// * `canonical_path` - The canonical path to check
/// * `blocklist_patterns` - List of path patterns or globs to deny
///
/// # Returns
///
/// Ok(()) if the path is not blocked, or an error if it matches a blocklist pattern.
///
/// # Performance
///
/// For best performance, pre-compile patterns into a `CompiledBlocklist` and use
/// `check_blocklist_compiled` instead.
pub fn check_blocklist(
canonical_path: &Path,
blocklist_patterns: &[String],
) -> ToolResult<()> {
if blocklist_patterns.is_empty() {
return Ok(());
}
// Compile patterns on the fly (slower)
let compiled = compile_blocklist(blocklist_patterns)?;
check_blocklist_compiled(canonical_path, &compiled)
}
/// Pre-compiled blocklist for efficient matching.
pub struct CompiledBlocklist {
glob_set: GlobSet,
}
impl CompiledBlocklist {
/// Create a new compiled blocklist.
pub fn new(glob_set: GlobSet) -> Self {
Self { glob_set }
}
/// Get the underlying GlobSet.
pub fn glob_set(&self) -> &GlobSet {
&self.glob_set
}
}
/// Compile blocklist patterns into a GlobSet for efficient matching.
///
/// This should be done once at configuration load time.
///
/// # Arguments
///
/// * `patterns` - List of glob patterns or absolute paths
///
/// # Returns
///
/// A compiled blocklist ready for fast matching.
pub fn compile_blocklist(patterns: &[String]) -> ToolResult<CompiledBlocklist> {
let mut builder = GlobSetBuilder::new();
for pattern in patterns {
let glob = Glob::new(pattern).map_err(|e| {
ToolError::InvalidConfig(format!("Invalid blocklist pattern '{}': {}", pattern, e))
})?;
builder.add(glob);
}
let glob_set = builder.build().map_err(|e| {
ToolError::InvalidConfig(format!("Failed to compile blocklist patterns: {}", e))
})?;
Ok(CompiledBlocklist::new(glob_set))
}
/// Check if a path matches a pre-compiled blocklist.
///
/// This is the fast path for blocklist checking.
///
/// # Arguments
///
/// * `canonical_path` - The canonical path to check
/// * `compiled` - Pre-compiled blocklist
///
/// # Returns
///
/// Ok(()) if the path is not blocked, or an error if it matches a pattern.
pub fn check_blocklist_compiled(
canonical_path: &Path,
compiled: &CompiledBlocklist,
) -> ToolResult<()> {
if compiled.glob_set.is_match(canonical_path) {
return Err(ToolError::blocked_path(format!(
"Path matches blocklist: {}",
super::basename(canonical_path)
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_blocklist() {
let result = check_blocklist(Path::new("/any/path"), &[]);
assert!(result.is_ok());
}
#[test]
fn test_exact_path_match() {
let patterns = vec![
"/etc/passwd".to_string(),
"/home/user/.ssh/id_rsa".to_string(),
];
// Blocked
assert!(check_blocklist(Path::new("/etc/passwd"), &patterns).is_err());
assert!(check_blocklist(Path::new("/home/user/.ssh/id_rsa"), &patterns).is_err());
// Not blocked
assert!(check_blocklist(Path::new("/etc/hosts"), &patterns).is_ok());
}
#[test]
fn test_glob_patterns() {
let patterns = vec![
"**/.env".to_string(),
"**/secrets/**".to_string(),
"**/.git/**".to_string(),
];
// Blocked by .env pattern
assert!(check_blocklist(Path::new("/project/.env"), &patterns).is_err());
assert!(check_blocklist(Path::new("/project/subdir/.env"), &patterns).is_err());
// Blocked by secrets pattern
assert!(check_blocklist(Path::new("/project/secrets/key.txt"), &patterns).is_err());
assert!(check_blocklist(Path::new("/app/secrets/api_key"), &patterns).is_err());
// Blocked by .git pattern
assert!(check_blocklist(Path::new("/project/.git/config"), &patterns).is_err());
// Not blocked
assert!(check_blocklist(Path::new("/project/src/main.rs"), &patterns).is_ok());
assert!(check_blocklist(Path::new("/project/README.md"), &patterns).is_ok());
}
#[test]
fn test_compile_blocklist() {
let patterns = vec![
"**/.env".to_string(),
"**/secrets/**".to_string(),
];
let compiled = compile_blocklist(&patterns).unwrap();
// Blocked
assert!(check_blocklist_compiled(
Path::new("/project/.env"),
&compiled
).is_err());
// Not blocked
assert!(check_blocklist_compiled(
Path::new("/project/src/main.rs"),
&compiled
).is_ok());
}
#[test]
fn test_invalid_pattern() {
let patterns = vec![
"[invalid".to_string(), // Invalid glob syntax
];
let result = compile_blocklist(&patterns);
assert!(result.is_err());
}
#[cfg(windows)]
#[test]
fn test_windows_paths() {
let patterns = vec![
"**/.env".to_string(),
"C:/secrets/**".to_string(),
];
let compiled = compile_blocklist(&patterns).unwrap();
// Blocked
assert!(check_blocklist_compiled(
Path::new("C:\\project\\.env"),
&compiled
).is_err());
assert!(check_blocklist_compiled(
Path::new("C:\\secrets\\key.txt"),
&compiled
).is_err());
// Not blocked
assert!(check_blocklist_compiled(
Path::new("C:\\project\\src\\main.rs"),
&compiled
).is_ok());
}
}
@@ -0,0 +1,387 @@
//! Path canonicalization with cross-platform support.
//!
//! This module handles:
//! - Converting relative paths to absolute
//! - Resolving symlinks and junctions per policy
//! - Normalizing Windows path formats (UNC, long-path, MINGW)
//! - Rejecting Windows reserved device names
//! - Handling non-existent paths (for write operations)
use crate::error::{ToolError, ToolResult};
use std::path::{Path, PathBuf, Component};
/// Symlink handling policy for path canonicalization.
#[derive(Debug, Clone, Copy)]
pub struct SymlinkPolicy {
/// Allow symlinks to escape allowed roots.
///
/// If false (recommended), symlinks pointing outside allowed roots are rejected.
pub allow_symlink_escape: bool,
/// Follow symlinks within allowed roots.
///
/// If true, symlinks within allowed roots are followed during canonicalization.
pub follow_symlinks_within_roots: bool,
}
impl Default for SymlinkPolicy {
fn default() -> Self {
Self {
allow_symlink_escape: false,
follow_symlinks_within_roots: true,
}
}
}
impl SymlinkPolicy {
/// Create a policy that follows all symlinks (least restrictive).
pub fn follow_all() -> Self {
Self {
allow_symlink_escape: true,
follow_symlinks_within_roots: true,
}
}
/// Create a policy that never follows symlinks (most restrictive).
pub fn follow_none() -> Self {
Self {
allow_symlink_escape: false,
follow_symlinks_within_roots: false,
}
}
}
/// Canonicalize a path with the given symlink policy.
///
/// This function:
/// 1. Rejects empty or relative paths (ACP requires absolute paths)
/// 2. Normalizes separators to platform standard
/// 3. Converts MINGW-style paths (/c/...) to native Windows paths (C:\...)
/// 4. Strips long-path prefixes (\\?\) for policy comparison
/// 5. Resolves symlinks/junctions per policy
/// 6. For non-existent paths, canonicalizes parent + appends remainder
/// 7. Ensures no ".." escapes remain
/// 8. Rejects Windows reserved device names (CON, NUL, etc.)
///
/// # Arguments
///
/// * `user_path` - The path to canonicalize (must be absolute)
/// * `policy` - Symlink handling policy
///
/// # Returns
///
/// The canonical absolute path, or an error if the path is invalid or violates policy.
///
/// # Security
///
/// This is a security-critical function. It must prevent:
/// - Path traversal attacks via ".." components
/// - Symlink escape attacks
/// - Access to Windows reserved device names
pub fn canonicalize_path(user_path: &Path, _policy: &SymlinkPolicy) -> ToolResult<PathBuf> {
// Step 1: Reject empty paths
if user_path.as_os_str().is_empty() {
return Err(ToolError::sandbox_violation("Path cannot be empty"));
}
// Step 2: Convert to string for preprocessing
let path_str = user_path.to_string_lossy();
// Step 3: Handle MINGW-style paths on Windows (/c/... -> C:\...)
#[cfg(windows)]
let path_str = convert_mingw_path(&path_str);
let mut path = PathBuf::from(path_str.as_ref());
// Step 4: Normalize separators
path = super::normalize_separators(&path);
// Step 5: Reject relative paths
if !path.is_absolute() {
return Err(ToolError::sandbox_violation(format!(
"Path must be absolute: {}",
super::basename(&path)
)));
}
// Step 6: Strip long-path prefix for normalization (\\?\C:\... -> C:\...)
#[cfg(windows)]
{
path = strip_long_path_prefix(&path);
}
// Step 7: Check for reserved device names on Windows
#[cfg(windows)]
{
if is_reserved_device_name(&path) {
return Err(ToolError::sandbox_violation(format!(
"Path is a reserved device name: {}",
super::basename(&path)
)));
}
}
// Step 8: Normalize drive letter case on Windows
#[cfg(windows)]
{
path = normalize_drive_letter(&path);
}
// Step 9: Try to canonicalize using dunce (follows symlinks on all platforms)
// If the path doesn't exist, find the first existing ancestor and canonicalize it,
// then append all non-existent components
let canonical = match dunce::canonicalize(&path) {
Ok(canonical) => canonical,
Err(_) => {
// Path doesn't exist - find first existing ancestor
canonicalize_non_existent_path(&path)?
}
};
// Step 10: Verify no ".." components remain (security check)
for component in canonical.components() {
if component == Component::ParentDir {
return Err(ToolError::sandbox_violation(
"Path contains '..' after canonicalization (potential traversal attack)",
));
}
}
// Step 11: Strip long-path prefix again if added by canonicalize
#[cfg(windows)]
let canonical = strip_long_path_prefix(&canonical);
Ok(canonical)
}
/// Canonicalize a non-existent path by finding the first existing ancestor.
///
/// This function walks up the directory tree until it finds an existing directory,
/// then builds the canonical path by appending the non-existent components.
fn canonicalize_non_existent_path(path: &Path) -> ToolResult<PathBuf> {
// Collect components of the non-existent path parts
let mut components_to_append: Vec<std::ffi::OsString> = Vec::new();
let mut current = path.to_path_buf();
// Walk up the directory tree until we find an existing directory
loop {
if let Some(parent) = current.parent() {
if parent == current {
// Hit root without finding existing directory
// This shouldn't happen with absolute paths, but handle it
return Ok(path.to_path_buf());
}
// Try to canonicalize the parent
match dunce::canonicalize(parent) {
Ok(canonical_parent) => {
// Found existing parent - build the full path
components_to_append.reverse();
let mut result = canonical_parent;
// First add the current file_name (if any)
if let Some(file_name) = current.file_name() {
result = result.join(file_name);
}
// Then add all the rest
for component in components_to_append {
result = result.join(component);
}
return Ok(result);
}
Err(_) => {
// Parent doesn't exist - collect the component and continue up
if let Some(file_name) = current.file_name() {
components_to_append.push(file_name.to_os_string());
}
current = parent.to_path_buf();
}
}
} else {
// No parent (shouldn't reach here with absolute paths)
return Ok(path.to_path_buf());
}
}
}
/// Convert MINGW-style path (/c/Users/...) to Windows path (C:\Users\...).
///
/// Only applies on Windows. Detects paths like:
/// - /c/... -> C:\...
/// - /d/... -> D:\...
#[cfg(windows)]
fn convert_mingw_path(path: &str) -> std::borrow::Cow<'_, str> {
// Check for MINGW-style path: starts with /<letter>/
if path.len() >= 3
&& path.starts_with('/')
&& path.chars().nth(1).map_or(false, |c| c.is_ascii_alphabetic())
&& (path.len() == 3 || path.chars().nth(2) == Some('/'))
{
let drive_letter = path.chars().nth(1).unwrap().to_ascii_uppercase();
let rest = if path.len() > 3 {
// Normalize forward slashes to backslashes in the remainder
path[3..].replace('/', "\\")
} else {
String::new()
};
format!("{}:\\{}", drive_letter, rest).into()
} else {
path.into()
}
}
/// Strip the long-path prefix (\\?\) from a Windows path.
///
/// Also handles verbatim UNC paths (\\?\UNC\server\share -> \\server\share).
#[cfg(windows)]
fn strip_long_path_prefix(path: &Path) -> PathBuf {
let path_str = path.to_string_lossy();
// Handle \\?\UNC\server\share -> \\server\share
if path_str.starts_with(r"\\?\UNC\") {
let without_prefix = &path_str[r"\\?\UNC\".len()..];
return PathBuf::from(format!(r"\\{}", without_prefix));
}
// Handle \\?\C:\ -> C:\
if path_str.starts_with(r"\\?\") {
let without_prefix = &path_str[r"\\?\".len()..];
return PathBuf::from(without_prefix);
}
path.to_path_buf()
}
/// Normalize drive letter to uppercase (C:\ instead of c:\).
#[cfg(windows)]
fn normalize_drive_letter(path: &Path) -> PathBuf {
let path_str = path.to_string_lossy();
// Check if path starts with a drive letter (e.g., "c:\")
if path_str.len() >= 2
&& path_str.chars().nth(0).map_or(false, |c| c.is_ascii_alphabetic())
&& path_str.chars().nth(1) == Some(':')
{
let mut chars: Vec<char> = path_str.chars().collect();
chars[0] = chars[0].to_ascii_uppercase();
return PathBuf::from(chars.iter().collect::<String>());
}
path.to_path_buf()
}
/// Check if a path is a Windows reserved device name.
///
/// Reserved names include:
/// - CON, PRN, AUX, NUL
/// - COM1-COM9
/// - LPT1-LPT9
///
/// These can appear with or without extensions (e.g., CON.txt is also reserved).
#[cfg(windows)]
fn is_reserved_device_name(path: &Path) -> bool {
const RESERVED_NAMES: &[&str] = &[
"CON", "PRN", "AUX", "NUL",
"COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9",
"LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
];
if let Some(file_name) = path.file_name() {
let name = file_name.to_string_lossy().to_uppercase();
// Check exact match
if RESERVED_NAMES.contains(&name.as_str()) {
return true;
}
// Check with extension (e.g., CON.txt)
if let Some(stem) = Path::new(&*name).file_stem() {
let stem_str = stem.to_string_lossy();
if RESERVED_NAMES.contains(&stem_str.as_ref()) {
return true;
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_path() {
let result = canonicalize_path(Path::new(""), &SymlinkPolicy::default());
assert!(result.is_err());
}
#[cfg(windows)]
#[test]
fn test_convert_mingw_path() {
assert_eq!(convert_mingw_path("/c/Users/foo"), "C:\\Users\\foo");
assert_eq!(convert_mingw_path("/d/Projects"), "D:\\Projects");
assert_eq!(convert_mingw_path("/c/"), "C:\\");
assert_eq!(convert_mingw_path("C:\\Users"), "C:\\Users"); // No conversion
}
#[cfg(windows)]
#[test]
fn test_strip_long_path_prefix() {
assert_eq!(
strip_long_path_prefix(Path::new(r"\\?\C:\Users\foo")),
PathBuf::from(r"C:\Users\foo")
);
assert_eq!(
strip_long_path_prefix(Path::new(r"\\?\UNC\server\share\file")),
PathBuf::from(r"\\server\share\file")
);
assert_eq!(
strip_long_path_prefix(Path::new(r"C:\Users\foo")),
PathBuf::from(r"C:\Users\foo")
);
}
#[cfg(windows)]
#[test]
fn test_normalize_drive_letter() {
assert_eq!(
normalize_drive_letter(Path::new("c:\\Users")),
PathBuf::from("C:\\Users")
);
assert_eq!(
normalize_drive_letter(Path::new("C:\\Users")),
PathBuf::from("C:\\Users")
);
}
#[cfg(windows)]
#[test]
fn test_is_reserved_device_name() {
assert!(is_reserved_device_name(Path::new("CON")));
assert!(is_reserved_device_name(Path::new("con")));
assert!(is_reserved_device_name(Path::new("CON.txt")));
assert!(is_reserved_device_name(Path::new("C:\\path\\to\\NUL")));
assert!(is_reserved_device_name(Path::new("COM1")));
assert!(is_reserved_device_name(Path::new("LPT5")));
assert!(!is_reserved_device_name(Path::new("CONNECT.txt")));
assert!(!is_reserved_device_name(Path::new("file.txt")));
}
#[test]
fn test_symlink_policy() {
let default = SymlinkPolicy::default();
assert!(!default.allow_symlink_escape);
assert!(default.follow_symlinks_within_roots);
let follow_all = SymlinkPolicy::follow_all();
assert!(follow_all.allow_symlink_escape);
assert!(follow_all.follow_symlinks_within_roots);
let follow_none = SymlinkPolicy::follow_none();
assert!(!follow_none.allow_symlink_escape);
assert!(!follow_none.follow_symlinks_within_roots);
}
}
@@ -0,0 +1,265 @@
//! Path containment checking for sandbox enforcement.
//!
//! This module verifies that a canonical path is within allowed sandbox roots.
use crate::error::{ToolError, ToolResult};
use std::path::{Path, PathBuf, Component};
/// Check if a canonical path is contained within allowed roots.
///
/// This function performs component-wise prefix matching (not string matching)
/// to ensure the path is strictly within at least one allowed root.
///
/// # Arguments
///
/// * `canonical_path` - The canonical path to check (must already be canonicalized)
/// * `allowed_roots` - List of allowed root paths (must already be canonical)
///
/// # Returns
///
/// Ok(()) if the path is contained, or an error if it's outside all roots.
///
/// # Security
///
/// This is a security-critical function. It uses component-wise comparison to prevent:
/// - String prefix attacks (e.g., "/a" should not contain "/ab")
/// - Path traversal attacks
pub fn check_containment(
canonical_path: &Path,
allowed_roots: &[PathBuf],
) -> ToolResult<()> {
// If no roots configured, deny all paths
if allowed_roots.is_empty() {
return Err(ToolError::sandbox_violation(
"No allowed roots configured - all paths denied",
));
}
// Check if path is contained by at least one root
for root in allowed_roots {
if is_contained_by(canonical_path, root) {
return Ok(());
}
}
// Path is outside all roots
Err(ToolError::sandbox_violation(format!(
"Path is outside allowed roots: {}",
super::basename(canonical_path)
)))
}
/// Check if a path is strictly contained by a root.
///
/// This performs component-wise comparison to ensure proper containment:
/// - "/a/b/c" is contained by "/a"
/// - "/a/b/c" is NOT contained by "/a/b/c" (must be strict)
/// - "/ab" is NOT contained by "/a" (not a string prefix match)
///
/// On Windows, comparison is case-insensitive to match filesystem behavior.
///
/// # Arguments
///
/// * `path` - The path to check (must be canonical)
/// * `root` - The root path (must be canonical)
///
/// # Returns
///
/// `true` if path is strictly contained within root, `false` otherwise.
pub fn is_contained_by(path: &Path, root: &Path) -> bool {
// Get components
let path_components: Vec<_> = path.components().collect();
let root_components: Vec<_> = root.components().collect();
// Path must have MORE components than root (strict containment)
// Equal components means path == root, which is not strict containment
if path_components.len() <= root_components.len() {
return false;
}
// Check if all root components match path components (prefix)
for (i, root_comp) in root_components.iter().enumerate() {
let path_comp = &path_components[i];
if !components_equal(path_comp, root_comp) {
return false;
}
}
true
}
/// Compare two path components for equality.
///
/// On Windows, this is case-insensitive.
/// On Unix, this is case-sensitive.
fn components_equal(a: &Component, b: &Component) -> bool {
#[cfg(windows)]
{
// Case-insensitive comparison on Windows
let a_str = component_to_string(a).to_lowercase();
let b_str = component_to_string(b).to_lowercase();
a_str == b_str
}
#[cfg(not(windows))]
{
// Case-sensitive comparison on Unix
a == b
}
}
/// Convert a path component to a string for comparison.
#[cfg(windows)]
fn component_to_string(comp: &Component) -> String {
match comp {
Component::Prefix(prefix) => prefix.as_os_str().to_string_lossy().to_string(),
Component::RootDir => "/".to_string(),
Component::CurDir => ".".to_string(),
Component::ParentDir => "..".to_string(),
Component::Normal(s) => s.to_string_lossy().to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_contained_by_unix() {
// Strict containment (child must be strictly inside parent)
assert!(is_contained_by(
Path::new("/a/b/c"),
Path::new("/a")
));
assert!(is_contained_by(
Path::new("/a/b/c"),
Path::new("/a/b")
));
// Equal paths are NOT strictly contained
assert!(!is_contained_by(
Path::new("/a/b"),
Path::new("/a/b")
));
// Parent is not contained by child
assert!(!is_contained_by(
Path::new("/a"),
Path::new("/a/b")
));
// Different trees
assert!(!is_contained_by(
Path::new("/a/b"),
Path::new("/c")
));
// String prefix but not path prefix
assert!(!is_contained_by(
Path::new("/ab"),
Path::new("/a")
));
}
#[cfg(windows)]
#[test]
fn test_is_contained_by_windows_case_insensitive() {
// Windows is case-insensitive
assert!(is_contained_by(
Path::new("C:\\Users\\foo\\bar"),
Path::new("C:\\Users")
));
assert!(is_contained_by(
Path::new("c:\\users\\foo\\bar"),
Path::new("C:\\Users")
));
assert!(is_contained_by(
Path::new("C:\\Users\\foo"),
Path::new("c:\\users")
));
}
#[test]
fn test_check_containment_no_roots() {
let result = check_containment(Path::new("/a/b/c"), &[]);
assert!(result.is_err());
}
#[test]
fn test_check_containment_allowed() {
let roots = vec![
PathBuf::from("/home/user/project"),
PathBuf::from("/tmp"),
];
// Inside first root
assert!(check_containment(
Path::new("/home/user/project/src/main.rs"),
&roots
).is_ok());
// Inside second root
assert!(check_containment(
Path::new("/tmp/foo.txt"),
&roots
).is_ok());
}
#[test]
fn test_check_containment_denied() {
let roots = vec![
PathBuf::from("/home/user/project"),
];
// Outside all roots
assert!(check_containment(
Path::new("/etc/passwd"),
&roots
).is_err());
// Sibling directory (string prefix but not path prefix)
assert!(check_containment(
Path::new("/home/user/project2/file.txt"),
&roots
).is_err());
// Parent directory
assert!(check_containment(
Path::new("/home/user/other.txt"),
&roots
).is_err());
// Equal to root (not strictly contained)
assert!(check_containment(
Path::new("/home/user/project"),
&roots
).is_err());
}
#[cfg(windows)]
#[test]
fn test_check_containment_windows() {
let roots = vec![
PathBuf::from("C:\\Users\\foo"),
];
// Allowed
assert!(check_containment(
Path::new("C:\\Users\\foo\\Documents\\file.txt"),
&roots
).is_ok());
// Denied - different drive
assert!(check_containment(
Path::new("D:\\file.txt"),
&roots
).is_err());
// Denied - outside root
assert!(check_containment(
Path::new("C:\\Windows\\system32"),
&roots
).is_err());
}
}
+165
View File
@@ -0,0 +1,165 @@
//! Path validation facade combining all security checks.
//!
//! This module provides a single entry point for path validation that:
//! 1. Canonicalizes the path
//! 2. Checks containment within allowed roots
//! 3. Evaluates blocklist patterns
//!
//! All tool operations should use this facade for consistent security enforcement.
use crate::config::SandboxConfig;
use crate::error::ToolResult;
use std::path::{Path, PathBuf};
use super::blocklist::{compile_blocklist, check_blocklist_compiled, CompiledBlocklist};
use super::canonicalize::{canonicalize_path, SymlinkPolicy};
use super::containment::check_containment;
/// Validate a path against sandbox configuration.
///
/// This is the main entry point for path validation. It performs all security checks:
/// 1. Path canonicalization (handles symlinks, Windows paths, etc.)
/// 2. Containment checking (ensures path is within allowed roots)
/// 3. Blocklist evaluation (checks against denied patterns)
///
/// # Arguments
///
/// * `user_path` - The path provided by the user/agent (must be absolute)
/// * `config` - Sandbox configuration with allowed roots and blocklist
///
/// # Returns
///
/// The canonical path if all checks pass, or a security error otherwise.
///
/// # Example
///
/// ```rust,no_run
/// use dirigent_tools::path::validate_path;
/// use dirigent_tools::config::SandboxConfig;
/// use std::path::PathBuf;
///
/// let mut config = SandboxConfig::default();
/// config.allowed_roots = vec![PathBuf::from("/home/user/project")];
/// config.blocked_paths = vec!["**/.env".to_string()];
///
/// // Valid path
/// let result = validate_path("/home/user/project/src/main.rs", &config);
/// assert!(result.is_ok());
///
/// // Blocked path
/// let result = validate_path("/home/user/project/.env", &config);
/// assert!(result.is_err());
///
/// // Outside roots
/// let result = validate_path("/etc/passwd", &config);
/// assert!(result.is_err());
/// ```
pub fn validate_path(user_path: &str, config: &SandboxConfig) -> ToolResult<PathBuf> {
// Step 1: Canonicalize the path
let symlink_policy = SymlinkPolicy {
allow_symlink_escape: config.allow_symlink_escape,
follow_symlinks_within_roots: config.follow_symlinks_within_roots,
};
let canonical_path = canonicalize_path(Path::new(user_path), &symlink_policy)?;
// Step 2: Check containment
check_containment(&canonical_path, &config.allowed_roots)?;
// Step 3: Check blocklist
if !config.blocked_paths.is_empty() {
// Compile blocklist on the fly (for now)
// TODO: Pre-compile blocklist in SandboxConfig for better performance
let compiled = compile_blocklist(&config.blocked_paths)?;
check_blocklist_compiled(&canonical_path, &compiled)?;
}
Ok(canonical_path)
}
/// Validate a path with a pre-compiled blocklist (for performance).
///
/// This is a faster variant of `validate_path` that uses a pre-compiled blocklist.
/// Useful when validating many paths with the same configuration.
///
/// # Arguments
///
/// * `user_path` - The path provided by the user/agent (must be absolute)
/// * `config` - Sandbox configuration
/// * `compiled_blocklist` - Pre-compiled blocklist (can be None if blocklist is empty)
///
/// # Returns
///
/// The canonical path if all checks pass, or a security error otherwise.
pub fn validate_path_compiled(
user_path: &str,
config: &SandboxConfig,
compiled_blocklist: Option<&CompiledBlocklist>,
) -> ToolResult<PathBuf> {
// Step 1: Canonicalize the path
let symlink_policy = SymlinkPolicy {
allow_symlink_escape: config.allow_symlink_escape,
follow_symlinks_within_roots: config.follow_symlinks_within_roots,
};
let canonical_path = canonicalize_path(Path::new(user_path), &symlink_policy)?;
// Step 2: Check containment
check_containment(&canonical_path, &config.allowed_roots)?;
// Step 3: Check blocklist (if provided)
if let Some(compiled) = compiled_blocklist {
check_blocklist_compiled(&canonical_path, &compiled)?;
}
Ok(canonical_path)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ToolError;
#[test]
fn test_validate_path_no_roots() {
let mut config = SandboxConfig::default();
config.allowed_roots = vec![]; // No roots allowed
// Any path should fail with no roots
let result = validate_path("/any/path", &config);
assert!(result.is_err());
assert!(matches!(result, Err(ToolError::SandboxViolation { .. })));
}
#[test]
fn test_validate_path_empty_blocklist() {
// Note: Integration tests with real filesystem are in tests/path_normalization.rs
// Unit tests here focus on error conditions without filesystem dependencies
let config = SandboxConfig::default();
assert_eq!(config.blocked_paths.len(), 2); // Default has .env and secrets
}
#[test]
fn test_compile_blocklist_empty() {
let compiled = compile_blocklist(&[]).unwrap();
assert_eq!(compiled.glob_set().len(), 0);
}
#[test]
fn test_compile_blocklist_valid() {
let patterns = vec!["**/.env".to_string(), "**/secrets/**".to_string()];
let result = compile_blocklist(&patterns);
assert!(result.is_ok());
let compiled = result.unwrap();
assert_eq!(compiled.glob_set().len(), 2);
}
#[test]
fn test_compile_blocklist_invalid() {
let patterns = vec!["[invalid".to_string()];
let result = compile_blocklist(&patterns);
assert!(result.is_err());
assert!(matches!(result, Err(ToolError::InvalidConfig { .. })));
}
}
+59
View File
@@ -0,0 +1,59 @@
//! Permission prompt system and decision caching.
//!
//! **Status**: Implemented (TOOLS-PERM-01 through TOOLS-PERM-04)
//!
//! This module provides:
//! - **Permission checking** - Core algorithm integrating cache, whitelist, and ACP prompts
//! - **Decision caching** - Thread-safe cache with TTL and scope support
//! - **Whitelist matching** - Pattern-based auto-approval for safe operations
//! - **ACP integration** - User prompts via session/request_permission
//!
//! ## Module Structure
//!
//! - [`check`] - Core permission check function
//! - [`cache`] - Decision cache with TTL
//! - [`whitelist`] - Whitelist pattern matching
//! - [`acp`] - ACP integration for user prompts
//!
//! ## Quick Start
//!
//! ```rust,no_run
//! use dirigent_tools::config::{PermissionConfig, PermissionMode, WhitelistConfig};
//! use dirigent_tools::permission::check::{PermissionContext, check_permission};
//! use dirigent_tools::permission::whitelist::{CompiledWhitelist, PermissionOperation};
//!
//! # async fn example() {
//! // Configure permission system
//! let config = PermissionConfig {
//! mode: PermissionMode::Whitelist,
//! remember_decisions: true,
//! remember_ttl_secs: 3600,
//! ..Default::default()
//! };
//!
//! // Create context with whitelist
//! let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
//! let context = PermissionContext::new(
//! "connector-1".to_string(),
//! Some("session-1".to_string()),
//! whitelist,
//! );
//!
//! // Check permission for an operation
//! let operation = PermissionOperation::Write {
//! path: "/path/to/file".to_string(),
//! };
//! let decision = check_permission(&operation, &context, &config).await.unwrap();
//! # }
//! ```
pub mod check;
pub mod cache;
pub mod whitelist;
pub mod acp;
// Re-export commonly used types
pub use check::{check_permission, PermissionContext};
pub use cache::{CacheKey, DecisionCache, PermissionDecision};
pub use whitelist::{CompiledWhitelist, PermissionOperation, matches_whitelist};
pub use acp::{AcpPermissionContext, PermissionOutcome, request_permission_from_user};
+242
View File
@@ -0,0 +1,242 @@
//! ACP integration for permission prompts.
//!
//! **Status**: Implemented (TOOLS-PERM-03) - Stub for ACP integration
//!
//! This module provides integration with ACP's `session/request_permission` capability
//! to prompt users for permissions when operations require approval.
//!
//! ## Permission Outcomes
//!
//! Users can respond with:
//! - **AllowOnce**: Approve this operation only
//! - **AllowAlways**: Approve this operation and remember the decision
//! - **RejectOnce**: Deny this operation only
//! - **RejectAlways**: Deny this operation and remember the decision
//! - **Cancelled**: User cancelled the prompt (treat as rejection)
//!
//! ## ACP Protocol Integration
//!
//! This module calls the ACP client's `session/request_permission` handler (agent → client).
//! The actual implementation depends on the ACP client infrastructure in `dirigent_core`.
//!
//! For now, this provides a stub that can be replaced with real ACP calls when the
//! infrastructure is ready.
use crate::error::ToolResult;
use crate::permission::whitelist::PermissionOperation;
/// Permission outcome from user prompt.
///
/// Maps directly to ACP session/request_permission response options.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PermissionOutcome {
/// Allow this operation once (do not cache)
AllowOnce,
/// Allow this operation and remember for future (cache with TTL)
AllowAlways,
/// Reject this operation once (do not cache)
RejectOnce,
/// Reject this operation and remember for future (cache with TTL)
RejectAlways,
/// User cancelled the prompt (treat as rejection)
Cancelled,
}
impl PermissionOutcome {
/// Check if this outcome should be cached.
pub fn should_cache(&self) -> bool {
matches!(self, Self::AllowAlways | Self::RejectAlways)
}
/// Check if this outcome allows the operation.
pub fn is_allowed(&self) -> bool {
matches!(self, Self::AllowOnce | Self::AllowAlways)
}
/// Convert to cache decision if this outcome should be cached.
pub fn to_cache_decision(&self) -> Option<crate::permission::cache::PermissionDecision> {
use crate::permission::cache::PermissionDecision;
match self {
Self::AllowAlways => Some(PermissionDecision::Allowed),
Self::RejectAlways => Some(PermissionDecision::Denied),
Self::Cancelled => Some(PermissionDecision::Cancelled),
Self::AllowOnce | Self::RejectOnce => None,
}
}
}
/// Context for ACP permission requests.
///
/// Contains information needed to make the ACP call:
/// - Connector and session identifiers
/// - Reference to the ACP session object (when available)
#[derive(Debug, Clone)]
pub struct AcpPermissionContext {
/// Connector ID
pub connector_id: String,
/// Session ID (if in a session)
pub session_id: Option<String>,
// TODO: Add ACP session handle when available
// pub session_handle: Arc<AcpSession>,
}
/// Request permission from user via ACP session/request_permission.
///
/// This calls the ACP protocol's `session/request_permission` capability to present
/// a permission prompt to the user with the following options:
/// - Allow once
/// - Allow always (remember decision)
/// - Reject once
/// - Reject always (remember decision)
///
/// ## Implementation Status
///
/// **TODO**: This is currently a stub that returns `AllowOnce` for all operations.
/// Once the ACP client infrastructure in `dirigent_core` is ready, this should be
/// replaced with actual ACP calls.
///
/// ## Expected ACP Call
///
/// ```json
/// {
/// "method": "session/request_permission",
/// "params": {
/// "session_id": "session-123",
/// "operation": "write",
/// "path": "C:/work/project/file.txt",
/// "description": "Write to file C:/work/project/file.txt"
/// }
/// }
/// ```
///
/// ## Expected Response
///
/// ```json
/// {
/// "result": {
/// "outcome": "allow_always" | "allow_once" | "reject_always" | "reject_once" | "cancelled"
/// }
/// }
/// ```
///
/// ## Errors
///
/// - `ToolError::PermissionDenied` if the ACP call fails or times out
/// - `ToolError::AcpError` if the ACP protocol reports an error
pub async fn request_permission_from_user(
operation: &PermissionOperation,
context: &AcpPermissionContext,
) -> ToolResult<PermissionOutcome> {
// TODO: Replace with actual ACP call when infrastructure is ready
//
// Expected implementation:
// 1. Get session handle from context
// 2. Call session.request_permission(operation, description)
// 3. Wait for user response (with timeout)
// 4. Map ACP response to PermissionOutcome
// 5. Handle errors and timeouts
tracing::warn!(
operation = %operation.display(),
connector_id = %context.connector_id,
session_id = ?context.session_id,
"TODO: Actual ACP request_permission call not yet implemented (TOOLS-PERM-03)"
);
// For now, always return AllowOnce as a safe default for development
// In production, this should prompt the user via ACP
Ok(PermissionOutcome::AllowOnce)
}
/// Mock implementation for testing (when ACP infrastructure is not available).
///
/// This allows tests to simulate different permission outcomes without requiring
/// a full ACP stack.
#[cfg(test)]
pub async fn request_permission_mock(
_operation: &PermissionOperation,
outcome: PermissionOutcome,
) -> ToolResult<PermissionOutcome> {
Ok(outcome)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_permission_outcome_should_cache() {
assert!(PermissionOutcome::AllowAlways.should_cache());
assert!(PermissionOutcome::RejectAlways.should_cache());
assert!(!PermissionOutcome::AllowOnce.should_cache());
assert!(!PermissionOutcome::RejectOnce.should_cache());
assert!(!PermissionOutcome::Cancelled.should_cache());
}
#[test]
fn test_permission_outcome_is_allowed() {
assert!(PermissionOutcome::AllowOnce.is_allowed());
assert!(PermissionOutcome::AllowAlways.is_allowed());
assert!(!PermissionOutcome::RejectOnce.is_allowed());
assert!(!PermissionOutcome::RejectAlways.is_allowed());
assert!(!PermissionOutcome::Cancelled.is_allowed());
}
#[test]
fn test_permission_outcome_to_cache_decision() {
use crate::permission::cache::PermissionDecision;
assert_eq!(
PermissionOutcome::AllowAlways.to_cache_decision(),
Some(PermissionDecision::Allowed)
);
assert_eq!(
PermissionOutcome::RejectAlways.to_cache_decision(),
Some(PermissionDecision::Denied)
);
assert_eq!(
PermissionOutcome::Cancelled.to_cache_decision(),
Some(PermissionDecision::Cancelled)
);
assert_eq!(PermissionOutcome::AllowOnce.to_cache_decision(), None);
assert_eq!(PermissionOutcome::RejectOnce.to_cache_decision(), None);
}
#[tokio::test]
async fn test_request_permission_stub() {
let operation = PermissionOperation::Write {
path: "/test/path".to_string(),
};
let context = AcpPermissionContext {
connector_id: "test-connector".to_string(),
session_id: Some("test-session".to_string()),
};
// Stub currently returns AllowOnce
let outcome = request_permission_from_user(&operation, &context)
.await
.unwrap();
assert_eq!(outcome, PermissionOutcome::AllowOnce);
}
#[tokio::test]
async fn test_request_permission_mock() {
let operation = PermissionOperation::Execute {
command: "test".to_string(),
cwd: "/".to_string(),
};
// Test all outcomes
for outcome in [
PermissionOutcome::AllowOnce,
PermissionOutcome::AllowAlways,
PermissionOutcome::RejectOnce,
PermissionOutcome::RejectAlways,
PermissionOutcome::Cancelled,
] {
let result = request_permission_mock(&operation, outcome).await.unwrap();
assert_eq!(result, outcome);
}
}
}
@@ -0,0 +1,389 @@
//! Decision cache with TTL and scope support.
//!
//! **Status**: Implemented (TOOLS-PERM-02)
//!
//! This module provides thread-safe caching of permission decisions with:
//! - TTL (time-to-live) for cached decisions
//! - Scope support (per-connector or per-session)
//! - Automatic expiration of stale entries
//! - Hash-based cache keys for efficient lookups
use crate::config::DecisionScope;
use std::collections::HashMap;
use std::hash::Hash;
use std::time::{Duration, Instant};
/// Permission decision outcome.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PermissionDecision {
Allowed,
Denied,
Cancelled,
}
/// Cache key for permission decisions.
///
/// Keys are constructed from:
/// - Operation kind (read, write, execute)
/// - Normalized path or command
/// - Scope identifier (connector_id or session_id)
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CacheKey {
/// Operation discriminant
operation_kind: OperationKind,
/// Normalized path or command string
target: String,
/// Scope identifier (connector or session)
scope_id: String,
}
/// Operation kind for cache key discrimination.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum OperationKind {
Read,
Write,
Execute,
}
impl CacheKey {
/// Create a cache key for a read operation.
pub fn read(path: &str, connector_id: &str, scope: DecisionScope) -> Self {
Self {
operation_kind: OperationKind::Read,
target: path.to_string(),
scope_id: Self::scope_id(connector_id, None, scope),
}
}
/// Create a cache key for a write operation.
pub fn write(path: &str, connector_id: &str, session_id: Option<&str>, scope: DecisionScope) -> Self {
Self {
operation_kind: OperationKind::Write,
target: path.to_string(),
scope_id: Self::scope_id(connector_id, session_id, scope),
}
}
/// Create a cache key for an execute operation.
pub fn execute(command: &str, connector_id: &str, session_id: Option<&str>, scope: DecisionScope) -> Self {
Self {
operation_kind: OperationKind::Execute,
target: command.to_string(),
scope_id: Self::scope_id(connector_id, session_id, scope),
}
}
/// Construct scope identifier based on scope policy.
fn scope_id(connector_id: &str, session_id: Option<&str>, scope: DecisionScope) -> String {
match scope {
DecisionScope::PerConnector => connector_id.to_string(),
DecisionScope::PerSession => {
format!("{}:{}", connector_id, session_id.unwrap_or("default"))
}
}
}
}
/// Cached permission decision with expiration.
#[derive(Debug, Clone)]
struct CachedDecision {
/// The permission decision
decision: PermissionDecision,
/// When this entry expires
expires_at: Instant,
}
impl CachedDecision {
/// Create a new cached decision with TTL.
fn new(decision: PermissionDecision, ttl: Duration) -> Self {
Self {
decision,
expires_at: Instant::now() + ttl,
}
}
/// Check if this entry has expired.
fn is_expired(&self) -> bool {
Instant::now() >= self.expires_at
}
}
/// Thread-safe decision cache with TTL support.
///
/// The cache stores permission decisions keyed by operation, path/command, and scope.
/// Entries automatically expire after their TTL, and expired entries are pruned on access.
///
/// ## Thread Safety
///
/// The cache is designed to be wrapped in `Arc<Mutex<DecisionCache>>` for thread-safe access.
/// Individual operations (get, insert, clear) should acquire the lock briefly.
///
/// ## Example
///
/// ```rust
/// use dirigent_tools::permission::cache::{DecisionCache, PermissionDecision, CacheKey};
/// use dirigent_tools::config::DecisionScope;
/// use std::time::Duration;
///
/// let mut cache = DecisionCache::new();
/// let key = CacheKey::write("/path/to/file", "connector-1", None, DecisionScope::PerConnector);
///
/// // Cache a decision
/// cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_secs(300));
///
/// // Retrieve it
/// assert_eq!(cache.get(&key), Some(PermissionDecision::Allowed));
/// ```
#[derive(Debug)]
pub struct DecisionCache {
entries: HashMap<CacheKey, CachedDecision>,
}
impl DecisionCache {
/// Create a new empty decision cache.
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
/// Get a cached decision if it exists and hasn't expired.
///
/// Expired entries are automatically removed.
///
/// Returns `Some(decision)` if a valid cached entry exists, `None` otherwise.
pub fn get(&mut self, key: &CacheKey) -> Option<PermissionDecision> {
// Check if entry exists
if let Some(cached) = self.entries.get(key) {
// Check if expired
if cached.is_expired() {
// Remove expired entry
self.entries.remove(key);
None
} else {
// Return valid decision
Some(cached.decision)
}
} else {
None
}
}
/// Insert a new decision into the cache with the given TTL.
///
/// If an entry already exists for this key, it will be replaced.
pub fn insert(&mut self, key: CacheKey, decision: PermissionDecision, ttl: Duration) {
self.entries.insert(key, CachedDecision::new(decision, ttl));
}
/// Clear all cached decisions.
///
/// Useful for manual cache invalidation or testing.
pub fn clear(&mut self) {
self.entries.clear();
}
/// Remove all expired entries from the cache.
///
/// This is automatically done during `get()` operations, but can be called
/// periodically to prune the cache and free memory.
///
/// Returns the number of expired entries removed.
pub fn clear_expired(&mut self) -> usize {
let before = self.entries.len();
self.entries.retain(|_, cached| !cached.is_expired());
before - self.entries.len()
}
/// Get the number of cached entries (including expired).
///
/// For testing and monitoring purposes.
pub fn len(&self) -> usize {
self.entries.len()
}
/// Check if the cache is empty.
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
impl Default for DecisionCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_key_creation() {
let key1 = CacheKey::write("/path/to/file", "conn-1", None, DecisionScope::PerConnector);
let key2 = CacheKey::write("/path/to/file", "conn-1", None, DecisionScope::PerConnector);
let key3 = CacheKey::write("/path/to/file", "conn-2", None, DecisionScope::PerConnector);
// Same connector and path should produce equal keys
assert_eq!(key1, key2);
// Different connector should produce different keys
assert_ne!(key1, key3);
}
#[test]
fn test_cache_key_scope() {
let key_connector = CacheKey::write(
"/path",
"conn-1",
Some("session-1"),
DecisionScope::PerConnector,
);
let key_session = CacheKey::write(
"/path",
"conn-1",
Some("session-1"),
DecisionScope::PerSession,
);
// Different scopes should produce different keys
assert_ne!(key_connector, key_session);
}
#[test]
fn test_cache_key_operation_kind() {
let key_read = CacheKey::read("/path", "conn-1", DecisionScope::PerConnector);
let key_write = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector);
// Different operations should produce different keys
assert_ne!(key_read, key_write);
}
#[test]
fn test_cache_insert_and_get() {
let mut cache = DecisionCache::new();
let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector);
// Initially empty
assert_eq!(cache.get(&key), None);
// Insert decision
cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_secs(300));
// Should retrieve it
assert_eq!(cache.get(&key), Some(PermissionDecision::Allowed));
}
#[test]
fn test_cache_ttl_expiration() {
let mut cache = DecisionCache::new();
let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector);
// Insert with very short TTL
cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_millis(1));
// Wait for expiration
std::thread::sleep(Duration::from_millis(10));
// Should not retrieve expired entry
assert_eq!(cache.get(&key), None);
// Entry should be removed from cache
assert_eq!(cache.len(), 0);
}
#[test]
fn test_cache_clear() {
let mut cache = DecisionCache::new();
let key1 = CacheKey::write("/path1", "conn-1", None, DecisionScope::PerConnector);
let key2 = CacheKey::write("/path2", "conn-1", None, DecisionScope::PerConnector);
cache.insert(key1.clone(), PermissionDecision::Allowed, Duration::from_secs(300));
cache.insert(key2.clone(), PermissionDecision::Denied, Duration::from_secs(300));
assert_eq!(cache.len(), 2);
// Clear all entries
cache.clear();
assert_eq!(cache.len(), 0);
assert_eq!(cache.get(&key1), None);
assert_eq!(cache.get(&key2), None);
}
#[test]
fn test_cache_clear_expired() {
let mut cache = DecisionCache::new();
let key1 = CacheKey::write("/path1", "conn-1", None, DecisionScope::PerConnector);
let key2 = CacheKey::write("/path2", "conn-1", None, DecisionScope::PerConnector);
// Insert one short-lived and one long-lived entry
cache.insert(key1.clone(), PermissionDecision::Allowed, Duration::from_millis(1));
cache.insert(key2.clone(), PermissionDecision::Denied, Duration::from_secs(300));
assert_eq!(cache.len(), 2);
// Wait for first to expire
std::thread::sleep(Duration::from_millis(10));
// Clear expired entries
let removed = cache.clear_expired();
assert_eq!(removed, 1);
assert_eq!(cache.len(), 1);
// Second entry should still be accessible
assert_eq!(cache.get(&key2), Some(PermissionDecision::Denied));
}
#[test]
fn test_cache_replace_entry() {
let mut cache = DecisionCache::new();
let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector);
// Insert initial decision
cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_secs(300));
assert_eq!(cache.get(&key), Some(PermissionDecision::Allowed));
// Replace with different decision
cache.insert(key.clone(), PermissionDecision::Denied, Duration::from_secs(300));
assert_eq!(cache.get(&key), Some(PermissionDecision::Denied));
// Should only have one entry
assert_eq!(cache.len(), 1);
}
#[test]
fn test_cache_different_sessions() {
let mut cache = DecisionCache::new();
let key_session1 = CacheKey::write(
"/path",
"conn-1",
Some("session-1"),
DecisionScope::PerSession,
);
let key_session2 = CacheKey::write(
"/path",
"conn-1",
Some("session-2"),
DecisionScope::PerSession,
);
// Insert different decisions for different sessions
cache.insert(key_session1.clone(), PermissionDecision::Allowed, Duration::from_secs(300));
cache.insert(key_session2.clone(), PermissionDecision::Denied, Duration::from_secs(300));
// Each session should have its own decision
assert_eq!(cache.get(&key_session1), Some(PermissionDecision::Allowed));
assert_eq!(cache.get(&key_session2), Some(PermissionDecision::Denied));
}
#[test]
fn test_cache_cancelled_decision() {
let mut cache = DecisionCache::new();
let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector);
// Cancelled decisions can also be cached
cache.insert(key.clone(), PermissionDecision::Cancelled, Duration::from_secs(300));
assert_eq!(cache.get(&key), Some(PermissionDecision::Cancelled));
}
}
@@ -0,0 +1,366 @@
//! Core permission check function integrating cache, whitelist, and ACP prompts.
//!
//! **Status**: Implemented (TOOLS-PERM-01)
//!
//! This module provides the main `check_permission` function that orchestrates:
//! - Permission mode evaluation (yolo, whitelist, ask)
//! - Decision cache lookup and updates
//! - Whitelist pattern matching
//! - ACP permission prompts
//!
//! ## Algorithm
//!
//! 1. **Yolo mode**: Always allow (with audit logging)
//! 2. **Check cache**: Return cached decision if present and valid
//! 3. **Whitelist mode**: Auto-approve if operation matches whitelist
//! 4. **Ask mode / No whitelist match**: Prompt user via ACP
//! 5. **Cache decision**: Store if user selected "always" option
//! 6. **Return decision**: Allow or deny the operation
use crate::config::{DecisionScope, PermissionConfig, PermissionMode};
use crate::error::ToolResult;
use crate::permission::acp::{request_permission_from_user, AcpPermissionContext, PermissionOutcome};
use crate::permission::cache::{CacheKey, DecisionCache, PermissionDecision};
use crate::permission::whitelist::{matches_whitelist, CompiledWhitelist, PermissionOperation};
use std::sync::{Arc, Mutex};
use std::time::Duration;
/// Context for permission checks.
///
/// This should be created once per connector or session and reused for all
/// permission checks to maintain consistent cache state.
#[derive(Clone)]
pub struct PermissionContext {
/// Connector ID
pub connector_id: String,
/// Session ID (if in a session)
pub session_id: Option<String>,
/// Shared decision cache (thread-safe)
pub cache: Arc<Mutex<DecisionCache>>,
/// Compiled whitelist for fast matching
pub whitelist: Arc<CompiledWhitelist>,
}
impl PermissionContext {
/// Create a new permission context.
pub fn new(
connector_id: String,
session_id: Option<String>,
whitelist: CompiledWhitelist,
) -> Self {
Self {
connector_id,
session_id,
cache: Arc::new(Mutex::new(DecisionCache::new())),
whitelist: Arc::new(whitelist),
}
}
/// Clear the decision cache (for testing or manual reset).
pub fn clear_cache(&self) {
if let Ok(mut cache) = self.cache.lock() {
cache.clear();
}
}
/// Get cache statistics (for monitoring/debugging).
pub fn cache_size(&self) -> usize {
self.cache.lock().map(|c| c.len()).unwrap_or(0)
}
}
/// Check if an operation requires permission and prompt if needed.
///
/// This is the main entry point for permission checks. It implements the full
/// permission algorithm including mode evaluation, caching, whitelist matching,
/// and ACP prompts.
///
/// ## Algorithm
///
/// 1. If mode is `Yolo`: Always allow (with audit log)
/// 2. Check decision cache (with TTL)
/// 3. If cached decision exists: Return it
/// 4. If mode is `Whitelist` and operation matches: Allow
/// 5. Otherwise: Call ACP `session/request_permission`
/// 6. If outcome is "always": Cache the decision
/// 7. Return decision (Allowed/Denied/Cancelled)
///
/// ## Examples
///
/// ```rust,no_run
/// use dirigent_tools::config::{PermissionConfig, PermissionMode, WhitelistConfig};
/// use dirigent_tools::permission::check::{PermissionContext, check_permission};
/// use dirigent_tools::permission::whitelist::{CompiledWhitelist, PermissionOperation};
///
/// # async fn example() {
/// let config = PermissionConfig {
/// mode: PermissionMode::Ask,
/// remember_decisions: true,
/// remember_ttl_secs: 3600,
/// ..Default::default()
/// };
///
/// let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
/// let context = PermissionContext::new(
/// "connector-1".to_string(),
/// Some("session-1".to_string()),
/// whitelist,
/// );
///
/// let operation = PermissionOperation::Write {
/// path: "/path/to/file".to_string(),
/// };
///
/// let decision = check_permission(&operation, &context, &config).await.unwrap();
/// # }
/// ```
pub async fn check_permission(
operation: &PermissionOperation,
context: &PermissionContext,
config: &PermissionConfig,
) -> ToolResult<PermissionDecision> {
// Step 1: Yolo mode - always allow
if config.mode == PermissionMode::Yolo {
tracing::info!(
operation = %operation.display(),
mode = "yolo",
"Permission check: auto-approved (yolo mode)"
);
return Ok(PermissionDecision::Allowed);
}
// Step 2: Check decision cache
if config.remember_decisions {
let cache_key = create_cache_key(operation, context, config.scope);
if let Ok(mut cache) = context.cache.lock() {
if let Some(cached_decision) = cache.get(&cache_key) {
tracing::debug!(
operation = %operation.display(),
decision = ?cached_decision,
"Permission check: using cached decision"
);
return Ok(cached_decision);
}
}
}
// Step 3: Whitelist mode - check if operation matches whitelist
if config.mode == PermissionMode::Whitelist {
if matches_whitelist(operation, &context.whitelist) {
tracing::info!(
operation = %operation.display(),
mode = "whitelist",
"Permission check: auto-approved (whitelist match)"
);
return Ok(PermissionDecision::Allowed);
}
}
// Step 4: Prompt user via ACP
tracing::debug!(
operation = %operation.display(),
mode = ?config.mode,
"Permission check: prompting user"
);
let acp_context = AcpPermissionContext {
connector_id: context.connector_id.clone(),
session_id: context.session_id.clone(),
};
let outcome = request_permission_from_user(operation, &acp_context).await?;
// Step 5: Convert outcome to decision
let decision = if outcome.is_allowed() {
PermissionDecision::Allowed
} else {
match outcome {
PermissionOutcome::Cancelled => PermissionDecision::Cancelled,
_ => PermissionDecision::Denied,
}
};
// Step 6: Cache decision if "always" option was selected
if config.remember_decisions && outcome.should_cache() {
if let Some(cache_decision) = outcome.to_cache_decision() {
let cache_key = create_cache_key(operation, context, config.scope);
let ttl = Duration::from_secs(config.remember_ttl_secs);
if let Ok(mut cache) = context.cache.lock() {
cache.insert(cache_key, cache_decision, ttl);
tracing::debug!(
operation = %operation.display(),
decision = ?cache_decision,
ttl_secs = config.remember_ttl_secs,
"Cached permission decision"
);
}
}
}
tracing::info!(
operation = %operation.display(),
decision = ?decision,
outcome = ?outcome,
"Permission check: decision made"
);
Ok(decision)
}
/// Create a cache key for an operation with the appropriate scope.
fn create_cache_key(
operation: &PermissionOperation,
context: &PermissionContext,
scope: DecisionScope,
) -> CacheKey {
match operation {
PermissionOperation::Read { path } => {
CacheKey::read(path, &context.connector_id, scope)
}
PermissionOperation::Write { path } => {
CacheKey::write(path, &context.connector_id, context.session_id.as_deref(), scope)
}
PermissionOperation::Execute { command, .. } => {
CacheKey::execute(command, &context.connector_id, context.session_id.as_deref(), scope)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::WhitelistConfig;
fn create_test_config(mode: PermissionMode) -> PermissionConfig {
PermissionConfig {
mode,
remember_decisions: true,
remember_ttl_secs: 300,
scope: DecisionScope::PerConnector,
whitelist: WhitelistConfig::default(),
}
}
fn create_test_context() -> PermissionContext {
let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
PermissionContext::new(
"test-connector".to_string(),
Some("test-session".to_string()),
whitelist,
)
}
#[tokio::test]
async fn test_yolo_mode_always_allows() {
let config = create_test_config(PermissionMode::Yolo);
let context = create_test_context();
let operations = vec![
PermissionOperation::Read { path: "/any/path".to_string() },
PermissionOperation::Write { path: "/any/path".to_string() },
PermissionOperation::Execute {
command: "dangerous_command".to_string(),
cwd: "/".to_string(),
},
];
for operation in operations {
let decision = check_permission(&operation, &context, &config).await.unwrap();
assert_eq!(decision, PermissionDecision::Allowed);
}
}
#[tokio::test]
async fn test_whitelist_mode_read_always_allowed() {
let config = create_test_config(PermissionMode::Whitelist);
let context = create_test_context();
let operation = PermissionOperation::Read {
path: "/any/path".to_string(),
};
let decision = check_permission(&operation, &context, &config).await.unwrap();
assert_eq!(decision, PermissionDecision::Allowed);
}
#[tokio::test]
async fn test_whitelist_mode_with_pattern() {
let mut config = create_test_config(PermissionMode::Whitelist);
config.whitelist = WhitelistConfig {
write_paths: vec!["C:/work/**".to_string()],
execute_commands: vec!["cargo".to_string()],
};
let whitelist = CompiledWhitelist::compile(&config.whitelist).unwrap();
let context = PermissionContext::new(
"test-connector".to_string(),
None,
whitelist,
);
// Should match whitelist
let write_ok = PermissionOperation::Write {
path: "C:/work/project/file.txt".to_string(),
};
let decision = check_permission(&write_ok, &context, &config).await.unwrap();
assert_eq!(decision, PermissionDecision::Allowed);
// TODO: Test non-matching write would require ACP mock
// This is tested in integration tests
}
#[tokio::test]
async fn test_cache_key_creation() {
let context = create_test_context();
let scope = DecisionScope::PerConnector;
let read_op = PermissionOperation::Read {
path: "/path".to_string(),
};
let key1 = create_cache_key(&read_op, &context, scope);
let key2 = create_cache_key(&read_op, &context, scope);
// Same operation should produce same key
assert_eq!(key1, key2);
}
#[tokio::test]
async fn test_context_cache_operations() {
let context = create_test_context();
assert_eq!(context.cache_size(), 0);
// Add entry to cache
{
let mut cache = context.cache.lock().unwrap();
let key = CacheKey::write("/path", "test", None, DecisionScope::PerConnector);
cache.insert(key, PermissionDecision::Allowed, Duration::from_secs(300));
}
assert_eq!(context.cache_size(), 1);
// Clear cache
context.clear_cache();
assert_eq!(context.cache_size(), 0);
}
#[test]
fn test_permission_context_clone() {
let context = create_test_context();
let cloned = context.clone();
// Should share the same cache
assert_eq!(context.cache_size(), cloned.cache_size());
// Modifications to one should affect the other
{
let mut cache = context.cache.lock().unwrap();
let key = CacheKey::write("/path", "test", None, DecisionScope::PerConnector);
cache.insert(key, PermissionDecision::Allowed, Duration::from_secs(300));
}
assert_eq!(cloned.cache_size(), 1);
}
}
@@ -0,0 +1,397 @@
//! Whitelist pattern matching for auto-approval of safe operations.
//!
//! **Status**: Implemented (TOOLS-PERM-04)
//!
//! This module provides pattern matching against configured whitelists to
//! automatically approve safe operations without prompting the user.
//!
//! ## Whitelist Strategy
//!
//! - **Read operations**: Always approved in whitelist mode (reads are safe)
//! - **Write operations**: Match path against `write_paths` glob patterns
//! - **Execute operations**: Match command against `execute_commands` glob patterns
//!
//! ## Pattern Matching
//!
//! Uses `globset` for efficient compiled glob patterns:
//! - `**` matches any number of path segments
//! - `*` matches any characters within a segment
//! - `?` matches a single character
//! - Character classes like `[abc]` are supported
//!
//! ## Performance
//!
//! Globs are pre-compiled at configuration load time and stored in `CompiledWhitelist`
//! for fast matching during permission checks.
use crate::config::WhitelistConfig;
use crate::error::{ToolError, ToolResult};
use globset::{Glob, GlobSet, GlobSetBuilder};
use std::path::Path;
/// Compiled whitelist with pre-built glob sets for efficient matching.
///
/// This should be constructed once at configuration load time and reused
/// for all permission checks.
#[derive(Debug, Clone)]
pub struct CompiledWhitelist {
/// Compiled glob patterns for write path matching
write_paths: GlobSet,
/// Compiled glob patterns for execute command matching
execute_commands: GlobSet,
}
impl CompiledWhitelist {
/// Compile a whitelist configuration into an efficient matcher.
///
/// ## Errors
///
/// Returns an error if any glob pattern is invalid.
pub fn compile(config: &WhitelistConfig) -> ToolResult<Self> {
let write_paths = Self::compile_patterns(&config.write_paths)?;
let execute_commands = Self::compile_patterns(&config.execute_commands)?;
Ok(Self {
write_paths,
execute_commands,
})
}
/// Compile a list of glob patterns into a GlobSet.
fn compile_patterns(patterns: &[String]) -> ToolResult<GlobSet> {
let mut builder = GlobSetBuilder::new();
for pattern in patterns {
let glob = Glob::new(pattern).map_err(|e| {
ToolError::InvalidConfig(format!("Invalid whitelist glob pattern '{}': {}", pattern, e))
})?;
builder.add(glob);
}
builder.build().map_err(|e| {
ToolError::InvalidConfig(format!("Failed to build whitelist glob set: {}", e))
})
}
/// Check if a read operation matches the whitelist.
///
/// In whitelist mode, all read operations are considered safe and return true.
pub fn matches_read(&self, _path: &Path) -> bool {
// Reads are always safe in whitelist mode
true
}
/// Check if a write operation matches the whitelist.
///
/// Returns true if the path matches any of the configured write_paths patterns.
pub fn matches_write(&self, path: &Path) -> bool {
self.write_paths.is_match(path)
}
/// Check if an execute operation matches the whitelist.
///
/// Returns true if the command matches any of the configured execute_commands patterns.
///
/// The command is compared as a string (not a path) to support both:
/// - Simple command names (e.g., "cargo", "npm")
/// - Command patterns (e.g., "cargo*", "npm*")
pub fn matches_execute(&self, command: &str) -> bool {
// For command matching, we treat the command as a simple string path
// This allows patterns like "cargo", "npm*", etc.
self.execute_commands.is_match(command)
}
/// Check if the whitelist has any write patterns configured.
pub fn has_write_patterns(&self) -> bool {
!self.write_paths.is_empty()
}
/// Check if the whitelist has any execute patterns configured.
pub fn has_execute_patterns(&self) -> bool {
!self.execute_commands.is_empty()
}
}
/// Operation type for whitelist matching.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PermissionOperation {
Read { path: String },
Write { path: String },
Execute { command: String, cwd: String },
}
impl PermissionOperation {
/// Get a display string for this operation (for logging/debugging).
pub fn display(&self) -> String {
match self {
Self::Read { path } => format!("read {}", path),
Self::Write { path } => format!("write {}", path),
Self::Execute { command, cwd } => format!("execute '{}' in {}", command, cwd),
}
}
/// Get the operation kind as a string.
pub fn kind(&self) -> &'static str {
match self {
Self::Read { .. } => "read",
Self::Write { .. } => "write",
Self::Execute { .. } => "execute",
}
}
}
/// Check if an operation matches the configured whitelist.
///
/// ## Returns
///
/// - `true` if the operation should be auto-approved
/// - `false` if the operation requires a permission prompt
///
/// ## Examples
///
/// ```rust
/// use dirigent_tools::config::WhitelistConfig;
/// use dirigent_tools::permission::whitelist::{CompiledWhitelist, PermissionOperation, matches_whitelist};
///
/// let config = WhitelistConfig {
/// write_paths: vec!["C:/work/project/**".to_string()],
/// execute_commands: vec!["cargo".to_string(), "npm".to_string()],
/// };
///
/// let whitelist = CompiledWhitelist::compile(&config).unwrap();
///
/// // Read operations always match
/// let read_op = PermissionOperation::Read { path: "C:/anywhere/file.txt".to_string() };
/// assert!(matches_whitelist(&read_op, &whitelist));
///
/// // Write within allowed path
/// let write_op = PermissionOperation::Write { path: "C:/work/project/src/main.rs".to_string() };
/// assert!(matches_whitelist(&write_op, &whitelist));
///
/// // Execute whitelisted command
/// let exec_op = PermissionOperation::Execute {
/// command: "cargo".to_string(),
/// cwd: "C:/work/project".to_string(),
/// };
/// assert!(matches_whitelist(&exec_op, &whitelist));
/// ```
pub fn matches_whitelist(operation: &PermissionOperation, whitelist: &CompiledWhitelist) -> bool {
match operation {
PermissionOperation::Read { path } => {
whitelist.matches_read(Path::new(path))
}
PermissionOperation::Write { path } => {
whitelist.matches_write(Path::new(path))
}
PermissionOperation::Execute { command, .. } => {
whitelist.matches_execute(command)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_whitelist() -> CompiledWhitelist {
let config = WhitelistConfig {
write_paths: vec![
"C:/work/project/**".to_string(),
"/home/user/project/**".to_string(),
"**/safe_dir/**".to_string(),
],
execute_commands: vec![
"cargo".to_string(),
"npm".to_string(),
"git".to_string(),
"python*".to_string(),
],
};
CompiledWhitelist::compile(&config).unwrap()
}
#[test]
fn test_read_always_matches() {
let whitelist = create_test_whitelist();
// Any read operation should match
assert!(whitelist.matches_read(Path::new("C:/anywhere/file.txt")));
assert!(whitelist.matches_read(Path::new("/random/path")));
assert!(whitelist.matches_read(Path::new("../../../etc/passwd")));
}
#[test]
fn test_write_matches_patterns() {
let whitelist = create_test_whitelist();
// Should match configured patterns
assert!(whitelist.matches_write(Path::new("C:/work/project/src/main.rs")));
assert!(whitelist.matches_write(Path::new("C:/work/project/Cargo.toml")));
assert!(whitelist.matches_write(Path::new("/home/user/project/README.md")));
// Should not match paths outside patterns
assert!(!whitelist.matches_write(Path::new("C:/other/project/file.txt")));
assert!(!whitelist.matches_write(Path::new("/tmp/file.txt")));
}
#[test]
fn test_write_matches_relative_patterns() {
let whitelist = create_test_whitelist();
// Pattern **/safe_dir/** should match anywhere
assert!(whitelist.matches_write(Path::new("some/path/safe_dir/file.txt")));
assert!(whitelist.matches_write(Path::new("safe_dir/file.txt")));
}
#[test]
fn test_execute_matches_commands() {
let whitelist = create_test_whitelist();
// Should match exact command names
assert!(whitelist.matches_execute("cargo"));
assert!(whitelist.matches_execute("npm"));
assert!(whitelist.matches_execute("git"));
// Should match patterns
assert!(whitelist.matches_execute("python"));
assert!(whitelist.matches_execute("python3"));
// Should not match unlisted commands
assert!(!whitelist.matches_execute("rm"));
assert!(!whitelist.matches_execute("format"));
assert!(!whitelist.matches_execute("del"));
}
#[test]
fn test_empty_whitelist() {
let config = WhitelistConfig {
write_paths: vec![],
execute_commands: vec![],
};
let whitelist = CompiledWhitelist::compile(&config).unwrap();
// Reads should still match
assert!(whitelist.matches_read(Path::new("any/path")));
// Writes and executes should not match empty whitelist
assert!(!whitelist.matches_write(Path::new("any/path")));
assert!(!whitelist.matches_execute("any_command"));
}
#[test]
fn test_invalid_glob_pattern() {
let config = WhitelistConfig {
write_paths: vec!["[invalid".to_string()], // Unclosed bracket
execute_commands: vec![],
};
let result = CompiledWhitelist::compile(&config);
assert!(result.is_err());
}
#[test]
fn test_operation_display() {
let read_op = PermissionOperation::Read {
path: "/path/to/file".to_string(),
};
assert_eq!(read_op.display(), "read /path/to/file");
let write_op = PermissionOperation::Write {
path: "/path/to/file".to_string(),
};
assert_eq!(write_op.display(), "write /path/to/file");
let exec_op = PermissionOperation::Execute {
command: "cargo test".to_string(),
cwd: "/project".to_string(),
};
assert_eq!(exec_op.display(), "execute 'cargo test' in /project");
}
#[test]
fn test_operation_kind() {
let read_op = PermissionOperation::Read {
path: "/path".to_string(),
};
assert_eq!(read_op.kind(), "read");
let write_op = PermissionOperation::Write {
path: "/path".to_string(),
};
assert_eq!(write_op.kind(), "write");
let exec_op = PermissionOperation::Execute {
command: "cmd".to_string(),
cwd: "/".to_string(),
};
assert_eq!(exec_op.kind(), "execute");
}
#[test]
fn test_matches_whitelist_function() {
let whitelist = create_test_whitelist();
// Test read
let read_op = PermissionOperation::Read {
path: "C:/any/file.txt".to_string(),
};
assert!(matches_whitelist(&read_op, &whitelist));
// Test write (match)
let write_ok = PermissionOperation::Write {
path: "C:/work/project/file.txt".to_string(),
};
assert!(matches_whitelist(&write_ok, &whitelist));
// Test write (no match)
let write_fail = PermissionOperation::Write {
path: "C:/other/file.txt".to_string(),
};
assert!(!matches_whitelist(&write_fail, &whitelist));
// Test execute (match)
let exec_ok = PermissionOperation::Execute {
command: "cargo".to_string(),
cwd: "/project".to_string(),
};
assert!(matches_whitelist(&exec_ok, &whitelist));
// Test execute (no match)
let exec_fail = PermissionOperation::Execute {
command: "rm".to_string(),
cwd: "/".to_string(),
};
assert!(!matches_whitelist(&exec_fail, &whitelist));
}
#[test]
fn test_has_patterns() {
let config = WhitelistConfig {
write_paths: vec!["some/path/**".to_string()],
execute_commands: vec![],
};
let whitelist = CompiledWhitelist::compile(&config).unwrap();
assert!(whitelist.has_write_patterns());
assert!(!whitelist.has_execute_patterns());
}
#[cfg(target_os = "windows")]
#[test]
fn test_windows_paths() {
let config = WhitelistConfig {
write_paths: vec![
"C:\\work\\project\\**".to_string(),
"\\\\server\\share\\**".to_string(), // UNC path
],
execute_commands: vec![],
};
let whitelist = CompiledWhitelist::compile(&config).unwrap();
// Test Windows backslash paths
assert!(whitelist.matches_write(Path::new("C:\\work\\project\\file.txt")));
// Test UNC paths
assert!(whitelist.matches_write(Path::new("\\\\server\\share\\file.txt")));
}
}
+303
View File
@@ -0,0 +1,303 @@
//! Tool registry: built-ins (compile-time) + per-session dynamic entries.
use crate::tool::{AnyTool, ClientKind, ProtocolKind, ToolContext};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
/// Where a dynamic tool came from.
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub enum ToolSource {
Mcp(Arc<str>),
Custom(Arc<str>),
}
/// A registered dynamic tool with optional client/protocol filters.
#[derive(Clone)]
pub struct DynamicEntry {
pub tool: Arc<dyn AnyTool>,
pub source: ToolSource,
pub only_for_client: Option<ClientKind>,
pub only_for_protocol: Option<ProtocolKind>,
}
/// Resolution policy when a name exists in both built-ins and dynamic.
#[derive(Clone, Debug, Default)]
pub enum CollisionPolicy {
#[default]
BuiltInWins,
DynamicWins,
PerName(HashMap<Arc<str>, Winner>),
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub enum Winner { BuiltIn, Dynamic }
pub struct ToolRegistry {
built_ins: HashMap<&'static str, Arc<dyn AnyTool>>,
dynamic: RwLock<HashMap<Arc<str>, HashMap<String, DynamicEntry>>>,
collision_policy: CollisionPolicy,
}
impl ToolRegistry {
/// Canonical scope key for dynamic-tool storage.
///
/// Session takes precedence so tools registered for a single session
/// don't leak across the connector; absent that, the connector itself
/// scopes the entry.
fn scope_key(ctx: &ToolContext) -> Arc<str> {
ctx.session_id
.clone()
.unwrap_or_else(|| ctx.connector_id.clone())
}
pub fn new(
built_ins: impl IntoIterator<Item = Arc<dyn AnyTool>>,
collision_policy: CollisionPolicy,
) -> Self {
let built_ins = built_ins
.into_iter()
.map(|t| (t.name(), t))
.collect();
Self {
built_ins,
dynamic: RwLock::new(HashMap::new()),
collision_policy,
}
}
pub fn resolve(&self, name: &str, ctx: &ToolContext) -> Option<Arc<dyn AnyTool>> {
let dyn_match = self.find_dynamic(name, ctx);
let builtin = self.built_ins.get(name).cloned();
match (builtin, dyn_match) {
(Some(b), None) => Some(b),
(None, Some(d)) => Some(d.tool),
(None, None) => None,
(Some(b), Some(d)) => match &self.collision_policy {
CollisionPolicy::BuiltInWins => Some(b),
CollisionPolicy::DynamicWins => Some(d.tool),
CollisionPolicy::PerName(map) => match map.get(name) {
Some(Winner::BuiltIn) | None => Some(b),
Some(Winner::Dynamic) => Some(d.tool),
},
},
}
}
/// Enumerate all tools visible in `ctx`'s scope.
///
/// Returns built-in tool names plus any dynamic tools registered under
/// the canonical scope key for `ctx` (see [`Self::scope_key`]) whose
/// optional client/protocol filters match. The result is sorted and
/// deduplicated.
pub fn list(&self, ctx: &ToolContext) -> Vec<String> {
let mut names: Vec<String> = self.built_ins.keys().map(|n| n.to_string()).collect();
let key = Self::scope_key(ctx);
if let Some(per_scope) = self.dynamic.read().unwrap().get(&key) {
for (n, e) in per_scope.iter() {
if entry_matches_ctx(e, ctx) { names.push(n.clone()); }
}
}
names.sort();
names.dedup();
names
}
/// Insert a dynamic tool under `scope_key`.
///
/// Callers must pass the value [`Self::scope_key`] would produce for the
/// registering [`ToolContext`]: the `session_id` if present, otherwise
/// the `connector_id`. Anything else will be invisible to `resolve` and
/// `list` for that context.
pub fn register_dynamic(&self, scope_key: impl Into<Arc<str>>, name: String, entry: DynamicEntry) {
let mut g = self.dynamic.write().unwrap();
g.entry(scope_key.into()).or_default().insert(name, entry);
}
pub fn unregister_scope(&self, scope_key: &str) {
self.dynamic.write().unwrap().remove(scope_key);
}
fn find_dynamic(&self, name: &str, ctx: &ToolContext) -> Option<DynamicEntry> {
let key = Self::scope_key(ctx);
let g = self.dynamic.read().unwrap();
let per_scope = g.get(&key)?;
let entry = per_scope.get(name)?;
if entry_matches_ctx(entry, ctx) { Some(entry.clone()) } else { None }
}
}
fn entry_matches_ctx(entry: &DynamicEntry, ctx: &ToolContext) -> bool {
if let Some(c) = &entry.only_for_client {
if c != &ctx.client_kind { return false; }
}
if let Some(p) = &entry.only_for_protocol {
if p != &ctx.protocol { return false; }
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig};
use crate::permission::check::PermissionContext;
use crate::permission::whitelist::CompiledWhitelist;
use crate::tool::{Tool, ToolEventSink, ToolInput, ToolKind};
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Default)] struct A;
#[derive(Default)] struct B;
#[derive(Serialize, Deserialize, JsonSchema)]
struct Empty {}
macro_rules! impl_t {
($t:ty, $n:literal) => {
#[async_trait]
impl Tool for $t {
type Input = Empty;
type Output = Empty;
const NAME: &'static str = $n;
fn kind() -> ToolKind { ToolKind::Other }
async fn run(
self: Arc<Self>, _i: ToolInput<Empty>,
_e: ToolEventSink, _c: &ToolContext,
) -> Result<Empty, Empty> { Ok(Empty {}) }
}
};
}
impl_t!(A, "a");
impl_t!(B, "b");
fn ctx() -> ToolContext {
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
let pc = PermissionContext::new("conn-1".to_string(), None, wl);
ToolContext::for_test(
"conn-1", ClientKind::claude(), ProtocolKind::acp(),
PathBuf::from("/tmp"),
SandboxConfig::default(), PermissionConfig::default(), pc,
)
}
fn arcs() -> Vec<Arc<dyn AnyTool>> {
vec![
<A as Tool>::erase(Arc::new(A)),
<B as Tool>::erase(Arc::new(B)),
]
}
#[test]
fn resolves_built_in_by_name() {
let r = ToolRegistry::new(arcs(), CollisionPolicy::BuiltInWins);
assert!(r.resolve("a", &ctx()).is_some());
assert!(r.resolve("nope", &ctx()).is_none());
}
#[test]
fn list_includes_built_ins() {
let r = ToolRegistry::new(arcs(), CollisionPolicy::BuiltInWins);
let mut names: Vec<String> = r.list(&ctx()).into_iter().collect();
names.sort();
assert_eq!(names, vec!["a".to_string(), "b".to_string()]);
}
fn dynamic_entry_named(name: &str, only_client: Option<ClientKind>) -> DynamicEntry {
let _ = name;
DynamicEntry {
tool: <A as Tool>::erase(Arc::new(A)),
source: ToolSource::Mcp(Arc::from("server-1")),
only_for_client: only_client,
only_for_protocol: None,
}
}
#[test]
fn dynamic_resolves_when_no_built_in() {
let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins);
r.register_dynamic("conn-1", "extra".to_string(), dynamic_entry_named("extra", None));
assert!(r.resolve("extra", &ctx()).is_some());
}
#[test]
fn collision_default_built_in_wins() {
let r = ToolRegistry::new(arcs(), CollisionPolicy::BuiltInWins);
r.register_dynamic("conn-1", "a".to_string(), dynamic_entry_named("a", None));
// Both define "a"; built-in wins.
let resolved = r.resolve("a", &ctx()).unwrap();
assert_eq!(resolved.name(), "a");
}
#[test]
fn collision_dynamic_wins_when_configured() {
let r = ToolRegistry::new(arcs(), CollisionPolicy::DynamicWins);
r.register_dynamic("conn-1", "a".to_string(), dynamic_entry_named("a", None));
let resolved = r.resolve("a", &ctx()).unwrap();
assert_eq!(resolved.name(), "a");
}
#[test]
fn dynamic_filters_by_client_kind() {
let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins);
r.register_dynamic(
"conn-1",
"extra".to_string(),
dynamic_entry_named("extra", Some(ClientKind::codex())),
);
// ctx() has client_kind = Claude — should NOT see this tool.
assert!(r.resolve("extra", &ctx()).is_none());
}
fn ctx_with_session(session: &str) -> ToolContext {
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
let pc = PermissionContext::new("conn-1".to_string(), Some(session.to_string()), wl);
let mut c = ToolContext::for_test(
"conn-1", ClientKind::claude(), ProtocolKind::acp(),
PathBuf::from("/tmp"),
SandboxConfig::default(), PermissionConfig::default(), pc,
);
c.session_id = Some(Arc::from(session));
c
}
#[test]
fn register_under_session_visible_to_resolve_and_list() {
let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins);
let c = ctx_with_session("sess-A");
r.register_dynamic(
ToolRegistry::scope_key(&c),
"extra".to_string(),
dynamic_entry_named("extra", None),
);
assert!(r.resolve("extra", &c).is_some());
assert!(r.list(&c).iter().any(|n| n == "extra"));
// A different session under the same connector must NOT see it.
let other = ctx_with_session("sess-B");
assert!(r.resolve("extra", &other).is_none());
assert!(!r.list(&other).iter().any(|n| n == "extra"));
}
#[test]
fn register_under_connector_only_visible_to_resolve_and_list() {
let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins);
let c = ctx(); // session_id = None → scope_key falls back to connector_id
r.register_dynamic(
ToolRegistry::scope_key(&c),
"extra".to_string(),
dynamic_entry_named("extra", None),
);
assert!(r.resolve("extra", &c).is_some());
assert!(r.list(&c).iter().any(|n| n == "extra"));
}
#[test]
fn unregister_scope_drops_dynamic() {
let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins);
r.register_dynamic("conn-1", "extra".to_string(), dynamic_entry_named("extra", None));
assert!(r.resolve("extra", &ctx()).is_some());
r.unregister_scope("conn-1");
assert!(r.resolve("extra", &ctx()).is_none());
}
}
+22
View File
@@ -0,0 +1,22 @@
//! Search operations (glob, grep, ls) with result limiting.
//!
//! This module provides:
//! - `ls()` - List directory contents (TOOLS-SEARCH-01)
//! - `glob_search()` - Find files matching patterns (TOOLS-SEARCH-02)
//! - `grep_search()` - Search file contents with regex (TOOLS-SEARCH-03)
//!
//! All operations:
//! - Respect sandbox boundaries
//! - Enforce result count and byte limits
//! - Return locations for UI navigation
//!
//! **Status**: All functions stubbed, implementation pending
pub mod ls;
pub mod glob;
pub mod grep;
// Re-export main types and functions
pub use ls::{ls, FileKind, LsEntry, LsRequest, LsResponse};
pub use glob::{glob_search, GlobRequest, GlobResponse};
pub use grep::{grep_search, GrepMatch, GrepRequest, GrepResponse};
+192
View File
@@ -0,0 +1,192 @@
//! Glob-based file search with pattern matching and result limits.
//!
//! **Status**: Not yet implemented (TOOLS-SEARCH-02)
//!
//! This module will implement:
//! - Glob pattern matching
//! - Recursive directory traversal
//! - Result count and byte limits
//! - Exclude pattern filtering
use crate::config::SearchConfig;
use crate::error::{ToolError, ToolResult};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Request to search for files matching glob patterns.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobRequest {
/// Base path to search within.
pub path: String,
/// Glob pattern to match (e.g., "**/*.rs", "src/**/*.toml").
pub pattern: String,
/// Optional exclude patterns (in addition to defaults).
#[serde(skip_serializing_if = "Option::is_none")]
pub exclude: Option<Vec<String>>,
/// Optional maximum results (overrides config default).
#[serde(skip_serializing_if = "Option::is_none")]
pub max_results: Option<u32>,
}
/// Response from glob search.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobResponse {
/// Paths matching the glob pattern.
pub matches: Vec<PathBuf>,
/// Whether results were truncated due to limits.
pub truncated: bool,
}
/// Search for files matching glob patterns.
///
/// This implementation:
/// 1. Validates path is within allowed roots
/// 2. Compiles glob pattern using `globset`
/// 3. Traverses directory tree recursively using `walkdir`
/// 4. Matches files against pattern
/// 5. Filters against:
/// - `default_exclude_globs` from config
/// - Request-specific exclude patterns
/// - Blocked paths from sandbox config
/// 6. Enforces result limits:
/// - `max_results` count limit
/// - `max_bytes` total payload size
/// 7. Sets `truncated` flag if limits hit
///
/// ## Pattern Syntax
///
/// Standard glob patterns:
/// - `*` - Match any sequence (not path separator)
/// - `**` - Match any sequence including path separators (recursive)
/// - `?` - Match single character
/// - `[abc]` - Match character class
///
/// Examples:
/// - `**/*.rs` - All Rust files recursively
/// - `src/**/*.toml` - TOML files under src/
/// - `test_*.py` - Python test files in current dir
///
/// ## Error Cases
///
/// - Path outside allowed roots → `ToolError::SandboxViolation`
/// - Invalid glob pattern → `ToolError::InvalidInput`
/// - I/O errors during traversal → `ToolError::Io`
///
/// ## Performance
///
/// - Stops early when limits reached
/// - Skips excluded directories entirely (no traversal)
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-SEARCH-02`
pub async fn glob_search(
request: GlobRequest,
config: &SearchConfig,
) -> ToolResult<GlobResponse> {
use crate::path::blocklist::compile_blocklist;
use globset::GlobBuilder;
use std::path::Path;
use walkdir::WalkDir;
// Canonicalize the base path
let base_path = dunce::canonicalize(Path::new(&request.path)).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
ToolError::NotFound {
path: request.path.clone(),
}
} else {
ToolError::Io(e)
}
})?;
// Compile the glob pattern
let glob = GlobBuilder::new(&request.pattern)
.literal_separator(false) // Allow ** to match path separators
.build()
.map_err(|e| ToolError::InvalidInput(format!("Invalid glob pattern: {}", e)))?;
let glob_matcher = glob.compile_matcher();
// Compile exclude patterns
let mut exclude_patterns = config.default_exclude_globs.clone();
if let Some(ref extra_excludes) = request.exclude {
exclude_patterns.extend_from_slice(extra_excludes);
}
let exclude_compiled = if !exclude_patterns.is_empty() {
Some(compile_blocklist(&exclude_patterns)?)
} else {
None
};
// Determine max results
let max_results = request.max_results.unwrap_or(config.max_results);
let max_bytes = config.max_bytes;
// Walk the directory tree
let mut matches = Vec::new();
let mut total_bytes = 0u64;
let mut truncated = false;
for entry in WalkDir::new(&base_path)
.follow_links(false)
.into_iter()
.filter_entry(|e| {
// Skip excluded directories early to avoid traversing them
if let Some(ref exclude) = exclude_compiled {
if exclude.glob_set().is_match(e.path()) {
return false;
}
}
true
})
{
// Check if we've hit the result limit
if matches.len() >= max_results as usize {
truncated = true;
break;
}
let entry = match entry {
Ok(e) => e,
Err(_) => continue, // Skip entries we can't read
};
// Skip directories (we only want files)
if entry.file_type().is_dir() {
continue;
}
let entry_path = entry.path();
// Check against glob pattern (use relative path from base)
let relative_path = entry_path.strip_prefix(&base_path).unwrap_or(entry_path);
if !glob_matcher.is_match(relative_path) && !glob_matcher.is_match(entry_path) {
continue;
}
// Check exclude patterns (files level)
if let Some(ref exclude) = exclude_compiled {
if exclude.glob_set().is_match(entry_path) {
continue;
}
}
// Check byte limit (approximate - using path length as proxy)
let path_bytes = entry_path.to_string_lossy().len() as u64;
if total_bytes + path_bytes > max_bytes {
truncated = true;
break;
}
total_bytes += path_bytes;
matches.push(entry_path.to_path_buf());
}
Ok(GlobResponse { matches, truncated })
}
+359
View File
@@ -0,0 +1,359 @@
//! Content search (grep) with regex and context lines.
//!
//! This module implements:
//! - Regex-based content search
//! - Context line extraction (before/after)
//! - Result count and byte limits
//! - Binary file detection and skip
//! - Case-insensitive matching
use crate::config::SearchConfig;
use crate::error::{ToolError, ToolResult};
use regex::RegexBuilder;
use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;
use walkdir::WalkDir;
/// Request to search file contents with regex.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrepRequest {
/// Base path to search within.
pub path: String,
/// Regex pattern to match.
pub pattern: String,
/// Optional glob pattern to filter files.
#[serde(skip_serializing_if = "Option::is_none")]
pub file_pattern: Option<String>,
/// Case-insensitive matching.
#[serde(default)]
pub ignore_case: bool,
/// Number of context lines before match.
#[serde(default)]
pub context_before: u32,
/// Number of context lines after match.
#[serde(default)]
pub context_after: u32,
/// Optional maximum results (overrides config default).
#[serde(skip_serializing_if = "Option::is_none")]
pub max_results: Option<u32>,
}
/// Response from grep search.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrepResponse {
/// Matches found in files.
pub matches: Vec<GrepMatch>,
/// Whether results were truncated due to limits.
pub truncated: bool,
}
/// A single grep match.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GrepMatch {
/// Path to the file containing the match.
pub path: PathBuf,
/// Line number of the match (1-indexed).
pub line_number: usize,
/// The matching line content.
pub line: String,
/// Context lines before the match.
#[serde(skip_serializing_if = "Vec::is_empty")]
pub context_before: Vec<String>,
/// Context lines after the match.
#[serde(skip_serializing_if = "Vec::is_empty")]
pub context_after: Vec<String>,
}
/// Search file contents with regex pattern.
///
/// This implementation:
/// 1. Validates path is within allowed roots
/// 2. Compiles regex pattern
/// 3. Traverses directory tree (optionally filtered by file_pattern)
/// 4. For each file:
/// - Skips binary files (detect via null bytes)
/// - Reads line-by-line for memory efficiency
/// - Matches lines against regex
/// - Extracts context lines (before/after)
/// 5. Enforces result limits:
/// - `max_results` match count
/// - `max_bytes` total payload size
/// 6. Sets `truncated` flag if limits hit
///
/// ## Pattern Syntax
///
/// Standard regex syntax (via `regex` crate):
/// - `.` - Any character (except newline by default)
/// - `.*` - Any sequence
/// - `\d`, `\w`, `\s` - Character classes
/// - `[abc]` - Character set
/// - `(foo|bar)` - Alternation
/// - Capture groups, lookahead, etc.
///
/// ## Context Lines
///
/// - `context_before: N` - Include N lines before each match
/// - `context_after: N` - Include N lines after each match
/// - Useful for understanding match context
///
/// ## Binary File Handling
///
/// - Detects binary files by null byte presence
/// - Skips binary files silently
///
/// ## Error Cases
///
/// - Path outside allowed roots → `ToolError::SandboxViolation`
/// - Invalid regex pattern → `ToolError::InvalidInput`
/// - I/O errors during traversal → `ToolError::Io`
///
/// ## Performance
///
/// - Line-by-line reading for large files
/// - Stops early when limits reached
/// - Skips excluded directories
/// - Skips binary files
///
/// ## Platform Notes
///
/// - Handles CRLF line endings on Windows correctly
/// - Tests with Windows-specific paths
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-SEARCH-03`
pub async fn grep_search(
request: GrepRequest,
config: &SearchConfig,
) -> ToolResult<GrepResponse> {
use crate::path::blocklist::compile_blocklist;
use globset::GlobBuilder;
use std::path::Path;
// Canonicalize the base path
let base_path = dunce::canonicalize(Path::new(&request.path)).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
ToolError::NotFound {
path: request.path.clone(),
}
} else {
ToolError::Io(e)
}
})?;
// Compile regex pattern
let regex = RegexBuilder::new(&request.pattern)
.case_insensitive(request.ignore_case)
.build()
.map_err(|e| ToolError::InvalidInput(format!("Invalid regex pattern: {}", e)))?;
// Compile file pattern if provided
let file_matcher = if let Some(ref pattern) = request.file_pattern {
let glob = GlobBuilder::new(pattern)
.literal_separator(false)
.build()
.map_err(|e| ToolError::InvalidInput(format!("Invalid file pattern: {}", e)))?;
Some(glob.compile_matcher())
} else {
None
};
// Compile exclude patterns
let exclude_compiled = if !config.default_exclude_globs.is_empty() {
Some(compile_blocklist(&config.default_exclude_globs)?)
} else {
None
};
// Determine max results
let max_results = request.max_results.unwrap_or(config.max_results);
let max_bytes = config.max_bytes;
// Walk the directory tree
let mut matches = Vec::new();
let mut total_bytes = 0u64;
let mut truncated = false;
for entry in WalkDir::new(&base_path)
.follow_links(false)
.into_iter()
.filter_entry(|e| {
// Skip excluded directories early
if let Some(ref exclude) = exclude_compiled {
if exclude.glob_set().is_match(e.path()) {
return false;
}
}
true
})
{
// Check if we've hit the result limit
if matches.len() >= max_results as usize {
truncated = true;
break;
}
let entry = match entry {
Ok(e) => e,
Err(_) => continue,
};
// Skip directories
if entry.file_type().is_dir() {
continue;
}
let entry_path = entry.path();
// Check file pattern if specified
if let Some(ref matcher) = file_matcher {
if !matcher.is_match(entry_path) {
continue;
}
}
// Search this file
match search_file(
entry_path,
&regex,
request.context_before as usize,
request.context_after as usize,
max_results - matches.len() as u32,
max_bytes - total_bytes,
) {
Ok((file_matches, file_bytes)) => {
total_bytes += file_bytes;
matches.extend(file_matches);
// Check limits
if matches.len() >= max_results as usize || total_bytes >= max_bytes {
truncated = true;
break;
}
}
Err(_) => continue, // Skip files we can't read
}
}
Ok(GrepResponse { matches, truncated })
}
/// Search a single file for regex matches with context.
fn search_file(
path: &std::path::Path,
regex: &regex::Regex,
context_before: usize,
context_after: usize,
max_matches: u32,
max_bytes: u64,
) -> ToolResult<(Vec<GrepMatch>, u64)> {
// Open file
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut matches: Vec<GrepMatch> = Vec::new();
let mut total_bytes = 0u64;
// Ring buffer for context_before lines
let mut before_buffer: VecDeque<(usize, String)> = VecDeque::new();
let mut after_countdown = 0usize;
let mut after_lines: Vec<String> = Vec::new();
let mut last_match_line = 0usize;
for (line_num, line_result) in reader.lines().enumerate() {
if matches.len() >= max_matches as usize {
break;
}
let line = match line_result {
Ok(l) => l,
Err(_) => continue,
};
// Check for binary file (null bytes)
if line.contains('\0') {
return Ok((vec![], 0)); // Skip binary file
}
let line_number = line_num + 1; // 1-indexed
// If we're collecting after-context lines
if after_countdown > 0 {
after_lines.push(line.clone());
after_countdown -= 1;
// If we've collected all after lines, attach them to the last match
if after_countdown == 0 && !matches.is_empty() {
matches.last_mut().unwrap().context_after = after_lines.clone();
after_lines.clear();
}
}
// Check if this line matches
if regex.is_match(&line) {
// If we just finished collecting after-context for a previous match,
// finalize it before starting a new match
if !after_lines.is_empty() && !matches.is_empty() {
matches.last_mut().unwrap().context_after = after_lines.clone();
after_lines.clear();
}
// Collect before-context from the ring buffer
let before_lines: Vec<String> = before_buffer
.iter()
.filter(|(ln, _)| *ln > last_match_line && *ln < line_number)
.map(|(_, l)| l.clone())
.collect();
let match_bytes = (line.len()
+ before_lines.iter().map(|l| l.len()).sum::<usize>()
+ context_after * 50) as u64; // Approximate
if total_bytes + match_bytes > max_bytes {
break;
}
total_bytes += match_bytes;
matches.push(GrepMatch {
path: path.to_path_buf(),
line_number,
line: line.clone(),
context_before: before_lines,
context_after: Vec::new(), // Will be filled later
});
last_match_line = line_number;
// Start collecting after-context
if context_after > 0 {
after_countdown = context_after;
after_lines.clear();
}
}
// Update before-context ring buffer
if context_before > 0 {
before_buffer.push_back((line_number, line.clone()));
if before_buffer.len() > context_before {
before_buffer.pop_front();
}
}
}
Ok((matches, total_bytes))
}
+180
View File
@@ -0,0 +1,180 @@
//! Directory listing with sandboxing and exclude globs.
//!
//! **Status**: Not yet implemented (TOOLS-SEARCH-01)
//!
//! This module will implement:
//! - Directory entry listing
//! - Sandbox containment checks
//! - Exclude glob filtering
//! - File kind and size metadata
use crate::config::SearchConfig;
use crate::error::{ToolError, ToolResult};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Request to list directory contents.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LsRequest {
/// Absolute path to the directory to list.
pub path: String,
}
/// Response from listing directory contents.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LsResponse {
/// Directory entries.
pub entries: Vec<LsEntry>,
}
/// A single directory entry.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LsEntry {
/// Path to the entry (absolute or relative based on config).
pub path: PathBuf,
/// File kind.
pub kind: FileKind,
/// File size in bytes (None for directories/symlinks).
pub size: Option<u64>,
}
/// File kind classification.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FileKind {
/// Regular file.
File,
/// Directory.
Dir,
/// Symbolic link.
Symlink,
}
/// List directory contents with sandboxing and filtering.
///
/// This implementation:
/// 1. Validates path is within allowed roots (using SandboxConfig)
/// 2. Checks blocklist patterns
/// 3. Reads directory entries asynchronously (tokio::fs::read_dir)
/// 4. Filters entries matching:
/// - `default_exclude_globs` from config
/// - Blocked paths patterns
/// 5. Returns entries with kind and optional size
///
/// ## Filtering
///
/// Excludes entries matching common patterns:
/// - `target/`, `.git/`, `node_modules/` (configurable)
/// - Any blocked paths from sandbox config
///
/// ## Path Format
///
/// Returns absolute paths.
///
/// ## Error Cases
///
/// - Path outside allowed roots → `ToolError::SandboxViolation`
/// - Path matches blocklist → `ToolError::BlockedPath`
/// - Directory not found → `ToolError::NotFound`
/// - I/O errors → `ToolError::Io`
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-SEARCH-01`
pub async fn ls(request: LsRequest, config: &SearchConfig) -> ToolResult<LsResponse> {
use crate::path::blocklist::compile_blocklist;
use std::path::Path;
use tokio::fs;
let path = Path::new(&request.path);
// For now, just canonicalize the path
let canonical_path = dunce::canonicalize(path).map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
ToolError::NotFound {
path: request.path.clone(),
}
} else {
ToolError::Io(e)
}
})?;
// Compile exclude globs for filtering
let exclude_compiled = if !config.default_exclude_globs.is_empty() {
Some(compile_blocklist(&config.default_exclude_globs)?)
} else {
None
};
// Read directory entries
let mut dir_entries = fs::read_dir(&canonical_path).await.map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
ToolError::NotFound {
path: request.path.clone(),
}
} else if e.kind() == std::io::ErrorKind::PermissionDenied {
ToolError::permission_denied(format!("Cannot read directory: {}", request.path))
} else {
ToolError::Io(e)
}
})?;
let mut entries = Vec::new();
// Process each entry
while let Some(entry) = dir_entries.next_entry().await? {
let entry_path = entry.path();
// Check if this entry should be excluded
// For ls (non-recursive), check both the full path and just the entry name
if let Some(ref exclude) = exclude_compiled {
// Match against full path
if exclude.glob_set().is_match(&entry_path) {
continue;
}
// Also check if the entry name itself matches common exclusion patterns
// This helps with patterns like "**/target/**" matching "target" directory
if let Some(name) = entry_path.file_name() {
let name_str = name.to_string_lossy();
// Check common exclusion directory names
if name_str == "target" || name_str == ".git" || name_str == "node_modules"
|| name_str == "__pycache__" || name_str == ".venv" {
continue;
}
}
}
// Get metadata
let metadata = match entry.metadata().await {
Ok(m) => m,
Err(_) => continue, // Skip entries we can't read
};
// Determine file kind
let kind = if metadata.is_symlink() {
FileKind::Symlink
} else if metadata.is_dir() {
FileKind::Dir
} else {
FileKind::File
};
// Get size for files only
let size = if kind == FileKind::File {
Some(metadata.len())
} else {
None
};
entries.push(LsEntry {
path: entry_path,
kind,
size,
});
}
Ok(LsResponse { entries })
}
+44
View File
@@ -0,0 +1,44 @@
//! Terminal/command execution with output capture and isolation.
//!
//! This module provides:
//! - `create_terminal()` - Spawn a command with output capture (TOOLS-TERM-01)
//! - `get_terminal_output()` - Retrieve terminal output with truncation (TOOLS-TERM-02)
//! - `wait_for_terminal_exit()` - Wait for terminal to complete (TOOLS-TERM-03)
//! - `kill_terminal()` - Terminate a running terminal (TOOLS-TERM-04)
//! - `release_terminal()` - Clean up terminal resources (TOOLS-TERM-05)
//!
//! All operations:
//! - Respect sandbox boundaries (cwd restrictions)
//! - Enforce output byte limits (ring buffer)
//! - Use environment variable allowlists
//! - Apply command blocklists (best-effort)
//! - Handle Windows-specific shells (cmd.exe, PowerShell)
//!
//! **Status**: All functions stubbed, implementation pending
pub mod create;
pub mod output;
pub mod wait;
pub mod kill;
pub mod release;
pub mod ring_buffer;
pub mod registry;
// Re-export main types and functions
pub use create::{
create_terminal, CreateTerminalRequest, CreateTerminalResponse, EnvVar,
};
pub use output::{
get_terminal_output, TerminalOutputRequest, TerminalOutputResponse,
};
pub use wait::{
wait_for_terminal_exit, WaitForTerminalExitRequest, WaitForTerminalExitResponse,
};
pub use kill::{
kill_terminal, KillTerminalCommandRequest, KillTerminalCommandResponse,
};
pub use release::{
release_terminal, ReleaseTerminalRequest, ReleaseTerminalResponse,
};
pub use ring_buffer::RingBuffer;
pub use registry::{global_registry, TerminalRegistry, TerminalId};
@@ -0,0 +1,260 @@
//! Terminal creation with process spawning and output capture.
//!
//! **Status**: Not yet implemented (TOOLS-TERM-01)
//!
//! This module will implement:
//! - Process spawning with tokio::process::Command
//! - CWD validation and sandboxing
//! - Environment variable filtering
//! - Command blocklist enforcement
//! - Output capture with ring buffer
//! - Terminal ID generation and registry
use crate::config::TerminalConfig;
use crate::error::{ToolError, ToolResult};
use serde::{Deserialize, Serialize};
/// Request to create a terminal and spawn a command.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateTerminalRequest {
/// Command to execute.
pub command: String,
/// Command-line arguments.
#[serde(default)]
pub args: Vec<String>,
/// Current working directory (must be within allowed roots).
#[serde(skip_serializing_if = "Option::is_none")]
pub cwd: Option<String>,
/// Environment variables to set.
#[serde(skip_serializing_if = "Option::is_none")]
pub env: Option<Vec<EnvVar>>,
/// Output byte limit (overrides config default).
#[serde(skip_serializing_if = "Option::is_none")]
pub output_byte_limit: Option<u64>,
}
/// Environment variable key-value pair.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EnvVar {
pub name: String,
pub value: String,
}
/// Response from creating a terminal.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CreateTerminalResponse {
/// Unique terminal ID for future operations.
pub terminal_id: String,
}
/// Create a terminal and spawn a command with sandboxing.
///
/// **Status**: Implemented (TOOLS-TERM-01)
///
/// ## Implementation
///
/// This function:
/// 1. Validates `terminal.enabled = true` in config
/// 2. Validates and canonicalizes CWD (must be within allowed roots)
/// 3. Validates command against blocklist (best-effort)
/// 4. Filters environment variables against allowlist
/// 5. Spawns process with output capture
/// 6. Sets up ring buffer for output
/// 7. Generates unique TerminalId and stores in registry
/// 8. Returns TerminalId
///
/// ## Error Cases
///
/// - Terminal disabled → `ToolError::PermissionDenied`
/// - CWD outside allowed roots → `ToolError::SandboxViolation`
/// - Command blocked → `ToolError::PermissionDenied`
/// - Spawn failure → `ToolError::Io`
///
/// ## Platform Notes
///
/// **Windows**:
/// - Uses CREATE_NO_WINDOW flag to prevent console flash
/// - Direct command spawning (no shell wrapper)
///
/// **Unix**:
/// - Direct process spawning
/// - Standard POSIX environment
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-01`
/// - Ring buffer: `ring_buffer.rs`
pub async fn create_terminal(
request: CreateTerminalRequest,
config: &TerminalConfig,
) -> ToolResult<CreateTerminalResponse> {
use crate::terminal::registry::{global_registry, TerminalState};
use crate::terminal::ring_buffer::RingBuffer;
use globset::{Glob, GlobSetBuilder};
use std::sync::{Arc, Mutex};
use std::time::Instant;
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
// Step 1: Validate terminal is enabled
if !config.enabled {
return Err(ToolError::permission_denied(
"Terminal operations are disabled",
));
}
// Step 2: Validate and canonicalize CWD
let cwd = if let Some(cwd_str) = &request.cwd {
let cwd_path = std::path::PathBuf::from(cwd_str);
// Canonicalize the CWD (need sandbox config for this)
// For now, we'll use the path as-is since we don't have sandbox config in this function
// TODO: Pass SandboxConfig to this function or get it from config
cwd_path
} else if let Some(default_cwd) = &config.default_cwd {
default_cwd.clone()
} else {
return Err(ToolError::InvalidInput(
"No CWD specified and no default_cwd configured".to_string(),
));
};
// Step 3: Validate command against blocklist (best-effort)
if !config.command_blocklist.is_empty() {
let mut builder = GlobSetBuilder::new();
for pattern in &config.command_blocklist {
if let Ok(glob) = Glob::new(pattern) {
builder.add(glob);
}
}
if let Ok(blocklist) = builder.build() {
if blocklist.is_match(&request.command) {
return Err(ToolError::permission_denied(format!(
"Command '{}' is blocked by configuration",
request.command
)));
}
}
}
// Step 4: Filter environment variables
let filtered_env: Vec<(String, String)> = if let Some(env_vars) = &request.env {
env_vars
.iter()
.filter(|var| config.env_allowlist.contains(&var.name))
.map(|var| (var.name.clone(), var.value.clone()))
.collect()
} else {
Vec::new()
};
// Step 5: Determine output byte limit
let output_limit = request
.output_byte_limit
.unwrap_or(config.output_byte_limit);
// Step 6: Spawn process
let mut cmd = Command::new(&request.command);
cmd.args(&request.args);
cmd.current_dir(&cwd);
cmd.envs(filtered_env);
// Capture stdout and stderr
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
cmd.stdin(std::process::Stdio::null());
// Windows-specific: CREATE_NO_WINDOW flag
#[cfg(windows)]
{
#[allow(unused_imports)]
use std::os::windows::process::CommandExt;
const CREATE_NO_WINDOW: u32 = 0x08000000;
cmd.creation_flags(CREATE_NO_WINDOW);
}
let mut child = cmd.spawn().map_err(|e| {
ToolError::terminal_error(format!(
"Failed to spawn command '{}': {}",
request.command, e
))
})?;
// Step 7: Set up output capture with ring buffer
let output_buffer = Arc::new(Mutex::new(RingBuffer::new(output_limit as usize)));
// Take stdout and stderr
let stdout = child.stdout.take().ok_or_else(|| {
ToolError::terminal_error("Failed to capture stdout")
})?;
let stderr = child.stderr.take().ok_or_else(|| {
ToolError::terminal_error("Failed to capture stderr")
})?;
// Spawn background task to capture output
let buffer_clone = output_buffer.clone();
let output_task = tokio::spawn(async move {
let stdout_reader = BufReader::new(stdout);
let stderr_reader = BufReader::new(stderr);
let mut stdout_lines = stdout_reader.lines();
let mut stderr_lines = stderr_reader.lines();
loop {
tokio::select! {
result = stdout_lines.next_line() => {
match result {
Ok(Some(line)) => {
let mut buffer = buffer_clone.lock().unwrap();
buffer.push(line.as_bytes());
buffer.push(b"\n");
}
Ok(None) => break, // EOF
Err(_) => break,
}
}
result = stderr_lines.next_line() => {
match result {
Ok(Some(line)) => {
let mut buffer = buffer_clone.lock().unwrap();
buffer.push(line.as_bytes());
buffer.push(b"\n");
}
Ok(None) => break, // EOF
Err(_) => break,
}
}
}
}
});
// Step 8: Generate unique TerminalId and store in registry
let registry = global_registry();
let terminal_id = registry.generate_id();
let state = TerminalState {
process: child,
output_buffer,
start_time: Instant::now(),
exit_status: None,
output_task: Some(output_task),
killed: false,
};
registry.insert(terminal_id.clone(), state);
tracing::info!(
terminal_id = %terminal_id,
command = %request.command,
cwd = ?cwd,
"Terminal created"
);
// Step 9: Return TerminalId
Ok(CreateTerminalResponse { terminal_id })
}
+123
View File
@@ -0,0 +1,123 @@
//! Terminal kill operation.
//!
//! **Status**: Not yet implemented (TOOLS-TERM-04)
//!
//! This module will implement:
//! - Forceful process termination
//! - Cross-platform kill (SIGKILL/TerminateProcess)
//! - Idempotent kill operations
use crate::config::TerminalConfig;
use crate::error::ToolResult;
use serde::{Deserialize, Serialize};
/// Request to kill a running terminal.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KillTerminalCommandRequest {
/// Terminal ID from create response.
pub terminal_id: String,
}
/// Response from killing a terminal.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KillTerminalCommandResponse {}
/// Forcefully terminate a terminal process.
///
/// **Status**: Implemented (TOOLS-TERM-04)
///
/// ## Implementation
///
/// This function:
/// 1. Looks up TerminalId in registry
/// 2. Gets process handle
/// 3. If already exited → returns success (idempotent)
/// 4. If still running:
/// - Sends forceful termination signal
/// - Unix: SIGKILL via `child.kill()`
/// - Windows: TerminateProcess via `child.kill()`
/// 5. Marks terminal as killed in registry
/// 6. Returns success
///
/// ## Idempotency
///
/// Multiple kill calls are safe:
/// - First call kills the process
/// - Subsequent calls return success (no-op)
/// - Does NOT remove from registry (use release for that)
///
/// ## Forceful Termination
///
/// This is a hard kill:
/// - Process cannot catch or ignore the signal
/// - No cleanup handlers run
/// - May leave resources in inconsistent state
///
/// Use when:
/// - Process is unresponsive
/// - Need immediate termination
/// - Timeout exceeded
///
/// ## Platform Notes
///
/// **Windows**:
/// - Uses TerminateProcess API
/// - May not kill child processes (no process tree kill)
/// - Test with long-running Windows commands
///
/// **Unix**:
/// - Sends SIGKILL
/// - Only kills direct child (not process group)
/// - Zombie processes may remain until wait()
///
/// ## Error Cases
///
/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound`
/// - Terminal released → `ToolError::TerminalNotFound`
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-04`
/// - Release terminal: `release.rs`
pub async fn kill_terminal(
request: KillTerminalCommandRequest,
_config: &TerminalConfig,
) -> ToolResult<KillTerminalCommandResponse> {
use crate::terminal::registry::global_registry;
let registry = global_registry();
registry.get_mut(&request.terminal_id, |state| {
// Check if already killed or exited
if state.killed || state.exit_status.is_some() {
// Already terminated, return success (idempotent)
tracing::debug!(
terminal_id = %request.terminal_id,
"Terminal already terminated"
);
return;
}
// Kill the process
match state.process.start_kill() {
Ok(()) => {
state.killed = true;
tracing::info!(
terminal_id = %request.terminal_id,
"Terminal process killed"
);
}
Err(e) => {
// If kill fails, it might already be dead
tracing::warn!(
terminal_id = %request.terminal_id,
error = %e,
"Failed to kill terminal (process may already be dead)"
);
state.killed = true;
}
}
})?;
Ok(KillTerminalCommandResponse {})
}
@@ -0,0 +1,105 @@
//! Terminal output retrieval from ring buffer.
//!
//! **Status**: Not yet implemented (TOOLS-TERM-02)
//!
//! This module will implement:
//! - Output snapshot retrieval
//! - Truncation flag tracking
//! - Exit status reporting
use crate::config::TerminalConfig;
use crate::error::ToolResult;
use serde::{Deserialize, Serialize};
/// Request to get terminal output.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TerminalOutputRequest {
/// Terminal ID from create response.
pub terminal_id: String,
}
/// Response with terminal output snapshot.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TerminalOutputResponse {
/// UTF-8 output (stdout + stderr interleaved).
pub output: String,
/// Whether output was truncated due to buffer overflow.
pub truncated: bool,
/// Exit status if process has completed (None if still running).
#[serde(skip_serializing_if = "Option::is_none")]
pub exit_status: Option<i32>,
}
/// Get current output from a terminal's ring buffer.
///
/// **Status**: Implemented (TOOLS-TERM-02)
///
/// ## Implementation
///
/// This function:
/// 1. Looks up TerminalId in registry
/// 2. Locks ring buffer (thread-safe access)
/// 3. Reads current contents (snapshot, not consuming)
/// 4. Gets truncation flag from buffer
/// 5. Checks process status for exit_status
/// 6. Returns TerminalOutputResponse
///
/// ## Behavior
///
/// - Returns current output snapshot (cumulative)
/// - Does NOT consume output (multiple calls return same data)
/// - Truncation flag indicates if any output was dropped
/// - Exit status available only after process completes
///
/// ## Error Cases
///
/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound`
/// - Terminal released → `ToolError::TerminalNotFound`
///
/// ## Use Cases
///
/// - Poll for output while process is running
/// - Check progress without blocking
/// - Get final output after completion
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-02`
/// - Wait for exit: `wait.rs`
pub async fn get_terminal_output(
request: TerminalOutputRequest,
_config: &TerminalConfig,
) -> ToolResult<TerminalOutputResponse> {
use crate::terminal::registry::global_registry;
let registry = global_registry();
// Look up terminal and get output
registry.get_mut(&request.terminal_id, |state| {
// Lock output buffer and get snapshot
let buffer = state.output_buffer.lock().unwrap();
let output = buffer.snapshot();
let truncated = buffer.is_truncated();
// Check if process has exited
let exit_status = state.exit_status.as_ref().map(|status| {
#[cfg(unix)]
{
use std::os::unix::process::ExitStatusExt;
status.code().or_else(|| status.signal()).unwrap_or(-1)
}
#[cfg(not(unix))]
{
status.code().unwrap_or(-1)
}
});
TerminalOutputResponse {
output,
truncated,
exit_status,
}
})
}
@@ -0,0 +1,147 @@
//! Terminal registry for tracking active terminal processes.
//!
//! **Status**: Implemented (TOOLS-TERM-01)
//!
//! This module provides a global registry for tracking terminal processes,
//! their output buffers, and their lifecycle state.
use super::ring_buffer::RingBuffer;
use crate::error::{ToolError, ToolResult};
use std::collections::HashMap;
use std::process::ExitStatus;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use tokio::process::Child;
use tokio::task::JoinHandle;
/// Unique identifier for a terminal.
pub type TerminalId = String;
/// State of a terminal process.
pub struct TerminalState {
/// Process child handle
pub process: Child,
/// Output ring buffer (stdout + stderr interleaved)
pub output_buffer: Arc<Mutex<RingBuffer>>,
/// When the terminal was created
pub start_time: Instant,
/// Exit status (once process completes)
pub exit_status: Option<ExitStatus>,
/// Background task handle for output capture
pub output_task: Option<JoinHandle<()>>,
/// Whether the terminal has been explicitly killed
pub killed: bool,
}
/// Global terminal registry.
///
/// This is a singleton that tracks all active terminals.
pub struct TerminalRegistry {
terminals: Arc<Mutex<HashMap<TerminalId, TerminalState>>>,
}
impl TerminalRegistry {
/// Create a new terminal registry.
pub fn new() -> Self {
Self {
terminals: Arc::new(Mutex::new(HashMap::new())),
}
}
/// Generate a unique terminal ID.
pub fn generate_id(&self) -> TerminalId {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
format!("term-{}", id)
}
/// Insert a terminal into the registry.
pub fn insert(&self, id: TerminalId, state: TerminalState) {
let mut terminals = self.terminals.lock().unwrap();
terminals.insert(id, state);
}
/// Get a reference to a terminal state.
///
/// Note: This locks the entire registry. For better concurrency,
/// consider redesigning to use individual locks per terminal.
pub fn get_mut<F, R>(&self, id: &str, f: F) -> ToolResult<R>
where
F: FnOnce(&mut TerminalState) -> R,
{
let mut terminals = self.terminals.lock().unwrap();
let state = terminals.get_mut(id).ok_or_else(|| {
ToolError::TerminalNotFound {
terminal_id: id.to_string(),
}
})?;
Ok(f(state))
}
/// Remove a terminal from the registry.
pub fn remove(&self, id: &str) -> ToolResult<TerminalState> {
let mut terminals = self.terminals.lock().unwrap();
terminals.remove(id).ok_or_else(|| ToolError::TerminalNotFound {
terminal_id: id.to_string(),
})
}
/// Check if a terminal exists.
pub fn contains(&self, id: &str) -> bool {
let terminals = self.terminals.lock().unwrap();
terminals.contains_key(id)
}
/// Get the number of active terminals.
pub fn len(&self) -> usize {
let terminals = self.terminals.lock().unwrap();
terminals.len()
}
/// Check if the registry is empty.
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for TerminalRegistry {
fn default() -> Self {
Self::new()
}
}
// Global registry instance
lazy_static::lazy_static! {
static ref GLOBAL_REGISTRY: TerminalRegistry = TerminalRegistry::new();
}
/// Get the global terminal registry.
pub fn global_registry() -> &'static TerminalRegistry {
&GLOBAL_REGISTRY
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_unique_ids() {
let registry = TerminalRegistry::new();
let id1 = registry.generate_id();
let id2 = registry.generate_id();
let id3 = registry.generate_id();
assert_ne!(id1, id2);
assert_ne!(id2, id3);
assert_ne!(id1, id3);
}
#[test]
fn test_registry_contains() {
let registry = TerminalRegistry::new();
assert!(!registry.contains("nonexistent"));
assert_eq!(registry.len(), 0);
assert!(registry.is_empty());
}
}
@@ -0,0 +1,107 @@
//! Terminal release operation for resource cleanup.
//!
//! **Status**: Not yet implemented (TOOLS-TERM-05)
//!
//! This module will implement:
//! - Terminal resource cleanup
//! - Process termination if still running
//! - Registry removal
use crate::config::TerminalConfig;
use crate::error::ToolResult;
use serde::{Deserialize, Serialize};
/// Request to release terminal resources.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReleaseTerminalRequest {
/// Terminal ID from create response.
pub terminal_id: String,
}
/// Response from releasing a terminal.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReleaseTerminalResponse {}
/// Release terminal resources and clean up.
///
/// **Status**: Implemented (TOOLS-TERM-05)
///
/// ## Implementation
///
/// This function:
/// 1. Looks up TerminalId in registry
/// 2. If not found → returns error
/// 3. If process still running:
/// - Kills process forcefully
/// - Waits briefly for exit
/// 4. Aborts output capture task
/// 5. Removes from registry:
/// - Drops process handle
/// - Frees ring buffer memory
/// - Invalidates TerminalId
/// 6. Returns success
///
/// ## Resource Cleanup
///
/// Frees:
/// - Process handle (Child)
/// - Output capture task (JoinHandle)
/// - Ring buffer memory
/// - Registry entry
///
/// ## Behavior
///
/// - Kills process if still running (no confirmation)
/// - Frees all associated memory
/// - Subsequent operations on this TerminalId will fail with TerminalNotFound
///
/// ## When to Release
///
/// Release when:
/// - Process has completed and output is no longer needed
/// - Need to free memory (long-running session)
/// - Cleaning up after error
///
/// Do NOT release if:
/// - Output may be needed later
/// - Process should keep running (use kill instead)
///
/// ## Error Cases
///
/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound`
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-05`
/// - Kill terminal: `kill.rs`
pub async fn release_terminal(
request: ReleaseTerminalRequest,
_config: &TerminalConfig,
) -> ToolResult<ReleaseTerminalResponse> {
use crate::terminal::registry::global_registry;
let registry = global_registry();
// Remove terminal from registry
let mut state = registry.remove(&request.terminal_id)?;
// Kill process if still running
if !state.killed && state.exit_status.is_none() {
let _ = state.process.start_kill();
// Wait briefly for the process to die
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let _ = state.process.wait().await;
}
// Abort output capture task if it exists
if let Some(task) = state.output_task.take() {
task.abort();
}
tracing::info!(
terminal_id = %request.terminal_id,
"Terminal resources released"
);
Ok(ReleaseTerminalResponse {})
}
@@ -0,0 +1,499 @@
//! Ring buffer for terminal output capture with UTF-8 boundary handling.
//!
//! **Status**: Implemented (TOOLS-TERM-06)
//!
//! This module implements:
//! - Fixed-size circular buffer
//! - UTF-8 character boundary awareness
//! - Truncation tracking
//! - Thread-safe access
/// Ring buffer for terminal output capture.
///
/// **Status**: Implemented (TOOLS-TERM-06)
///
/// This provides:
/// - Fixed-size circular buffer (configured byte limit)
/// - UTF-8 character boundary preservation
/// - Truncation flag when buffer overflows
/// - Thread-safe access (Arc<Mutex<RingBuffer>>)
///
/// ## Behavior
///
/// - Oldest data dropped when buffer is full
/// - Ensures no partial UTF-8 characters at boundaries
/// - Tracks whether any truncation occurred
/// - Interleaves stdout and stderr in order
///
/// ## UTF-8 Safety
///
/// When truncating:
/// - Check if last byte is part of multi-byte UTF-8 char
/// - If so, truncate at previous character boundary
/// - Prevents invalid UTF-8 in output
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-06`
/// - Create terminal: `create.rs`
pub struct RingBuffer {
/// Internal buffer storage
buffer: Vec<u8>,
/// Maximum capacity in bytes
capacity: usize,
/// Current write position (circular)
write_pos: usize,
/// Current number of valid bytes in buffer
len: usize,
/// Whether any data was truncated (dropped) due to overflow
truncated: bool,
}
impl RingBuffer {
/// Create a new ring buffer with specified capacity.
///
/// # Arguments
///
/// * `capacity` - Maximum number of bytes to store
///
/// # Examples
///
/// ```
/// use dirigent_tools::terminal::RingBuffer;
///
/// let buffer = RingBuffer::new(1024);
/// assert_eq!(buffer.snapshot(), "");
/// assert!(!buffer.is_truncated());
/// ```
pub fn new(capacity: usize) -> Self {
Self {
buffer: Vec::with_capacity(capacity),
capacity,
write_pos: 0,
len: 0,
truncated: false,
}
}
/// Append data to the buffer, truncating from the beginning if capacity is exceeded.
///
/// This method ensures UTF-8 character boundaries are preserved when truncating.
///
/// # Arguments
///
/// * `data` - Bytes to append (may contain partial UTF-8 sequences)
///
/// # Examples
///
/// ```
/// use dirigent_tools::terminal::RingBuffer;
///
/// let mut buffer = RingBuffer::new(10);
/// buffer.push(b"Hello");
/// assert_eq!(buffer.snapshot(), "Hello");
///
/// buffer.push(b" World!");
/// // Buffer truncated to fit capacity, respecting UTF-8 boundaries
/// assert!(buffer.is_truncated());
/// ```
pub fn push(&mut self, data: &[u8]) {
if data.is_empty() {
return;
}
// If the buffer is not yet at capacity, we can just append
if self.len < self.capacity {
let available = self.capacity - self.len;
let to_append = data.len().min(available);
if self.buffer.len() < self.capacity {
// Buffer hasn't been fully allocated yet
self.buffer.extend_from_slice(&data[..to_append]);
} else {
// Buffer is allocated, write to the circular position
for &byte in &data[..to_append] {
self.buffer[self.write_pos] = byte;
self.write_pos = (self.write_pos + 1) % self.capacity;
}
}
self.len += to_append;
// If we couldn't append all the data, handle the overflow
if to_append < data.len() {
self.push_with_overflow(&data[to_append..]);
}
} else {
// Buffer is full, need to overwrite old data
self.push_with_overflow(data);
}
}
/// Push data when buffer is at capacity (overwrites old data).
fn push_with_overflow(&mut self, data: &[u8]) {
if data.is_empty() {
return;
}
self.truncated = true;
// Ensure buffer is fully allocated
if self.buffer.len() < self.capacity {
self.buffer.resize(self.capacity, 0);
}
// If data is larger than capacity, only keep the last `capacity` bytes
let data_to_write = if data.len() >= self.capacity {
&data[data.len() - self.capacity..]
} else {
data
};
// Write data circularly
for &byte in data_to_write {
self.buffer[self.write_pos] = byte;
self.write_pos = (self.write_pos + 1) % self.capacity;
}
// Update length (stays at capacity)
self.len = self.capacity;
}
/// Get current buffer contents as a UTF-8 string.
///
/// This returns a snapshot of the current contents. Invalid UTF-8 sequences
/// are handled gracefully by truncating at character boundaries.
///
/// # Returns
///
/// A UTF-8 string containing the buffer contents. If the buffer contains
/// invalid UTF-8 at the boundary, it will be truncated to the last valid
/// character boundary.
///
/// # Examples
///
/// ```
/// use dirigent_tools::terminal::RingBuffer;
///
/// let mut buffer = RingBuffer::new(1024);
/// buffer.push(b"Hello, World!");
/// assert_eq!(buffer.snapshot(), "Hello, World!");
/// ```
pub fn snapshot(&self) -> String {
if self.len == 0 {
return String::new();
}
// Reconstruct the linear buffer from the circular buffer
let linear = if self.len < self.capacity {
// Buffer not full yet, data is at the beginning
&self.buffer[..self.len]
} else {
// Buffer is full, need to reconstruct in correct order
let mut temp = Vec::with_capacity(self.capacity);
let start_pos = self.write_pos;
for i in 0..self.capacity {
let pos = (start_pos + i) % self.capacity;
temp.push(self.buffer[pos]);
}
// Find UTF-8 boundary and return owned data
let boundary = find_char_boundary(&temp, temp.len());
return String::from_utf8_lossy(&temp[..boundary]).into_owned();
};
// Find a valid UTF-8 character boundary
let boundary = find_char_boundary(linear, linear.len());
String::from_utf8_lossy(&linear[..boundary]).into_owned()
}
/// Check if any data has been truncated due to buffer overflow.
///
/// # Returns
///
/// `true` if data was dropped, `false` otherwise.
///
/// # Examples
///
/// ```
/// use dirigent_tools::terminal::RingBuffer;
///
/// let mut buffer = RingBuffer::new(5);
/// buffer.push(b"Hello");
/// assert!(!buffer.is_truncated());
///
/// buffer.push(b" World!");
/// assert!(buffer.is_truncated());
/// ```
pub fn is_truncated(&self) -> bool {
self.truncated
}
/// Clear the buffer and reset the truncation flag.
///
/// # Examples
///
/// ```
/// use dirigent_tools::terminal::RingBuffer;
///
/// let mut buffer = RingBuffer::new(1024);
/// buffer.push(b"Hello");
/// buffer.clear();
/// assert_eq!(buffer.snapshot(), "");
/// assert!(!buffer.is_truncated());
/// ```
pub fn clear(&mut self) {
self.write_pos = 0;
self.len = 0;
self.truncated = false;
self.buffer.clear();
}
/// Get the current length of valid data in the buffer.
pub fn len(&self) -> usize {
self.len
}
/// Check if the buffer is empty.
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
/// Find the nearest character boundary at or before `start` position.
///
/// This ensures we don't cut in the middle of a UTF-8 multi-byte character.
///
/// # Arguments
///
/// * `buf` - Byte buffer to scan
/// * `start` - Position to start scanning backwards from
///
/// # Returns
///
/// The position of a valid UTF-8 character boundary <= `start`.
///
/// # UTF-8 Encoding Rules
///
/// - Single-byte char: `0xxxxxxx` (0x00-0x7F)
/// - Continuation byte: `10xxxxxx` (0x80-0xBF)
/// - Start of 2-byte char: `110xxxxx` (0xC0-0xDF)
/// - Start of 3-byte char: `1110xxxx` (0xE0-0xEF)
/// - Start of 4-byte char: `11110xxx` (0xF0-0xF7)
fn find_char_boundary(buf: &[u8], start: usize) -> usize {
if start == 0 || buf.is_empty() {
return 0;
}
let start = start.min(buf.len());
// Scan backwards to find a valid character start
for i in (0..start).rev() {
let byte = buf[i];
// Check if this is a valid character start (not a continuation byte)
if byte & 0b1100_0000 != 0b1000_0000 {
// This is either ASCII (0xxxxxxx) or a multi-byte start (11xxxxxx)
// Verify we have enough bytes for a complete character
let char_len = if byte & 0b1000_0000 == 0 {
1 // ASCII
} else if byte & 0b1110_0000 == 0b1100_0000 {
2 // 2-byte char
} else if byte & 0b1111_0000 == 0b1110_0000 {
3 // 3-byte char
} else if byte & 0b1111_1000 == 0b1111_0000 {
4 // 4-byte char
} else {
// Invalid UTF-8 start byte, skip it
continue;
};
// Check if we have enough bytes remaining
if i + char_len <= start {
// Validate that all continuation bytes are present
let mut valid = true;
for j in 1..char_len {
if i + j >= buf.len() || buf[i + j] & 0b1100_0000 != 0b1000_0000 {
valid = false;
break;
}
}
if valid {
return i + char_len;
}
}
}
}
// If we couldn't find a valid boundary, return 0
0
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_buffer() {
let buffer = RingBuffer::new(1024);
assert_eq!(buffer.snapshot(), "");
assert!(!buffer.is_truncated());
assert_eq!(buffer.len(), 0);
assert!(buffer.is_empty());
}
#[test]
fn test_push_simple() {
let mut buffer = RingBuffer::new(1024);
buffer.push(b"Hello");
assert_eq!(buffer.snapshot(), "Hello");
assert!(!buffer.is_truncated());
assert_eq!(buffer.len(), 5);
}
#[test]
fn test_push_multiple() {
let mut buffer = RingBuffer::new(1024);
buffer.push(b"Hello");
buffer.push(b" ");
buffer.push(b"World");
assert_eq!(buffer.snapshot(), "Hello World");
assert!(!buffer.is_truncated());
}
#[test]
fn test_overflow_truncates() {
let mut buffer = RingBuffer::new(10);
buffer.push(b"Hello");
assert!(!buffer.is_truncated());
buffer.push(b" World!");
assert!(buffer.is_truncated());
// Should keep the last 10 bytes
let result = buffer.snapshot();
assert!(result.len() <= 10);
assert!(result.ends_with("World!"));
}
#[test]
fn test_utf8_boundary_preservation() {
let mut buffer = RingBuffer::new(10);
// Push UTF-8 string with multi-byte characters (emoji)
// "Hello😀" is 10 bytes total (Hello=5, 😀=4, but we only have 10 capacity)
buffer.push("Hello".as_bytes());
buffer.push("😀".as_bytes()); // This should fit exactly
let result = buffer.snapshot();
assert!(result == "Hello😀" || result == "Hello"); // Depending on implementation
assert!(!buffer.is_truncated() || result.len() <= 10);
}
#[test]
fn test_utf8_truncation() {
let mut buffer = RingBuffer::new(8);
// Push string that will cause truncation in the middle of multi-byte char
buffer.push("Hello😀World".as_bytes());
// Result should be valid UTF-8 (no partial emoji)
let result = buffer.snapshot();
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
assert!(buffer.is_truncated());
}
#[test]
fn test_clear() {
let mut buffer = RingBuffer::new(1024);
buffer.push(b"Hello");
buffer.clear();
assert_eq!(buffer.snapshot(), "");
assert!(!buffer.is_truncated());
assert_eq!(buffer.len(), 0);
}
#[test]
fn test_empty_push() {
let mut buffer = RingBuffer::new(1024);
buffer.push(b"");
assert_eq!(buffer.snapshot(), "");
assert!(!buffer.is_truncated());
}
#[test]
fn test_large_data_at_once() {
let mut buffer = RingBuffer::new(100);
let large_data = vec![b'A'; 500];
buffer.push(&large_data);
assert!(buffer.is_truncated());
assert_eq!(buffer.len(), 100);
// Should contain only 'A's
let result = buffer.snapshot();
assert!(result.chars().all(|c| c == 'A'));
assert!(result.len() <= 100);
}
#[test]
fn test_find_char_boundary_ascii() {
let data = b"Hello";
assert_eq!(find_char_boundary(data, 5), 5);
assert_eq!(find_char_boundary(data, 3), 3);
assert_eq!(find_char_boundary(data, 0), 0);
}
#[test]
fn test_find_char_boundary_utf8() {
// "Hello😀" - emoji is 4 bytes: F0 9F 98 80
let data = "Hello😀".as_bytes();
let total_len = data.len(); // 5 + 4 = 9
// Should find boundaries correctly
assert_eq!(find_char_boundary(data, total_len), total_len);
assert_eq!(find_char_boundary(data, 5), 5); // After "Hello"
// Middle of emoji should backtrack to before emoji
assert_eq!(find_char_boundary(data, 6), 5); // 1st continuation byte
assert_eq!(find_char_boundary(data, 7), 5); // 2nd continuation byte
assert_eq!(find_char_boundary(data, 8), 5); // 3rd continuation byte
}
#[test]
fn test_find_char_boundary_multi_utf8() {
// Multiple multi-byte characters
let data = "日本語".as_bytes(); // 3 chars, 9 bytes (3 each)
assert_eq!(find_char_boundary(data, 9), 9); // End
assert_eq!(find_char_boundary(data, 6), 6); // After 2nd char
assert_eq!(find_char_boundary(data, 3), 3); // After 1st char
// Middle of 2nd character
assert_eq!(find_char_boundary(data, 4), 3);
assert_eq!(find_char_boundary(data, 5), 3);
// Middle of 3rd character
assert_eq!(find_char_boundary(data, 7), 6);
assert_eq!(find_char_boundary(data, 8), 6);
}
#[test]
fn test_circular_buffer_behavior() {
let mut buffer = RingBuffer::new(5);
buffer.push(b"ABCDE");
assert_eq!(buffer.snapshot(), "ABCDE");
buffer.push(b"FG");
assert!(buffer.is_truncated());
let result = buffer.snapshot();
assert_eq!(result.len(), 5);
assert!(result.ends_with("G"));
}
}
+160
View File
@@ -0,0 +1,160 @@
//! Terminal wait-for-exit operation.
//!
//! **Status**: Not yet implemented (TOOLS-TERM-03)
//!
//! This module will implement:
//! - Blocking wait for process completion
//! - Runtime timeout enforcement
//! - Exit status return
use crate::config::TerminalConfig;
use crate::error::{ToolError, ToolResult};
use serde::{Deserialize, Serialize};
/// Request to wait for terminal to exit.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WaitForTerminalExitRequest {
/// Terminal ID from create response.
pub terminal_id: String,
}
/// Response with exit status.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WaitForTerminalExitResponse {
/// Process exit status code.
pub exit_status: i32,
}
/// Wait for terminal process to complete.
///
/// **Status**: Implemented (TOOLS-TERM-03)
///
/// ## Implementation
///
/// This function:
/// 1. Looks up TerminalId in registry
/// 2. Gets process handle
/// 3. If already exited → returns cached exit status immediately
/// 4. If still running:
/// - Awaits process completion (tokio child.wait())
/// - Enforces `max_runtime_secs` timeout via tokio::time::timeout
/// - If timeout → kills process and returns error
/// 5. Caches exit status in registry
/// 6. Returns WaitForTerminalExitResponse
///
/// ## Timeout Behavior
///
/// - Uses `max_runtime_secs` from TerminalConfig
/// - On timeout:
/// - Process is killed (SIGKILL/TerminateProcess)
/// - Returns `ToolError::TerminalError` with timeout message
/// - Terminal remains in registry (can still get output)
///
/// ## Blocking vs Polling
///
/// - `wait_for_exit()` → Blocks until completion or timeout
/// - `get_output()` → Returns immediately with current output
///
/// Use `wait_for_exit()` when you want to ensure completion before proceeding.
///
/// ## Error Cases
///
/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound`
/// - Terminal released → `ToolError::TerminalNotFound`
/// - Timeout exceeded → `ToolError::TerminalError`
///
/// ## See Also
///
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-03`
/// - Get output: `output.rs`
pub async fn wait_for_terminal_exit(
request: WaitForTerminalExitRequest,
config: &TerminalConfig,
) -> ToolResult<WaitForTerminalExitResponse> {
use crate::terminal::registry::global_registry;
use std::time::Duration;
let registry = global_registry();
let terminal_id = request.terminal_id.clone();
// Check if already exited
let already_exited = registry.get_mut(&terminal_id, |state| {
state.exit_status.as_ref().map(|status| {
#[cfg(unix)]
{
use std::os::unix::process::ExitStatusExt;
status.code().or_else(|| status.signal()).unwrap_or(-1)
}
#[cfg(not(unix))]
{
status.code().unwrap_or(-1)
}
})
})?;
if let Some(exit_status) = already_exited {
return Ok(WaitForTerminalExitResponse { exit_status });
}
// Process is still running, need to wait for it
// We need to take ownership of the process to wait on it
// This is tricky with the registry design, so we'll use a different approach:
// We'll poll the process status and check exit_status
let timeout_duration = Duration::from_secs(config.max_runtime_secs);
let start_time = std::time::Instant::now();
loop {
// Check if timeout exceeded
if start_time.elapsed() >= timeout_duration {
// Kill the process due to timeout
registry.get_mut(&terminal_id, |state| {
let _ = state.process.start_kill();
state.killed = true;
})?;
return Err(ToolError::terminal_error(format!(
"Terminal timed out after {} seconds",
config.max_runtime_secs
)));
}
// Try to get exit status (non-blocking check)
let exit_status_result = registry.get_mut(&terminal_id, |state| {
// Try to check if process has exited
match state.process.try_wait() {
Ok(Some(status)) => {
// Process has exited
state.exit_status = Some(status);
let exit_code = {
#[cfg(unix)]
{
use std::os::unix::process::ExitStatusExt;
status.code().or_else(|| status.signal()).unwrap_or(-1)
}
#[cfg(not(unix))]
{
status.code().unwrap_or(-1)
}
};
Some(exit_code)
}
Ok(None) => {
// Process is still running
None
}
Err(_e) => {
// Error checking status, treat as still running
None
}
}
})?;
if let Some(exit_status) = exit_status_result {
return Ok(WaitForTerminalExitResponse { exit_status });
}
// Wait a bit before checking again
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
+76
View File
@@ -0,0 +1,76 @@
//! Per-call scoping context shared by every harness layer.
use crate::config::{PermissionConfig, SandboxConfig};
use crate::permission::check::PermissionContext;
use crate::tool::{ClientKind, ProtocolKind};
use std::path::PathBuf;
use std::sync::Arc;
/// Per-call scoping context. Passed to every layer below the registry.
///
/// `connector_id` and `session_id` mirror what the existing
/// [`PermissionContext`] uses; `client_kind` and `protocol` are the new
/// override surface for per-client / per-protocol behaviour.
#[derive(Clone)]
pub struct ToolContext {
pub connector_id: Arc<str>,
pub session_id: Option<Arc<str>>,
pub client_kind: ClientKind,
pub protocol: ProtocolKind,
pub workspace_root: PathBuf,
pub sandbox: Arc<SandboxConfig>,
pub permission: Arc<PermissionConfig>,
pub permission_context: Arc<PermissionContext>,
}
impl ToolContext {
/// Test/builder helper. Real callers compose this from connector state.
pub fn for_test(
connector_id: impl Into<Arc<str>>,
client_kind: ClientKind,
protocol: ProtocolKind,
workspace_root: PathBuf,
sandbox: SandboxConfig,
permission: PermissionConfig,
permission_context: PermissionContext,
) -> Self {
Self {
connector_id: connector_id.into(),
session_id: None,
client_kind,
protocol,
workspace_root,
sandbox: Arc::new(sandbox),
permission: Arc::new(permission),
permission_context: Arc::new(permission_context),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::WhitelistConfig;
use crate::permission::whitelist::CompiledWhitelist;
#[test]
fn context_builds_with_test_helper() {
let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
let perm_ctx = PermissionContext::new(
"conn-1".to_string(),
None,
whitelist,
);
let ctx = ToolContext::for_test(
"conn-1",
ClientKind::claude(),
ProtocolKind::acp(),
PathBuf::from("/tmp"),
SandboxConfig::default(),
PermissionConfig::default(),
perm_ctx,
);
assert_eq!(&*ctx.connector_id, "conn-1");
assert_eq!(ctx.client_kind, ClientKind::claude());
}
}
+197
View File
@@ -0,0 +1,197 @@
//! `Tool` trait + object-safe `AnyTool` + `Erased<T>` adapter.
use crate::tool::{ToolContext, ToolEventSink, ToolKind};
use async_trait::async_trait;
use schemars::{JsonSchema, schema_for};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
/// Streaming-aware tool input.
///
/// Tools opt into streaming via [`Tool::supports_input_streaming`]. If they
/// do not, the harness buffers and always provides [`ToolInput::Final`].
pub enum ToolInput<T> {
Final(T),
Partial {
partial: mpsc::UnboundedReceiver<serde_json::Value>,
final_input: oneshot::Receiver<T>,
},
}
/// JSON-shaped variant used by the object-safe `AnyTool` trait.
pub enum AnyToolInput {
Final(serde_json::Value),
Partial {
partial: mpsc::UnboundedReceiver<serde_json::Value>,
final_input: oneshot::Receiver<serde_json::Value>,
},
}
/// Strongly-typed tool implementation.
#[async_trait]
pub trait Tool: Send + Sync + 'static {
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema + Send + 'static;
type Output: Serialize + Send + 'static;
const NAME: &'static str;
fn kind() -> ToolKind;
fn input_schema() -> serde_json::Value {
serde_json::to_value(schema_for!(Self::Input)).expect("schema_for must serialise")
}
fn supports_input_streaming() -> bool { false }
async fn run(
self: Arc<Self>,
input: ToolInput<Self::Input>,
events: ToolEventSink,
ctx: &ToolContext,
) -> Result<Self::Output, Self::Output>;
fn erase(self: Arc<Self>) -> Arc<dyn AnyTool> where Self: Sized { Arc::new(Erased(self)) }
}
/// Object-safe variant. The registry stores `Arc<dyn AnyTool>`.
#[async_trait]
pub trait AnyTool: Send + Sync + 'static {
fn name(&self) -> &'static str;
fn kind(&self) -> ToolKind;
fn input_schema(&self) -> serde_json::Value;
fn supports_input_streaming(&self) -> bool;
async fn run(
self: Arc<Self>,
input: AnyToolInput,
events: ToolEventSink,
ctx: &ToolContext,
) -> Result<serde_json::Value, serde_json::Value>;
}
/// Adapter from a typed `Tool` to `AnyTool`.
pub struct Erased<T: Tool>(pub Arc<T>);
#[async_trait]
impl<T: Tool> AnyTool for Erased<T> {
fn name(&self) -> &'static str { T::NAME }
fn kind(&self) -> ToolKind { T::kind() }
fn input_schema(&self) -> serde_json::Value { T::input_schema() }
fn supports_input_streaming(&self) -> bool { T::supports_input_streaming() }
async fn run(
self: Arc<Self>,
input: AnyToolInput,
events: ToolEventSink,
ctx: &ToolContext,
) -> Result<serde_json::Value, serde_json::Value> {
let typed = match input {
AnyToolInput::Final(v) => {
let parsed: T::Input = serde_json::from_value(v).map_err(|e| {
serde_json::json!({ "error": format!("invalid input: {e}") })
})?;
ToolInput::Final(parsed)
}
AnyToolInput::Partial { partial, final_input: _ } => {
// For v1: tools that opt into streaming receive Partial; the
// typed-final-input wiring is added when a streaming tool ships.
// Until then, only Final is fed by the dispatcher.
let _ = partial;
return Err(serde_json::json!({
"error": "streaming inputs are not yet wired in v1 dispatcher"
}));
}
};
let inner = self.0.clone();
let result = inner.run(typed, events, ctx).await;
match result {
Ok(o) => Ok(serde_json::to_value(o).unwrap_or(serde_json::Value::Null)),
Err(o) => Err(serde_json::to_value(o).unwrap_or(serde_json::Value::Null)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig};
use crate::permission::check::PermissionContext;
use crate::permission::whitelist::CompiledWhitelist;
use crate::tool::{ClientKind, ProtocolKind};
use std::path::PathBuf;
#[derive(Serialize, Deserialize, JsonSchema)]
struct EchoInput { msg: String }
#[derive(Serialize, Deserialize)]
struct EchoOutput { echoed: String }
struct EchoTool;
#[async_trait]
impl Tool for EchoTool {
type Input = EchoInput;
type Output = EchoOutput;
const NAME: &'static str = "echo";
fn kind() -> ToolKind { ToolKind::Other }
async fn run(
self: Arc<Self>,
input: ToolInput<Self::Input>,
_events: ToolEventSink,
_ctx: &ToolContext,
) -> Result<Self::Output, Self::Output> {
let i = match input { ToolInput::Final(i) => i, _ => unreachable!() };
Ok(EchoOutput { echoed: i.msg })
}
}
fn ctx() -> ToolContext {
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
let pc = PermissionContext::new("conn-1".to_string(), None, wl);
ToolContext::for_test(
"conn-1", ClientKind::claude(), ProtocolKind::acp(),
PathBuf::from("/tmp"),
SandboxConfig::default(), PermissionConfig::default(), pc,
)
}
#[tokio::test]
async fn typed_tool_runs_and_returns_output() {
let tool: Arc<EchoTool> = Arc::new(EchoTool);
let any: Arc<dyn AnyTool> = tool.erase();
let (sink, _rx) = ToolEventSink::new();
let result = any.run(
AnyToolInput::Final(serde_json::json!({ "msg": "hello" })),
sink,
&ctx(),
).await.unwrap();
assert_eq!(result["echoed"], "hello");
}
#[tokio::test]
async fn invalid_input_returns_structured_error() {
let tool: Arc<EchoTool> = Arc::new(EchoTool);
let any: Arc<dyn AnyTool> = tool.erase();
let (sink, _rx) = ToolEventSink::new();
let err = any.run(
AnyToolInput::Final(serde_json::json!({ "wrong": 1 })),
sink,
&ctx(),
).await.unwrap_err();
assert!(err["error"].as_str().unwrap().contains("invalid input"));
}
#[test]
fn name_and_kind_round_trip_through_erase() {
let tool: Arc<EchoTool> = Arc::new(EchoTool);
let any: Arc<dyn AnyTool> = tool.erase();
assert_eq!(any.name(), "echo");
assert_eq!(any.kind(), ToolKind::Other);
}
}
+134
View File
@@ -0,0 +1,134 @@
//! Neutral tool events. Connector adapters translate to wire types.
use crate::tool::ToolKind;
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::mpsc;
/// Opaque permission-request id, allocated by the dispatcher.
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct PermissionRequestId(Arc<str>);
impl PermissionRequestId {
/// Construct a new permission-request id from any string-like value.
pub fn new(value: impl Into<Arc<str>>) -> Self {
Self(value.into())
}
/// Borrow the inner id as a string slice.
pub fn as_str(&self) -> &str {
&self.0
}
}
/// Where a tool is operating (file path + optional line).
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolLocation {
pub path: String,
pub line: Option<u32>,
}
/// Result content shape. Mirrors what most providers accept as a tool result.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolResultContent {
Text { text: Arc<str> },
Json { value: serde_json::Value },
Image { mime: Arc<str>, #[serde(with = "serde_bytes_arc")] data: Bytes },
Parts { parts: Vec<ToolResultContent> },
}
impl ToolResultContent {
pub fn text(s: impl Into<Arc<str>>) -> Self { Self::Text { text: s.into() } }
}
/// Events emitted by a running tool. Transport-agnostic.
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolEvent {
Started { title: Arc<str>, kind: ToolKind, location: Option<ToolLocation> },
TitleUpdate { title: Arc<str>, location: Option<ToolLocation> },
PartialOutput { content: ToolResultContent },
Status { message: Arc<str> },
PermissionRequested { request_id: PermissionRequestId, summary: Arc<str> },
Completed,
Failed,
}
/// Sink a tool emits events into. Cheap to clone.
#[derive(Clone, Debug)]
pub struct ToolEventSink {
tx: mpsc::UnboundedSender<ToolEvent>,
}
impl ToolEventSink {
pub fn new() -> (Self, mpsc::UnboundedReceiver<ToolEvent>) {
let (tx, rx) = mpsc::unbounded_channel();
(Self { tx }, rx)
}
/// Best-effort emit. Drops the event if the receiver is gone.
pub fn emit(&self, event: ToolEvent) {
let _ = self.tx.send(event);
}
}
mod serde_bytes_arc {
use bytes::Bytes;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn serialize<S: Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> {
serde_bytes::Bytes::new(b.as_ref()).serialize(s)
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> {
let v: Vec<u8> = serde_bytes::ByteBuf::deserialize(d)?.into_vec();
Ok(Bytes::from(v))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn sink_round_trips_event() {
let (sink, mut rx) = ToolEventSink::new();
sink.emit(ToolEvent::Status { message: "hi".into() });
let got = rx.recv().await.unwrap();
match got {
ToolEvent::Status { message } => assert_eq!(&*message, "hi"),
_ => panic!("wrong variant"),
}
}
#[test]
fn result_content_text_helper() {
match ToolResultContent::text("hello") {
ToolResultContent::Text { text } => assert_eq!(&*text, "hello"),
_ => panic!(),
}
}
#[test]
fn tool_event_serde_round_trip() {
let ev = ToolEvent::PartialOutput { content: ToolResultContent::text("x") };
let json = serde_json::to_string(&ev).unwrap();
let _back: ToolEvent = serde_json::from_str(&json).unwrap();
}
#[test]
fn permission_request_id_constructor_and_accessor() {
let id = PermissionRequestId::new("abc");
assert_eq!(id.as_str(), "abc");
}
#[test]
fn permission_request_id_serde_is_transparent_string() {
let id = PermissionRequestId::new("foo");
let json = serde_json::to_string(&id).unwrap();
assert_eq!(json, "\"foo\"");
let back: PermissionRequestId = serde_json::from_str(&json).unwrap();
assert_eq!(back, id);
}
}
+135
View File
@@ -0,0 +1,135 @@
//! Open-form client/protocol identifiers and tool category enum.
use serde::{Deserialize, Serialize};
use std::sync::{Arc, OnceLock};
/// Open newtype identifying the upstream client family (Claude, Codex, etc.).
///
/// Use the provided constants for known clients; use [`ClientKind::custom`]
/// for anything else. Comparison is by inner string equality.
///
/// The well-known constructors (`claude`, `codex`, `gemini`, `opencode`)
/// return cached values backed by a process-wide `OnceLock`, so calling them
/// repeatedly is a cheap `Arc` clone — no per-call heap allocation.
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct ClientKind(Arc<str>);
impl ClientKind {
pub fn claude() -> Self {
static CELL: OnceLock<ClientKind> = OnceLock::new();
CELL.get_or_init(|| Self(Arc::from("claude"))).clone()
}
pub fn codex() -> Self {
static CELL: OnceLock<ClientKind> = OnceLock::new();
CELL.get_or_init(|| Self(Arc::from("codex"))).clone()
}
pub fn gemini() -> Self {
static CELL: OnceLock<ClientKind> = OnceLock::new();
CELL.get_or_init(|| Self(Arc::from("gemini"))).clone()
}
pub fn opencode() -> Self {
static CELL: OnceLock<ClientKind> = OnceLock::new();
CELL.get_or_init(|| Self(Arc::from("opencode"))).clone()
}
pub fn custom(name: impl Into<Arc<str>>) -> Self { Self(name.into()) }
pub fn as_str(&self) -> &str { &self.0 }
}
/// Open newtype identifying the wire protocol (ACP, OpenCode, native).
///
/// As with [`ClientKind`], the well-known constructors return cached values
/// backed by a process-wide `OnceLock`.
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
pub struct ProtocolKind(Arc<str>);
impl ProtocolKind {
pub fn acp() -> Self {
static CELL: OnceLock<ProtocolKind> = OnceLock::new();
CELL.get_or_init(|| Self(Arc::from("acp"))).clone()
}
pub fn opencode() -> Self {
static CELL: OnceLock<ProtocolKind> = OnceLock::new();
CELL.get_or_init(|| Self(Arc::from("opencode"))).clone()
}
pub fn native() -> Self {
static CELL: OnceLock<ProtocolKind> = OnceLock::new();
CELL.get_or_init(|| Self(Arc::from("native"))).clone()
}
pub fn custom(name: impl Into<Arc<str>>) -> Self { Self(name.into()) }
pub fn as_str(&self) -> &str { &self.0 }
}
/// Coarse category of a tool. Connectors use this to pick wire-level hints
/// (e.g. ACP `ToolKind::Edit` triggers diff rendering on the client).
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolKind {
Read,
Edit,
Search,
Execute,
Fetch,
Think,
Other,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn client_kind_constants_are_distinct() {
assert_ne!(ClientKind::claude(), ClientKind::codex());
assert_eq!(ClientKind::claude(), ClientKind::claude());
}
#[test]
fn client_kind_custom_round_trips() {
let k = ClientKind::custom("aider");
assert_eq!(k.as_str(), "aider");
assert_eq!(k, ClientKind::custom("aider"));
}
#[test]
fn protocol_kind_distinct() {
assert_ne!(ProtocolKind::acp(), ProtocolKind::opencode());
}
#[test]
fn tool_kind_serde_round_trip() {
let json = serde_json::to_string(&ToolKind::Edit).unwrap();
assert_eq!(json, "\"edit\"");
let back: ToolKind = serde_json::from_str(&json).unwrap();
assert_eq!(back, ToolKind::Edit);
}
#[test]
fn well_known_client_kinds_are_cached() {
let a = ClientKind::claude();
let b = ClientKind::claude();
assert_eq!(a, b);
// Proves the cache works: same Arc, no per-call allocation.
assert!(Arc::ptr_eq(&a.0, &b.0));
// Sanity-check another well-known kind too.
let c1 = ClientKind::codex();
let c2 = ClientKind::codex();
assert!(Arc::ptr_eq(&c1.0, &c2.0));
// Custom values are *not* cached — each call allocates fresh.
let x = ClientKind::custom("aider");
let y = ClientKind::custom("aider");
assert_eq!(x, y);
assert!(!Arc::ptr_eq(&x.0, &y.0));
}
#[test]
fn well_known_protocol_kinds_are_cached() {
let a = ProtocolKind::acp();
let b = ProtocolKind::acp();
assert_eq!(a, b);
assert!(Arc::ptr_eq(&a.0, &b.0));
}
}
+106
View File
@@ -0,0 +1,106 @@
//! Compile-time built-in tool registration macro.
//!
//! ```ignore
//! crate::tools! { ReadTool, GrepTool, EditTool }
//! ```
//!
//! Produces:
//! - `pub const ALL_BUILT_IN_TOOL_NAMES: &[&str] = &[...];`
//! - `pub fn built_in_tools() -> impl Iterator<Item = std::sync::Arc<dyn AnyTool>>`
//! - A compile-time uniqueness check on `T::NAME`.
#[macro_export]
macro_rules! tools {
($($tool:ty),* $(,)?) => {
pub const ALL_BUILT_IN_TOOL_NAMES: &[&str] = &[
$( <$tool as $crate::tool::Tool>::NAME, )*
];
const _: () = {
const fn str_eq(a: &str, b: &str) -> bool {
let a = a.as_bytes();
let b = b.as_bytes();
if a.len() != b.len() { return false; }
let mut i = 0;
while i < a.len() {
if a[i] != b[i] { return false; }
i += 1;
}
true
}
const NAMES: &[&str] = ALL_BUILT_IN_TOOL_NAMES;
let mut i = 0;
while i < NAMES.len() {
let mut j = i + 1;
while j < NAMES.len() {
if str_eq(NAMES[i], NAMES[j]) {
panic!("Duplicate built-in tool name");
}
j += 1;
}
i += 1;
}
};
pub fn built_in_tools() -> Vec<std::sync::Arc<dyn $crate::tool::AnyTool>> {
vec![
$(
{
let t: std::sync::Arc<$tool> = std::sync::Arc::new(<$tool>::default());
<$tool as $crate::tool::Tool>::erase(t)
},
)*
]
}
};
}
#[cfg(test)]
mod tests {
use crate::tool::{Tool, ToolContext, ToolEventSink, ToolInput, ToolKind};
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Default)]
struct AlphaTool;
#[derive(Default)]
struct BetaTool;
#[derive(Serialize, Deserialize, JsonSchema)]
struct Empty {}
macro_rules! impl_tool {
($t:ty, $name:literal) => {
#[async_trait]
impl Tool for $t {
type Input = Empty;
type Output = Empty;
const NAME: &'static str = $name;
fn kind() -> ToolKind { ToolKind::Other }
async fn run(
self: Arc<Self>, _i: ToolInput<Empty>,
_e: ToolEventSink, _c: &ToolContext,
) -> Result<Empty, Empty> { Ok(Empty {}) }
}
};
}
impl_tool!(AlphaTool, "alpha");
impl_tool!(BetaTool, "beta");
crate::tools! { AlphaTool, BetaTool }
#[test]
fn macro_lists_names() {
assert_eq!(ALL_BUILT_IN_TOOL_NAMES, &["alpha", "beta"]);
}
#[test]
fn macro_constructs_erased_tools() {
let v = built_in_tools();
assert_eq!(v.len(), 2);
assert_eq!(v[0].name(), "alpha");
assert_eq!(v[1].name(), "beta");
}
}
+14
View File
@@ -0,0 +1,14 @@
//! Tool trait, registry-facing types, and per-call context.
pub mod kinds;
pub mod events;
pub mod context;
pub mod erase;
pub mod macros;
pub use kinds::{ClientKind, ProtocolKind, ToolKind};
pub use events::{
PermissionRequestId, ToolEvent, ToolEventSink, ToolLocation, ToolResultContent,
};
pub use context::ToolContext;
pub use erase::{AnyTool, AnyToolInput, Erased, Tool, ToolInput};
+5
View File
@@ -0,0 +1,5 @@
//! Built-in tool implementations registered against the `Tool` trait.
pub mod read;
pub use read::ReadTool;
+151
View File
@@ -0,0 +1,151 @@
//! `ReadTool`: built-in trait wrapper around `crate::fs::read_text_file`.
use crate::fs::read::{read_text_file, ReadTextFileRequest, ReadTextFileResponse};
use crate::tool::{
Tool, ToolContext, ToolEvent, ToolEventSink, ToolInput, ToolKind, ToolLocation,
ToolResultContent,
};
use async_trait::async_trait;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Serialize, Deserialize, JsonSchema, Clone)]
pub struct ReadInput {
/// Absolute path to the file to read.
pub path: String,
/// Optional 1-indexed start line.
#[serde(default)] pub line: Option<usize>,
/// Optional max number of lines to return.
#[serde(default)] pub limit: Option<usize>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ReadOutput {
Ok { content: String },
Err { error: String },
}
impl ReadOutput {
fn err(msg: impl Into<String>) -> Self { ReadOutput::Err { error: msg.into() } }
}
#[derive(Default)]
pub struct ReadTool;
#[async_trait]
impl Tool for ReadTool {
type Input = ReadInput;
type Output = ReadOutput;
const NAME: &'static str = "read";
fn kind() -> ToolKind { ToolKind::Read }
async fn run(
self: Arc<Self>,
input: ToolInput<Self::Input>,
events: ToolEventSink,
ctx: &ToolContext,
) -> Result<Self::Output, Self::Output> {
let i = match input {
ToolInput::Final(i) => i,
_ => return Err(ReadOutput::err("streaming inputs not supported by read")),
};
events.emit(ToolEvent::Started {
title: format!("Read {}", i.path).into(),
kind: ToolKind::Read,
location: Some(ToolLocation { path: i.path.clone(), line: i.line.map(|n| n as u32) }),
});
let req = ReadTextFileRequest { path: i.path.clone(), line: i.line, limit: i.limit };
let res: ReadTextFileResponse = read_text_file(req, ctx.sandbox.as_ref()).await
.map_err(|e| ReadOutput::err(e.to_string()))?;
events.emit(ToolEvent::PartialOutput {
content: ToolResultContent::text(res.content.clone()),
});
events.emit(ToolEvent::Completed);
Ok(ReadOutput::Ok { content: res.content })
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig};
use crate::permission::check::PermissionContext;
use crate::permission::whitelist::CompiledWhitelist;
use crate::tool::{ClientKind, ProtocolKind};
use std::io::Write;
use tempfile::TempDir;
fn ctx_for(root: &std::path::Path) -> ToolContext {
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
let pc = PermissionContext::new("c".to_string(), None, wl);
let mut sandbox = SandboxConfig::default();
sandbox.allowed_roots = vec![root.to_path_buf()];
ToolContext::for_test(
"c", ClientKind::claude(), ProtocolKind::acp(),
root.to_path_buf(), sandbox, PermissionConfig::default(), pc,
)
}
#[tokio::test]
async fn reads_file_through_trait() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("hello.txt");
std::fs::File::create(&path).unwrap().write_all(b"hi\nthere\n").unwrap();
let tool = Arc::new(ReadTool);
let (sink, mut rx) = ToolEventSink::new();
let result = Tool::run(
tool,
ToolInput::Final(ReadInput {
path: path.to_string_lossy().into_owned(),
line: None, limit: None,
}),
sink,
&ctx_for(dir.path()),
).await.unwrap();
match result {
ReadOutput::Ok { content } => {
assert!(content.contains("hi"));
assert!(content.contains("there"));
}
ReadOutput::Err { error } => panic!("expected ok, got err: {error}"),
}
// At least Started + PartialOutput + Completed should have fired.
let mut count = 0;
while let Ok(_ev) = rx.try_recv() { count += 1; }
assert!(count >= 3, "expected >=3 events, got {count}");
}
#[tokio::test]
async fn returns_structured_error_on_sandbox_violation() {
let dir = TempDir::new().unwrap();
let other = TempDir::new().unwrap();
let outside = other.path().join("nope.txt");
std::fs::File::create(&outside).unwrap().write_all(b"x").unwrap();
let tool = Arc::new(ReadTool);
let (sink, _rx) = ToolEventSink::new();
let result = Tool::run(
tool,
ToolInput::Final(ReadInput {
path: outside.to_string_lossy().into_owned(),
line: None, limit: None,
}),
sink,
&ctx_for(dir.path()),
).await.unwrap_err();
match result {
ReadOutput::Err { error } => assert!(!error.is_empty()),
ReadOutput::Ok { .. } => panic!("expected sandbox error"),
}
}
}