sync from monorepo @ 2452e92e
This commit is contained in:
@@ -0,0 +1,58 @@
|
||||
//! Audit logging for sensitive tool operations.
|
||||
//!
|
||||
//! This module provides structured logging for:
|
||||
//! - File read/write operations
|
||||
//! - Terminal command execution
|
||||
//! - Permission decisions
|
||||
//! - Sandbox violations
|
||||
//!
|
||||
//! All audit logs include:
|
||||
//! - Timestamp
|
||||
//! - User/session context
|
||||
//! - Operation type
|
||||
//! - Parameters (sanitized)
|
||||
//! - Outcome (success/error)
|
||||
//!
|
||||
//! TODO: Implement audit logging
|
||||
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Log a file read operation.
|
||||
///
|
||||
/// TODO: Implement with structured fields
|
||||
pub fn log_file_read(_path: &str, _success: bool) {
|
||||
// Placeholder - will use tracing with structured fields
|
||||
info!("File read audit log placeholder");
|
||||
}
|
||||
|
||||
/// Log a file write operation.
|
||||
///
|
||||
/// TODO: Implement with structured fields
|
||||
pub fn log_file_write(_path: &str, _success: bool) {
|
||||
// Placeholder - will use tracing with structured fields
|
||||
info!("File write audit log placeholder");
|
||||
}
|
||||
|
||||
/// Log a terminal command execution.
|
||||
///
|
||||
/// TODO: Implement with structured fields
|
||||
pub fn log_terminal_exec(_command: &str, _success: bool) {
|
||||
// Placeholder - will use tracing with structured fields
|
||||
info!("Terminal exec audit log placeholder");
|
||||
}
|
||||
|
||||
/// Log a permission decision.
|
||||
///
|
||||
/// TODO: Implement with structured fields
|
||||
pub fn log_permission_decision(_operation: &str, _allowed: bool) {
|
||||
// Placeholder - will use tracing with structured fields
|
||||
info!("Permission decision audit log placeholder");
|
||||
}
|
||||
|
||||
/// Log a sandbox violation attempt.
|
||||
///
|
||||
/// TODO: Implement with structured fields
|
||||
pub fn log_sandbox_violation(_path: &str, _reason: &str) {
|
||||
// Placeholder - will use tracing with structured fields
|
||||
warn!("Sandbox violation audit log placeholder");
|
||||
}
|
||||
@@ -0,0 +1,425 @@
|
||||
//! Configuration types for Phase 03 features.
|
||||
//!
|
||||
//! This module defines configuration structures for:
|
||||
//! - `SandboxConfig` - Filesystem sandboxing configuration
|
||||
//! - `PermissionConfig` - Permission prompt and decision caching
|
||||
//! - `TerminalConfig` - Terminal/command execution limits
|
||||
//! - `SearchConfig` - Search operation limits and defaults
|
||||
//! - `EmbeddingConfig` - File embedding thresholds and policies
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
// =============================================================================
|
||||
// Sandbox Configuration
|
||||
// =============================================================================
|
||||
|
||||
/// Filesystem sandboxing configuration.
|
||||
///
|
||||
/// Determines which paths are accessible and how symlinks are handled.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct SandboxConfig {
|
||||
/// Absolute paths that are allowed for file operations.
|
||||
///
|
||||
/// Operations outside these roots will be rejected.
|
||||
pub allowed_roots: Vec<PathBuf>,
|
||||
|
||||
/// Path patterns that are explicitly blocked even if within allowed roots.
|
||||
///
|
||||
/// Supports glob patterns like "**/.env", "**/secrets/**".
|
||||
pub blocked_paths: Vec<String>,
|
||||
|
||||
/// Whether to allow symlinks to escape allowed roots.
|
||||
///
|
||||
/// If false (recommended), symlinks pointing outside allowed roots are rejected.
|
||||
pub allow_symlink_escape: bool,
|
||||
|
||||
/// Whether to follow symlinks within allowed roots.
|
||||
///
|
||||
/// If true, symlinks within allowed roots are followed.
|
||||
pub follow_symlinks_within_roots: bool,
|
||||
|
||||
/// Enable read operations.
|
||||
pub read_enabled: bool,
|
||||
|
||||
/// Enable write operations.
|
||||
pub write_enabled: bool,
|
||||
|
||||
/// Maximum bytes to read in a single request.
|
||||
///
|
||||
/// Soft cap for previews. Default: 1 MB.
|
||||
pub max_read_bytes: u64,
|
||||
|
||||
/// Maximum bytes to write in a single request.
|
||||
///
|
||||
/// Default: 1 MB.
|
||||
pub max_write_bytes: u64,
|
||||
|
||||
/// End-of-line policy for file operations.
|
||||
pub eol_policy: EolPolicy,
|
||||
|
||||
/// Text encoding support.
|
||||
///
|
||||
/// Currently only UTF-8 is supported.
|
||||
pub encoding: String,
|
||||
}
|
||||
|
||||
impl Default for SandboxConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allowed_roots: vec![],
|
||||
blocked_paths: vec!["**/.env".to_string(), "**/secrets/**".to_string()],
|
||||
allow_symlink_escape: false,
|
||||
follow_symlinks_within_roots: true,
|
||||
read_enabled: true,
|
||||
write_enabled: false,
|
||||
max_read_bytes: 1_048_576, // 1 MB
|
||||
max_write_bytes: 1_048_576, // 1 MB
|
||||
eol_policy: EolPolicy::Preserve,
|
||||
encoding: "utf-8".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// End-of-line handling policy.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EolPolicy {
|
||||
/// Preserve original line endings.
|
||||
Preserve,
|
||||
/// Normalize to LF (\n).
|
||||
Lf,
|
||||
/// Normalize to CRLF (\r\n).
|
||||
Crlf,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Permission Configuration
|
||||
// =============================================================================
|
||||
|
||||
/// Permission prompt and decision caching configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct PermissionConfig {
|
||||
/// Permission mode strategy.
|
||||
pub mode: PermissionMode,
|
||||
|
||||
/// Whether to remember permission decisions.
|
||||
pub remember_decisions: bool,
|
||||
|
||||
/// Time-to-live for cached decisions in seconds.
|
||||
///
|
||||
/// Default: 86400 (24 hours).
|
||||
pub remember_ttl_secs: u64,
|
||||
|
||||
/// Scope of cached decisions.
|
||||
pub scope: DecisionScope,
|
||||
|
||||
/// Whitelist configuration for whitelist mode.
|
||||
pub whitelist: WhitelistConfig,
|
||||
}
|
||||
|
||||
impl Default for PermissionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: PermissionMode::Whitelist,
|
||||
remember_decisions: true,
|
||||
remember_ttl_secs: 86_400, // 24 hours
|
||||
scope: DecisionScope::PerConnector,
|
||||
whitelist: WhitelistConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Permission mode strategy.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum PermissionMode {
|
||||
/// Prompt for every sensitive operation.
|
||||
Ask,
|
||||
/// Auto-approve whitelisted operations, prompt for others.
|
||||
Whitelist,
|
||||
/// Auto-approve all operations (with audit logging).
|
||||
Yolo,
|
||||
}
|
||||
|
||||
/// Scope for cached permission decisions.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum DecisionScope {
|
||||
/// Decisions persist per connector.
|
||||
PerConnector,
|
||||
/// Decisions persist only within a session.
|
||||
PerSession,
|
||||
}
|
||||
|
||||
/// Whitelist configuration for auto-approved operations.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct WhitelistConfig {
|
||||
/// Path patterns that are safe for write operations.
|
||||
///
|
||||
/// Glob patterns like "C:/work/project/**".
|
||||
pub write_paths: Vec<String>,
|
||||
|
||||
/// Commands that are safe to execute.
|
||||
///
|
||||
/// Glob patterns like "cargo", "npm", "git".
|
||||
pub execute_commands: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for WhitelistConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
write_paths: vec![],
|
||||
execute_commands: vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Terminal Configuration
|
||||
// =============================================================================
|
||||
|
||||
/// Terminal/command execution configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct TerminalConfig {
|
||||
/// Enable terminal operations.
|
||||
pub enabled: bool,
|
||||
|
||||
/// Default current working directory for terminal commands.
|
||||
///
|
||||
/// Must be within an allowed sandbox root.
|
||||
pub default_cwd: Option<PathBuf>,
|
||||
|
||||
/// Environment variable names that are allowed.
|
||||
///
|
||||
/// Only these variables can be set in spawned processes.
|
||||
pub env_allowlist: Vec<String>,
|
||||
|
||||
/// Command patterns that are blocked (best-effort).
|
||||
///
|
||||
/// Glob patterns for dangerous commands like "rm", "format".
|
||||
pub command_blocklist: Vec<String>,
|
||||
|
||||
/// Maximum bytes to capture from terminal output (ring buffer).
|
||||
///
|
||||
/// Default: 200,000 bytes.
|
||||
pub output_byte_limit: u64,
|
||||
|
||||
/// Maximum runtime for a terminal command in seconds.
|
||||
///
|
||||
/// Commands exceeding this will be killed. Default: 3600 (1 hour).
|
||||
pub max_runtime_secs: u64,
|
||||
}
|
||||
|
||||
impl Default for TerminalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
default_cwd: None,
|
||||
env_allowlist: vec![
|
||||
"RUST_LOG".to_string(),
|
||||
"NODE_ENV".to_string(),
|
||||
"PATH".to_string(),
|
||||
],
|
||||
command_blocklist: vec![
|
||||
"rm".to_string(),
|
||||
"rd".to_string(),
|
||||
"format".to_string(),
|
||||
"mkfs*".to_string(),
|
||||
],
|
||||
output_byte_limit: 200_000,
|
||||
max_runtime_secs: 3_600, // 1 hour
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Search Configuration
|
||||
// =============================================================================
|
||||
|
||||
/// Search operation configuration (glob, grep, ls).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct SearchConfig {
|
||||
/// Maximum number of results to return.
|
||||
///
|
||||
/// Default: 5,000.
|
||||
pub max_results: u32,
|
||||
|
||||
/// Maximum total bytes in search results.
|
||||
///
|
||||
/// Default: 1,000,000 (1 MB).
|
||||
pub max_bytes: u64,
|
||||
|
||||
/// Default include patterns for searches.
|
||||
pub default_include_globs: Vec<String>,
|
||||
|
||||
/// Default exclude patterns for searches.
|
||||
///
|
||||
/// Common directories to skip like target/, .git/, node_modules/.
|
||||
pub default_exclude_globs: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for SearchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_results: 5_000,
|
||||
max_bytes: 1_000_000, // 1 MB
|
||||
default_include_globs: vec![],
|
||||
default_exclude_globs: vec![
|
||||
"**/target/**".to_string(),
|
||||
"**/.git/**".to_string(),
|
||||
"**/node_modules/**".to_string(),
|
||||
"**/__pycache__/**".to_string(),
|
||||
"**/.venv/**".to_string(),
|
||||
],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Embedding Configuration
|
||||
// =============================================================================
|
||||
|
||||
/// File embedding configuration for ACP prompt context.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct EmbeddingConfig {
|
||||
/// Maximum bytes to embed as ContentBlock::resource per file.
|
||||
///
|
||||
/// Larger files will use resource_link instead. Default: 256,000.
|
||||
pub max_embed_bytes: u64,
|
||||
|
||||
/// Whether to allow resource_link for large files.
|
||||
///
|
||||
/// If false, large files are rejected instead of linked.
|
||||
pub allow_resource_link: bool,
|
||||
|
||||
/// Regex patterns for redacting secrets in embedded content.
|
||||
///
|
||||
/// Best-effort redaction (does not modify files on disk).
|
||||
pub redact_patterns: Vec<String>,
|
||||
|
||||
/// Snippet strategy when file is too large to embed fully.
|
||||
pub snippet_strategy: SnippetStrategy,
|
||||
|
||||
/// Maximum number of files to embed in a single prompt.
|
||||
pub max_files_per_prompt: u32,
|
||||
}
|
||||
|
||||
impl Default for EmbeddingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_embed_bytes: 256_000,
|
||||
allow_resource_link: true,
|
||||
redact_patterns: vec![
|
||||
// Common secret patterns
|
||||
r"(?i)(api[_-]?key|password|secret|token)[:]\s*[']?([a-zA-Z0-9_\-\.]+)[']?".to_string(),
|
||||
],
|
||||
snippet_strategy: SnippetStrategy::HeadTail,
|
||||
max_files_per_prompt: 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Strategy for creating snippets from large files.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SnippetStrategy {
|
||||
/// Include beginning and end of file.
|
||||
HeadTail,
|
||||
/// Include only the beginning.
|
||||
HeadOnly,
|
||||
/// Include only the end.
|
||||
TailOnly,
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Validation
|
||||
// =============================================================================
|
||||
|
||||
impl SandboxConfig {
|
||||
/// Validate the sandbox configuration.
|
||||
///
|
||||
/// Returns an error message if the configuration is invalid.
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.allowed_roots.is_empty() && (self.read_enabled || self.write_enabled) {
|
||||
return Err("allowed_roots cannot be empty when read or write is enabled".to_string());
|
||||
}
|
||||
|
||||
if self.encoding != "utf-8" {
|
||||
return Err(format!("unsupported encoding: {}", self.encoding));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Normalize allowed roots to canonical paths.
|
||||
///
|
||||
/// This should be called after loading configuration to ensure all roots are canonical.
|
||||
/// Panics if any root cannot be canonicalized (configuration error).
|
||||
pub fn normalize_roots(&mut self) {
|
||||
self.allowed_roots = self.allowed_roots
|
||||
.iter()
|
||||
.map(|root| {
|
||||
dunce::canonicalize(root)
|
||||
.unwrap_or_else(|_| panic!("Failed to canonicalize allowed root: {:?}", root))
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
|
||||
impl TerminalConfig {
|
||||
/// Validate the terminal configuration.
|
||||
///
|
||||
/// Returns an error message if the configuration is invalid.
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.output_byte_limit == 0 {
|
||||
return Err("output_byte_limit must be greater than 0".to_string());
|
||||
}
|
||||
|
||||
if self.max_runtime_secs == 0 {
|
||||
return Err("max_runtime_secs must be greater than 0".to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl SearchConfig {
|
||||
/// Validate the search configuration.
|
||||
///
|
||||
/// Returns an error message if the configuration is invalid.
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.max_results == 0 {
|
||||
return Err("max_results must be greater than 0".to_string());
|
||||
}
|
||||
|
||||
if self.max_bytes == 0 {
|
||||
return Err("max_bytes must be greater than 0".to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbeddingConfig {
|
||||
/// Validate the embedding configuration.
|
||||
///
|
||||
/// Returns an error message if the configuration is invalid.
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.max_embed_bytes == 0 {
|
||||
return Err("max_embed_bytes must be greater than 0".to_string());
|
||||
}
|
||||
|
||||
if self.max_files_per_prompt == 0 {
|
||||
return Err("max_files_per_prompt must be greater than 0".to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
//! Dispatch orchestration: registry → SecurityFloor → permission → Tool::run.
|
||||
|
||||
use crate::floor::{FloorDecision, SecurityFloor};
|
||||
use crate::registry::ToolRegistry;
|
||||
use crate::tool::{AnyToolInput, ToolContext, ToolEventSink};
|
||||
|
||||
/// Outcome of a dispatch call. Matches `AnyTool::run`'s return shape:
|
||||
/// `Ok` is a successful structured result, `Err` is a structured error.
|
||||
pub type DispatchResult = Result<serde_json::Value, serde_json::Value>;
|
||||
|
||||
/// Dispatch a tool call through the harness.
|
||||
///
|
||||
/// 1. Resolve the tool via the registry (built-in vs dynamic, with collision
|
||||
/// policy and per-client/per-protocol filters applied).
|
||||
/// 2. Run the hardcoded [`SecurityFloor`]. Cannot be bypassed by settings.
|
||||
/// 3. (Permission check is handled by the caller for now — the existing
|
||||
/// [`crate::permission::check::check_permission`] takes a different
|
||||
/// operation type and is wired in by the connector. This is documented in
|
||||
/// `2026-04-28-tool-harness-design.md`.)
|
||||
/// 4. Run the tool, awaiting its final result.
|
||||
pub async fn dispatch(
|
||||
registry: &ToolRegistry,
|
||||
floor: &SecurityFloor,
|
||||
name: &str,
|
||||
input: serde_json::Value,
|
||||
events: ToolEventSink,
|
||||
ctx: &ToolContext,
|
||||
) -> DispatchResult {
|
||||
let tool = match registry.resolve(name, ctx) {
|
||||
Some(t) => t,
|
||||
None => return Err(serde_json::json!({
|
||||
"error": format!("unknown tool: {name}"),
|
||||
})),
|
||||
};
|
||||
|
||||
if let FloorDecision::Block { reason } = floor.check(name, &input, ctx) {
|
||||
return Err(serde_json::json!({ "error": reason }));
|
||||
}
|
||||
|
||||
tool.run(AnyToolInput::Final(input), events, ctx).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig};
|
||||
use crate::permission::check::PermissionContext;
|
||||
use crate::permission::whitelist::CompiledWhitelist;
|
||||
use crate::registry::CollisionPolicy;
|
||||
use crate::tool::{AnyTool, ClientKind, ProtocolKind, Tool, ToolInput, ToolKind};
|
||||
use async_trait::async_trait;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Serialize, Deserialize, JsonSchema)]
|
||||
struct EchoIn { msg: String }
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct EchoOut { echoed: String }
|
||||
|
||||
#[derive(Default)]
|
||||
struct Echo;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for Echo {
|
||||
type Input = EchoIn;
|
||||
type Output = EchoOut;
|
||||
const NAME: &'static str = "echo";
|
||||
fn kind() -> ToolKind { ToolKind::Other }
|
||||
async fn run(
|
||||
self: Arc<Self>, input: ToolInput<EchoIn>,
|
||||
_e: ToolEventSink, _c: &ToolContext,
|
||||
) -> Result<EchoOut, EchoOut> {
|
||||
let i = match input { ToolInput::Final(i) => i, _ => unreachable!() };
|
||||
Ok(EchoOut { echoed: i.msg })
|
||||
}
|
||||
}
|
||||
|
||||
fn ctx() -> ToolContext {
|
||||
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
|
||||
let pc = PermissionContext::new("c".to_string(), None, wl);
|
||||
ToolContext::for_test(
|
||||
"c", ClientKind::claude(), ProtocolKind::acp(),
|
||||
PathBuf::from("/tmp"),
|
||||
SandboxConfig::default(), PermissionConfig::default(), pc,
|
||||
)
|
||||
}
|
||||
|
||||
fn registry_with_echo() -> ToolRegistry {
|
||||
let any: Arc<dyn AnyTool> = <Echo as Tool>::erase(Arc::new(Echo));
|
||||
ToolRegistry::new(vec![any], CollisionPolicy::BuiltInWins)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_runs_known_tool() {
|
||||
let r = registry_with_echo();
|
||||
let f = SecurityFloor::new();
|
||||
let (sink, _rx) = ToolEventSink::new();
|
||||
let out = dispatch(&r, &f, "echo",
|
||||
serde_json::json!({ "msg": "hi" }), sink, &ctx()).await.unwrap();
|
||||
assert_eq!(out["echoed"], "hi");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_rejects_unknown_tool() {
|
||||
let r = registry_with_echo();
|
||||
let f = SecurityFloor::new();
|
||||
let (sink, _rx) = ToolEventSink::new();
|
||||
let err = dispatch(&r, &f, "nope",
|
||||
serde_json::json!({}), sink, &ctx()).await.unwrap_err();
|
||||
assert!(err["error"].as_str().unwrap().contains("unknown tool"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn dispatch_blocks_on_security_floor() {
|
||||
// We can't easily inject a "terminal" tool here, but we exercise the
|
||||
// floor by registering a tool literally named "terminal".
|
||||
#[derive(Default)] struct Term;
|
||||
#[derive(Serialize, Deserialize, JsonSchema)] struct In { command: String }
|
||||
#[derive(Serialize, Deserialize)] struct Out;
|
||||
#[async_trait]
|
||||
impl Tool for Term {
|
||||
type Input = In; type Output = Out;
|
||||
const NAME: &'static str = "terminal";
|
||||
fn kind() -> ToolKind { ToolKind::Execute }
|
||||
async fn run(
|
||||
self: Arc<Self>, _i: ToolInput<In>,
|
||||
_e: ToolEventSink, _c: &ToolContext,
|
||||
) -> Result<Out, Out> { Ok(Out) }
|
||||
}
|
||||
|
||||
let any: Arc<dyn AnyTool> = <Term as Tool>::erase(Arc::new(Term));
|
||||
let r = ToolRegistry::new(vec![any], CollisionPolicy::BuiltInWins);
|
||||
let f = SecurityFloor::new();
|
||||
let (sink, _rx) = ToolEventSink::new();
|
||||
let err = dispatch(&r, &f, "terminal",
|
||||
serde_json::json!({ "command": "rm -rf /" }), sink, &ctx())
|
||||
.await.unwrap_err();
|
||||
assert!(err["error"].as_str().unwrap().to_lowercase().contains("blocked"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,467 @@
|
||||
//! Embedding decision logic for file attachments.
|
||||
//!
|
||||
//! This module implements the decision matrix for choosing between embedded content,
|
||||
//! resource links, or snippets based on capabilities, size, and configuration.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
use crate::config::EmbeddingConfig;
|
||||
use crate::error::ToolResult;
|
||||
use crate::fs::file_type::detect_file_type;
|
||||
|
||||
/// Strategy for including a file in an ACP prompt.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum EmbeddingStrategy {
|
||||
/// Embed full file content as text resource.
|
||||
EmbedText {
|
||||
/// File content (possibly redacted).
|
||||
content: String,
|
||||
/// MIME type for the content.
|
||||
mime_type: String,
|
||||
},
|
||||
|
||||
/// Embed file as base64-encoded blob.
|
||||
EmbedBlob {
|
||||
/// Base64-encoded binary data.
|
||||
data: Vec<u8>,
|
||||
/// MIME type for the binary data.
|
||||
mime_type: String,
|
||||
},
|
||||
|
||||
/// Create a resource link (don't embed full content).
|
||||
Link {
|
||||
/// URI for the resource (e.g., dirigent://resource/<hash>).
|
||||
uri: String,
|
||||
/// Human-readable name (relative path).
|
||||
name: String,
|
||||
/// File size in bytes.
|
||||
size: u64,
|
||||
/// MIME type (if known).
|
||||
mime_type: Option<String>,
|
||||
},
|
||||
|
||||
/// Embed a snippet (head/tail) with optional link to full file.
|
||||
Snippet {
|
||||
/// Head portion of the file.
|
||||
head: String,
|
||||
/// Tail portion of the file.
|
||||
tail: String,
|
||||
/// Total size of the original file.
|
||||
total_size: u64,
|
||||
/// MIME type.
|
||||
mime_type: String,
|
||||
},
|
||||
|
||||
/// Deny the attachment (too large, blocked, etc.).
|
||||
Deny {
|
||||
/// Reason for denial.
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Embedding decider.
|
||||
///
|
||||
/// Decides how to include files in ACP prompts based on agent capabilities,
|
||||
/// file properties, and configuration limits.
|
||||
pub struct EmbeddingDecider {
|
||||
/// Embedding configuration.
|
||||
config: EmbeddingConfig,
|
||||
/// Whether the agent supports embedded context.
|
||||
agent_supports_embedded: bool,
|
||||
/// Total bytes accumulated across all files in this prompt.
|
||||
accumulated_bytes: usize,
|
||||
/// Number of files processed so far.
|
||||
file_count: usize,
|
||||
}
|
||||
|
||||
impl EmbeddingDecider {
|
||||
/// Create a new embedding decider.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - Embedding configuration with size limits and policies
|
||||
/// * `agent_supports_embedded` - Whether the agent advertised `embeddedContext` capability
|
||||
pub fn new(config: EmbeddingConfig, agent_supports_embedded: bool) -> Self {
|
||||
Self {
|
||||
config,
|
||||
agent_supports_embedded,
|
||||
accumulated_bytes: 0,
|
||||
file_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Decide the embedding strategy for a file.
|
||||
///
|
||||
/// This implements the decision tree from the file embedding policy:
|
||||
/// 1. File type detection
|
||||
/// 2. Size checks (per-file and total accumulated)
|
||||
/// 3. Capability check (embeddedContext)
|
||||
/// 4. Strategy selection (embed text, blob, link, snippet, deny)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the file to embed
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The chosen embedding strategy.
|
||||
pub fn decide(&mut self, path: &Path) -> ToolResult<EmbeddingStrategy> {
|
||||
// Check file count limit
|
||||
if self.file_count >= self.config.max_files_per_prompt as usize {
|
||||
return Ok(EmbeddingStrategy::Deny {
|
||||
reason: format!(
|
||||
"Maximum file count ({}) exceeded",
|
||||
self.config.max_files_per_prompt
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
// Get file metadata
|
||||
let metadata = std::fs::metadata(path)?;
|
||||
let file_size = metadata.len();
|
||||
|
||||
// Detect file type
|
||||
let file_type = detect_file_type(path)?;
|
||||
|
||||
// Build a name for the file (use file name, not full path)
|
||||
let name = path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
// Decide based on file type and capabilities
|
||||
let strategy = if file_type.is_binary {
|
||||
// Binary files: prefer link, allow small blobs if needed
|
||||
if file_size <= self.config.max_embed_bytes && self.agent_supports_embedded {
|
||||
// Small binary - could embed as blob, but prefer link for efficiency
|
||||
if self.config.allow_resource_link {
|
||||
EmbeddingStrategy::Link {
|
||||
uri: format!("dirigent://resource/{}", Self::hash_path(path)),
|
||||
name,
|
||||
size: file_size,
|
||||
mime_type: file_type.mime_type,
|
||||
}
|
||||
} else {
|
||||
// Embed as blob only if linking is disabled
|
||||
let data = std::fs::read(path)?;
|
||||
EmbeddingStrategy::EmbedBlob {
|
||||
data,
|
||||
mime_type: file_type.mime_type.unwrap_or_else(|| "application/octet-stream".to_string()),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Large binary - link only
|
||||
if self.config.allow_resource_link {
|
||||
EmbeddingStrategy::Link {
|
||||
uri: format!("dirigent://resource/{}", Self::hash_path(path)),
|
||||
name,
|
||||
size: file_size,
|
||||
mime_type: file_type.mime_type,
|
||||
}
|
||||
} else {
|
||||
EmbeddingStrategy::Deny {
|
||||
reason: format!("Binary file too large ({} bytes) and linking is disabled", file_size),
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Text files: embed if small and capability supports, otherwise link or snippet
|
||||
if !self.agent_supports_embedded {
|
||||
// Agent doesn't support embedded context - always link
|
||||
if self.config.allow_resource_link {
|
||||
EmbeddingStrategy::Link {
|
||||
uri: format!("dirigent://resource/{}", Self::hash_path(path)),
|
||||
name,
|
||||
size: file_size,
|
||||
mime_type: file_type.mime_type,
|
||||
}
|
||||
} else {
|
||||
EmbeddingStrategy::Deny {
|
||||
reason: "Agent does not support embedded context and linking is disabled".to_string(),
|
||||
}
|
||||
}
|
||||
} else if file_size <= self.config.max_embed_bytes {
|
||||
// Small text file - check total accumulated bytes
|
||||
let new_total = self.accumulated_bytes + file_size as usize;
|
||||
|
||||
// For Phase 03, use max_embed_bytes * max_files_per_prompt as total cap
|
||||
let max_total_bytes = (self.config.max_embed_bytes as usize)
|
||||
* (self.config.max_files_per_prompt as usize);
|
||||
|
||||
if new_total <= max_total_bytes {
|
||||
// Embed the file
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
EmbeddingStrategy::EmbedText {
|
||||
content,
|
||||
mime_type: file_type.mime_type.unwrap_or_else(|| "text/plain; charset=utf-8".to_string()),
|
||||
}
|
||||
} else {
|
||||
// Exceeds total byte cap - link or deny
|
||||
if self.config.allow_resource_link {
|
||||
EmbeddingStrategy::Link {
|
||||
uri: format!("dirigent://resource/{}", Self::hash_path(path)),
|
||||
name,
|
||||
size: file_size,
|
||||
mime_type: file_type.mime_type,
|
||||
}
|
||||
} else {
|
||||
EmbeddingStrategy::Deny {
|
||||
reason: format!("Total embedded bytes would exceed limit ({} bytes)", max_total_bytes),
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Large text file - link or snippet
|
||||
if self.config.allow_resource_link {
|
||||
EmbeddingStrategy::Link {
|
||||
uri: format!("dirigent://resource/{}", Self::hash_path(path)),
|
||||
name,
|
||||
size: file_size,
|
||||
mime_type: file_type.mime_type,
|
||||
}
|
||||
} else if self.config.snippet_strategy != crate::config::SnippetStrategy::HeadTail {
|
||||
// Snippet embedding not configured
|
||||
EmbeddingStrategy::Deny {
|
||||
reason: format!("File too large ({} bytes) for embedding and linking is disabled", file_size),
|
||||
}
|
||||
} else {
|
||||
// Generate snippet (handled in EMBED-05)
|
||||
// For now, deny and indicate snippet is needed
|
||||
EmbeddingStrategy::Deny {
|
||||
reason: "Snippet generation not yet implemented".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Update accumulated bytes if we're embedding
|
||||
match &strategy {
|
||||
EmbeddingStrategy::EmbedText { content, .. } => {
|
||||
self.accumulated_bytes += content.len();
|
||||
self.file_count += 1;
|
||||
}
|
||||
EmbeddingStrategy::EmbedBlob { data, .. } => {
|
||||
self.accumulated_bytes += data.len();
|
||||
self.file_count += 1;
|
||||
}
|
||||
EmbeddingStrategy::Link { .. } => {
|
||||
self.file_count += 1;
|
||||
}
|
||||
EmbeddingStrategy::Snippet { .. } => {
|
||||
self.file_count += 1;
|
||||
}
|
||||
EmbeddingStrategy::Deny { .. } => {
|
||||
// Don't increment count for denied files
|
||||
}
|
||||
}
|
||||
|
||||
Ok(strategy)
|
||||
}
|
||||
|
||||
/// Get the total bytes accumulated so far.
|
||||
pub fn accumulated_bytes(&self) -> usize {
|
||||
self.accumulated_bytes
|
||||
}
|
||||
|
||||
/// Get the number of files processed so far.
|
||||
pub fn file_count(&self) -> usize {
|
||||
self.file_count
|
||||
}
|
||||
|
||||
/// Generate a stable hash for a file path (used in URIs).
|
||||
///
|
||||
/// This creates an opaque, stable identifier for the file.
|
||||
fn hash_path(path: &Path) -> String {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
path.hash(&mut hasher);
|
||||
format!("{:x}", hasher.finish())
|
||||
}
|
||||
}
|
||||
|
||||
/// Standalone function for simple decision-making (no state tracking).
|
||||
///
|
||||
/// This is a convenience wrapper for one-off decisions without state.
|
||||
pub fn decide_embedding_strategy(
|
||||
path: &Path,
|
||||
agent_supports_embedded: bool,
|
||||
config: &EmbeddingConfig,
|
||||
) -> ToolResult<EmbeddingStrategy> {
|
||||
let mut decider = EmbeddingDecider::new(config.clone(), agent_supports_embedded);
|
||||
decider.decide(path)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::SnippetStrategy;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
fn default_config() -> EmbeddingConfig {
|
||||
EmbeddingConfig {
|
||||
max_embed_bytes: 256_000,
|
||||
allow_resource_link: true,
|
||||
redact_patterns: vec![],
|
||||
snippet_strategy: SnippetStrategy::HeadTail,
|
||||
max_files_per_prompt: 10,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_small_text_file_with_capability() {
|
||||
let mut temp = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
temp.write_all(b"Hello, world!").unwrap();
|
||||
|
||||
let config = default_config();
|
||||
let mut decider = EmbeddingDecider::new(config, true);
|
||||
|
||||
let strategy = decider.decide(temp.path()).unwrap();
|
||||
|
||||
match strategy {
|
||||
EmbeddingStrategy::EmbedText { content, mime_type } => {
|
||||
assert_eq!(content, "Hello, world!");
|
||||
assert!(mime_type.contains("text"));
|
||||
}
|
||||
_ => panic!("Expected EmbedText strategy"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_small_text_file_without_capability() {
|
||||
let mut temp = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
temp.write_all(b"Hello, world!").unwrap();
|
||||
|
||||
let config = default_config();
|
||||
let mut decider = EmbeddingDecider::new(config, false);
|
||||
|
||||
let strategy = decider.decide(temp.path()).unwrap();
|
||||
|
||||
match strategy {
|
||||
EmbeddingStrategy::Link { .. } => {
|
||||
// Correct - should link when capability is missing
|
||||
}
|
||||
_ => panic!("Expected Link strategy when capability is missing"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_text_file_creates_link() {
|
||||
let mut temp = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
// Create file larger than max_embed_bytes
|
||||
let large_content = "x".repeat(300_000);
|
||||
temp.write_all(large_content.as_bytes()).unwrap();
|
||||
|
||||
let config = default_config();
|
||||
let mut decider = EmbeddingDecider::new(config, true);
|
||||
|
||||
let strategy = decider.decide(temp.path()).unwrap();
|
||||
|
||||
match strategy {
|
||||
EmbeddingStrategy::Link { size, .. } => {
|
||||
assert_eq!(size, 300_000);
|
||||
}
|
||||
_ => panic!("Expected Link strategy for large file"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_file_creates_link() {
|
||||
let mut temp = NamedTempFile::with_suffix(".png").unwrap();
|
||||
temp.write_all(b"\x89PNG\r\n\x1a\n").unwrap();
|
||||
|
||||
let config = default_config();
|
||||
let mut decider = EmbeddingDecider::new(config, true);
|
||||
|
||||
let strategy = decider.decide(temp.path()).unwrap();
|
||||
|
||||
match strategy {
|
||||
EmbeddingStrategy::Link { mime_type, .. } => {
|
||||
assert_eq!(mime_type, Some("image/png".to_string()));
|
||||
}
|
||||
_ => panic!("Expected Link strategy for binary file"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_file_count_exceeded() {
|
||||
let mut config = default_config();
|
||||
config.max_files_per_prompt = 2;
|
||||
|
||||
let mut decider = EmbeddingDecider::new(config, true);
|
||||
|
||||
// Process 2 files successfully
|
||||
let mut temp1 = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
temp1.write_all(b"File 1").unwrap();
|
||||
let _ = decider.decide(temp1.path()).unwrap();
|
||||
|
||||
let mut temp2 = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
temp2.write_all(b"File 2").unwrap();
|
||||
let _ = decider.decide(temp2.path()).unwrap();
|
||||
|
||||
// Third file should be denied
|
||||
let mut temp3 = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
temp3.write_all(b"File 3").unwrap();
|
||||
let strategy = decider.decide(temp3.path()).unwrap();
|
||||
|
||||
match strategy {
|
||||
EmbeddingStrategy::Deny { reason } => {
|
||||
assert!(reason.contains("Maximum file count"));
|
||||
}
|
||||
_ => panic!("Expected Deny strategy when file count exceeded"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linking_disabled_denies_large_file() {
|
||||
let mut temp = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
let large_content = "x".repeat(300_000);
|
||||
temp.write_all(large_content.as_bytes()).unwrap();
|
||||
|
||||
let mut config = default_config();
|
||||
config.allow_resource_link = false;
|
||||
|
||||
let mut decider = EmbeddingDecider::new(config, true);
|
||||
|
||||
let strategy = decider.decide(temp.path()).unwrap();
|
||||
|
||||
match strategy {
|
||||
EmbeddingStrategy::Deny { reason } => {
|
||||
assert!(reason.contains("linking is disabled"));
|
||||
}
|
||||
_ => panic!("Expected Deny when linking disabled and file too large"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accumulated_bytes_tracking() {
|
||||
let config = default_config();
|
||||
let mut decider = EmbeddingDecider::new(config, true);
|
||||
|
||||
assert_eq!(decider.accumulated_bytes(), 0);
|
||||
|
||||
let mut temp1 = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
temp1.write_all(b"Hello").unwrap();
|
||||
let _ = decider.decide(temp1.path()).unwrap();
|
||||
|
||||
assert_eq!(decider.accumulated_bytes(), 5);
|
||||
|
||||
let mut temp2 = NamedTempFile::with_suffix(".txt").unwrap();
|
||||
temp2.write_all(b"World!").unwrap();
|
||||
let _ = decider.decide(temp2.path()).unwrap();
|
||||
|
||||
assert_eq!(decider.accumulated_bytes(), 11);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_path_is_stable() {
|
||||
let path = Path::new("/test/file.txt");
|
||||
let hash1 = EmbeddingDecider::hash_path(path);
|
||||
let hash2 = EmbeddingDecider::hash_path(path);
|
||||
|
||||
assert_eq!(hash1, hash2);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
//! File embedding utilities for ACP prompts.
|
||||
//!
|
||||
//! This module provides the core logic for deciding how to include file content
|
||||
//! in ACP prompts based on agent capabilities, file size, and configuration.
|
||||
|
||||
pub mod decider;
|
||||
pub mod redactor;
|
||||
pub mod snippet;
|
||||
|
||||
pub use decider::{decide_embedding_strategy, EmbeddingStrategy, EmbeddingDecider};
|
||||
pub use redactor::{ContentRedactor, RedactionPattern};
|
||||
pub use snippet::{generate_snippet, Snippet, SnippetConfig};
|
||||
@@ -0,0 +1,382 @@
|
||||
//! Content redaction for embedded files.
|
||||
//!
|
||||
//! This module implements pattern-based redaction to prevent secrets from
|
||||
//! being included in embedded file content. Redaction only affects the
|
||||
//! payload sent to the agent - files on disk are never modified.
|
||||
|
||||
use regex::Regex;
|
||||
|
||||
use crate::error::ToolResult;
|
||||
|
||||
/// A single redaction pattern with name and replacement text.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RedactionPattern {
|
||||
/// Human-readable name for this pattern (e.g., "API_KEY").
|
||||
pub name: String,
|
||||
/// Compiled regular expression to match.
|
||||
pub regex: Regex,
|
||||
/// Replacement text (e.g., "<REDACTED:API_KEY>").
|
||||
pub replacement: String,
|
||||
}
|
||||
|
||||
impl RedactionPattern {
|
||||
/// Create a new redaction pattern.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `name` - Human-readable name
|
||||
/// * `pattern` - Regular expression pattern
|
||||
/// * `replacement` - Replacement text (can include pattern name)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A compiled redaction pattern, or an error if the regex is invalid.
|
||||
pub fn new(
|
||||
name: impl Into<String>,
|
||||
pattern: &str,
|
||||
replacement: impl Into<String>,
|
||||
) -> ToolResult<Self> {
|
||||
let name = name.into();
|
||||
let regex = Regex::new(pattern).map_err(|e| {
|
||||
crate::error::ToolError::InvalidInput(format!("Invalid redaction pattern: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
name,
|
||||
regex,
|
||||
replacement: replacement.into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a redaction pattern with default "<REDACTED:NAME>" replacement.
|
||||
pub fn with_default_replacement(name: impl Into<String>, pattern: &str) -> ToolResult<Self> {
|
||||
let name_str = name.into();
|
||||
let replacement = format!("<REDACTED:{}>", name_str.to_uppercase());
|
||||
Self::new(name_str, pattern, replacement)
|
||||
}
|
||||
}
|
||||
|
||||
/// Content redactor for secret patterns.
|
||||
///
|
||||
/// Applies a set of redaction patterns to content before embedding.
|
||||
pub struct ContentRedactor {
|
||||
/// Redaction patterns to apply.
|
||||
patterns: Vec<RedactionPattern>,
|
||||
}
|
||||
|
||||
impl ContentRedactor {
|
||||
/// Create a new content redactor with the given patterns.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `pattern_strings` - List of regex pattern strings from configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A content redactor, or an error if any pattern is invalid.
|
||||
pub fn new(pattern_strings: &[String]) -> ToolResult<Self> {
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
for (i, pattern_str) in pattern_strings.iter().enumerate() {
|
||||
let pattern = RedactionPattern::with_default_replacement(
|
||||
format!("CUSTOM_{}", i),
|
||||
pattern_str,
|
||||
)?;
|
||||
patterns.push(pattern);
|
||||
}
|
||||
|
||||
Ok(Self { patterns })
|
||||
}
|
||||
|
||||
/// Create a redactor with default built-in patterns.
|
||||
///
|
||||
/// Default patterns include:
|
||||
/// - API keys
|
||||
/// - AWS credentials
|
||||
/// - Generic secrets and tokens
|
||||
/// - Passwords
|
||||
pub fn with_default_patterns() -> Self {
|
||||
let mut patterns = Vec::new();
|
||||
|
||||
// API keys (various formats)
|
||||
if let Ok(p) = RedactionPattern::with_default_replacement(
|
||||
"API_KEY",
|
||||
r#"(?i)(api[_-]?key|apikey)[:=\s]+["']?([a-zA-Z0-9_\-\.]{20,})["']?"#,
|
||||
) {
|
||||
patterns.push(p);
|
||||
}
|
||||
|
||||
// AWS access key IDs
|
||||
if let Ok(p) = RedactionPattern::with_default_replacement(
|
||||
"AWS_ACCESS_KEY",
|
||||
r"AKIA[0-9A-Z]{16}",
|
||||
) {
|
||||
patterns.push(p);
|
||||
}
|
||||
|
||||
// Generic secrets and tokens
|
||||
if let Ok(p) = RedactionPattern::with_default_replacement(
|
||||
"SECRET",
|
||||
r#"(?i)(secret|token)[:=\s]+["']?([a-zA-Z0-9_\-\.]{16,})["']?"#,
|
||||
) {
|
||||
patterns.push(p);
|
||||
}
|
||||
|
||||
// Passwords
|
||||
if let Ok(p) = RedactionPattern::with_default_replacement(
|
||||
"PASSWORD",
|
||||
r#"(?i)password[:=\s]+["']?([^\s"']{8,})["']?"#,
|
||||
) {
|
||||
patterns.push(p);
|
||||
}
|
||||
|
||||
// Bearer tokens
|
||||
if let Ok(p) = RedactionPattern::with_default_replacement(
|
||||
"BEARER_TOKEN",
|
||||
r"(?i)bearer\s+([a-zA-Z0-9_\-\.=]+)",
|
||||
) {
|
||||
patterns.push(p);
|
||||
}
|
||||
|
||||
Self { patterns }
|
||||
}
|
||||
|
||||
/// Create a redactor combining default and custom patterns.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `custom_patterns` - Additional custom pattern strings
|
||||
pub fn with_custom_patterns(custom_patterns: &[String]) -> ToolResult<Self> {
|
||||
let mut redactor = Self::with_default_patterns();
|
||||
|
||||
for (i, pattern_str) in custom_patterns.iter().enumerate() {
|
||||
let pattern = RedactionPattern::with_default_replacement(
|
||||
format!("CUSTOM_{}", i),
|
||||
pattern_str,
|
||||
)?;
|
||||
redactor.patterns.push(pattern);
|
||||
}
|
||||
|
||||
Ok(redactor)
|
||||
}
|
||||
|
||||
/// Redact sensitive content from the given text.
|
||||
///
|
||||
/// Applies all configured redaction patterns in order and returns
|
||||
/// the redacted content. The original content is never modified.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `content` - The content to redact
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Redacted content with sensitive data replaced.
|
||||
pub fn redact(&self, content: &str) -> String {
|
||||
let mut redacted = content.to_string();
|
||||
|
||||
for pattern in &self.patterns {
|
||||
redacted = pattern.regex.replace_all(&redacted, &pattern.replacement).to_string();
|
||||
}
|
||||
|
||||
// Ensure UTF-8 safety (should always be valid since we're only replacing with ASCII)
|
||||
debug_assert!(redacted.is_char_boundary(redacted.len()));
|
||||
|
||||
redacted
|
||||
}
|
||||
|
||||
/// Check if content contains any patterns that would be redacted.
|
||||
///
|
||||
/// This can be used to warn users before embedding.
|
||||
pub fn contains_secrets(&self, content: &str) -> bool {
|
||||
self.patterns.iter().any(|p| p.regex.is_match(content))
|
||||
}
|
||||
|
||||
/// Get the number of redaction patterns configured.
|
||||
pub fn pattern_count(&self) -> usize {
|
||||
self.patterns.len()
|
||||
}
|
||||
}
|
||||
|
||||
// Implement a safe Debug that doesn't leak pattern details
|
||||
impl std::fmt::Debug for ContentRedactor {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("ContentRedactor")
|
||||
.field("pattern_count", &self.patterns.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_redaction_pattern_creation() {
|
||||
let pattern = RedactionPattern::new("TEST", r"\d+", "XXX").unwrap();
|
||||
|
||||
assert_eq!(pattern.name, "TEST");
|
||||
assert_eq!(pattern.replacement, "XXX");
|
||||
assert!(pattern.regex.is_match("123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_redaction_pattern_with_default_replacement() {
|
||||
let pattern = RedactionPattern::with_default_replacement("api_key", r"key_\d+").unwrap();
|
||||
|
||||
assert_eq!(pattern.name, "api_key");
|
||||
assert_eq!(pattern.replacement, "<REDACTED:API_KEY>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_redaction_pattern_invalid_regex() {
|
||||
let result = RedactionPattern::new("TEST", r"[invalid(", "XXX");
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_patterns_api_key() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
|
||||
let content = "API_KEY=sk_live_1234567890abcdefghij";
|
||||
let redacted = redactor.redact(content);
|
||||
|
||||
assert!(redacted.contains("<REDACTED:"));
|
||||
assert!(!redacted.contains("sk_live_"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_patterns_aws_key() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
|
||||
let content = "AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE";
|
||||
let redacted = redactor.redact(content);
|
||||
|
||||
assert!(redacted.contains("<REDACTED:"));
|
||||
assert!(!redacted.contains("AKIAIOSFODNN7EXAMPLE"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_patterns_secret() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
|
||||
let content = "secret: my_super_secret_value_12345";
|
||||
let redacted = redactor.redact(content);
|
||||
|
||||
assert!(redacted.contains("<REDACTED:"));
|
||||
assert!(!redacted.contains("my_super_secret"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_patterns_password() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
|
||||
let content = "password=MyP@ssw0rd123";
|
||||
let redacted = redactor.redact(content);
|
||||
|
||||
assert!(redacted.contains("<REDACTED:"));
|
||||
assert!(!redacted.contains("MyP@ssw0rd"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_patterns_bearer_token() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
|
||||
let content = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9";
|
||||
let redacted = redactor.redact(content);
|
||||
|
||||
assert!(redacted.contains("<REDACTED:"));
|
||||
assert!(!redacted.contains("eyJhbGci"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_custom_patterns() {
|
||||
let custom = vec![r"CUSTOM-\d{4}".to_string()];
|
||||
let redactor = ContentRedactor::with_custom_patterns(&custom).unwrap();
|
||||
|
||||
let content = "My custom ID is CUSTOM-1234";
|
||||
let redacted = redactor.redact(content);
|
||||
|
||||
assert!(redacted.contains("<REDACTED:"));
|
||||
assert!(!redacted.contains("CUSTOM-1234"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_patterns() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
|
||||
let content = r#"
|
||||
API_KEY="sk_test_abcdefghijklmnopqrst"
|
||||
password=secret123
|
||||
token: my_long_secret_token_value
|
||||
"#;
|
||||
|
||||
let redacted = redactor.redact(content);
|
||||
|
||||
// All secrets should be redacted
|
||||
assert!(!redacted.contains("sk_test_"));
|
||||
assert!(!redacted.contains("secret123"));
|
||||
assert!(!redacted.contains("my_long_secret"));
|
||||
assert!(redacted.contains("<REDACTED:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contains_secrets() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
|
||||
assert!(redactor.contains_secrets("API_KEY=sk_test_12345678901234567890"));
|
||||
assert!(redactor.contains_secrets("password=secret"));
|
||||
assert!(!redactor.contains_secrets("Hello, world!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_utf8_safety() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
|
||||
let content = "API_KEY=sk_test_12345678901234567890 你好世界";
|
||||
let redacted = redactor.redact(content);
|
||||
|
||||
// Should still contain valid UTF-8
|
||||
assert!(redacted.contains("你好世界"));
|
||||
assert!(std::str::from_utf8(redacted.as_bytes()).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_false_positives_on_code() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
|
||||
let code = r#"
|
||||
fn api_key_validator(input: &str) -> bool {
|
||||
input.len() > 20
|
||||
}
|
||||
"#;
|
||||
|
||||
let redacted = redactor.redact(code);
|
||||
|
||||
// Function name and code should not be redacted
|
||||
assert!(redacted.contains("api_key_validator"));
|
||||
assert!(redacted.contains("fn"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pattern_count() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
assert_eq!(redactor.pattern_count(), 5); // 5 default patterns
|
||||
|
||||
let custom = vec![r"CUSTOM-\d{4}".to_string()];
|
||||
let redactor_custom = ContentRedactor::with_custom_patterns(&custom).unwrap();
|
||||
assert_eq!(redactor_custom.pattern_count(), 6); // 5 default + 1 custom
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_debug_impl_safe() {
|
||||
let redactor = ContentRedactor::with_default_patterns();
|
||||
let debug_str = format!("{:?}", redactor);
|
||||
|
||||
// Should not leak actual pattern details
|
||||
assert!(debug_str.contains("ContentRedactor"));
|
||||
assert!(debug_str.contains("pattern_count"));
|
||||
assert!(!debug_str.contains("regex")); // Internal details hidden
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
//! Snippet generation for large files.
|
||||
//!
|
||||
//! This module implements partial selection strategies (head/tail/window)
|
||||
//! for files exceeding embed size limits.
|
||||
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::path::Path;
|
||||
|
||||
use crate::error::ToolResult;
|
||||
|
||||
/// Configuration for snippet generation.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SnippetConfig {
|
||||
/// Number of lines to include from the beginning.
|
||||
pub head_lines: usize,
|
||||
/// Number of lines to include from the end.
|
||||
pub tail_lines: usize,
|
||||
/// Maximum total bytes for the snippet.
|
||||
pub max_snippet_bytes: usize,
|
||||
}
|
||||
|
||||
impl Default for SnippetConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
head_lines: 200,
|
||||
tail_lines: 200,
|
||||
max_snippet_bytes: 128_000, // Half of default max_embed_bytes
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A snippet extracted from a file.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Snippet {
|
||||
/// First portion of the file (head).
|
||||
pub head: Option<String>,
|
||||
/// Last portion of the file (tail).
|
||||
pub tail: Option<String>,
|
||||
/// Total size of the original file in bytes.
|
||||
pub total_size: u64,
|
||||
/// Total number of lines in the original file.
|
||||
pub total_lines: usize,
|
||||
/// Whether the snippet is truncated (doesn't show full file).
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
impl Snippet {
|
||||
/// Render the snippet with metadata comments.
|
||||
///
|
||||
/// Returns a formatted string with:
|
||||
/// - Metadata header showing file info and truncation
|
||||
/// - Head content
|
||||
/// - Separator (if both head and tail present)
|
||||
/// - Tail content
|
||||
pub fn render(&self, file_path: &Path) -> String {
|
||||
let file_name = file_path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let mut output = String::new();
|
||||
|
||||
// Metadata header
|
||||
output.push_str(&format!("# File: {} (truncated)\n", file_name));
|
||||
output.push_str(&format!(
|
||||
"# Total size: {:.1} KB, {} lines\n",
|
||||
self.total_size as f64 / 1024.0,
|
||||
self.total_lines
|
||||
));
|
||||
|
||||
if self.head.is_some() && self.tail.is_some() {
|
||||
output.push_str(&format!(
|
||||
"# Showing: first {} lines + last {} lines\n\n",
|
||||
self.head.as_ref().unwrap().lines().count(),
|
||||
self.tail.as_ref().unwrap().lines().count()
|
||||
));
|
||||
} else if self.head.is_some() {
|
||||
output.push_str(&format!(
|
||||
"# Showing: first {} lines\n\n",
|
||||
self.head.as_ref().unwrap().lines().count()
|
||||
));
|
||||
} else if self.tail.is_some() {
|
||||
output.push_str(&format!(
|
||||
"# Showing: last {} lines\n\n",
|
||||
self.tail.as_ref().unwrap().lines().count()
|
||||
));
|
||||
}
|
||||
|
||||
// Head content
|
||||
if let Some(ref head) = self.head {
|
||||
output.push_str(head);
|
||||
if !head.ends_with('\n') {
|
||||
output.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
// Separator if both head and tail
|
||||
if self.head.is_some() && self.tail.is_some() {
|
||||
output.push_str("\n# ... (middle content omitted) ...\n\n");
|
||||
}
|
||||
|
||||
// Tail content
|
||||
if let Some(ref tail) = self.tail {
|
||||
output.push_str(tail);
|
||||
if !tail.ends_with('\n') {
|
||||
output.push('\n');
|
||||
}
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a snippet from a file.
|
||||
///
|
||||
/// Reads the first N and last N lines from the file, respecting UTF-8
|
||||
/// boundaries and byte limits.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the file
|
||||
/// * `config` - Snippet configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A snippet with head/tail content and metadata.
|
||||
pub fn generate_snippet(path: &Path, config: &SnippetConfig) -> ToolResult<Snippet> {
|
||||
let metadata = std::fs::metadata(path)?;
|
||||
let total_size = metadata.len();
|
||||
|
||||
// Read the full file to count lines and extract head/tail
|
||||
let file = File::open(path)?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
let mut all_lines: Vec<String> = Vec::new();
|
||||
for line in reader.lines() {
|
||||
all_lines.push(line?);
|
||||
}
|
||||
|
||||
let total_lines = all_lines.len();
|
||||
|
||||
// Check if file is small enough to not need truncation
|
||||
if total_lines <= config.head_lines + config.tail_lines {
|
||||
return Ok(Snippet {
|
||||
head: Some(all_lines.join("\n")),
|
||||
tail: None,
|
||||
total_size,
|
||||
total_lines,
|
||||
truncated: false,
|
||||
});
|
||||
}
|
||||
|
||||
// Extract head and tail
|
||||
let head_content = all_lines
|
||||
.iter()
|
||||
.take(config.head_lines)
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
let tail_content = all_lines
|
||||
.iter()
|
||||
.skip(all_lines.len().saturating_sub(config.tail_lines))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
// Check if combined snippet exceeds max bytes
|
||||
let combined_bytes = head_content.len() + tail_content.len();
|
||||
if combined_bytes > config.max_snippet_bytes {
|
||||
// Reduce to head only if combined is too large
|
||||
let truncated_head = truncate_to_bytes(&head_content, config.max_snippet_bytes);
|
||||
Ok(Snippet {
|
||||
head: Some(truncated_head),
|
||||
tail: None,
|
||||
total_size,
|
||||
total_lines,
|
||||
truncated: true,
|
||||
})
|
||||
} else {
|
||||
Ok(Snippet {
|
||||
head: Some(head_content),
|
||||
tail: Some(tail_content),
|
||||
total_size,
|
||||
total_lines,
|
||||
truncated: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncate a string to fit within a byte limit, respecting UTF-8 boundaries.
|
||||
fn truncate_to_bytes(s: &str, max_bytes: usize) -> String {
|
||||
if s.len() <= max_bytes {
|
||||
return s.to_string();
|
||||
}
|
||||
|
||||
// Find the largest valid UTF-8 boundary within max_bytes
|
||||
let mut boundary = max_bytes;
|
||||
while boundary > 0 && !s.is_char_boundary(boundary) {
|
||||
boundary -= 1;
|
||||
}
|
||||
|
||||
s[..boundary].to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_snippet_config_default() {
|
||||
let config = SnippetConfig::default();
|
||||
|
||||
assert_eq!(config.head_lines, 200);
|
||||
assert_eq!(config.tail_lines, 200);
|
||||
assert_eq!(config.max_snippet_bytes, 128_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_snippet_small_file() {
|
||||
let mut temp = NamedTempFile::new().unwrap();
|
||||
temp.write_all(b"Line 1\nLine 2\nLine 3\n").unwrap();
|
||||
|
||||
let config = SnippetConfig {
|
||||
head_lines: 10,
|
||||
tail_lines: 10,
|
||||
max_snippet_bytes: 10_000,
|
||||
};
|
||||
|
||||
let snippet = generate_snippet(temp.path(), &config).unwrap();
|
||||
|
||||
assert!(!snippet.truncated);
|
||||
assert_eq!(snippet.total_lines, 3);
|
||||
assert!(snippet.head.is_some());
|
||||
assert!(snippet.tail.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_snippet_large_file() {
|
||||
let mut temp = NamedTempFile::new().unwrap();
|
||||
for i in 1..=1000 {
|
||||
writeln!(temp, "Line {}", i).unwrap();
|
||||
}
|
||||
|
||||
let config = SnippetConfig {
|
||||
head_lines: 10,
|
||||
tail_lines: 10,
|
||||
max_snippet_bytes: 100_000,
|
||||
};
|
||||
|
||||
let snippet = generate_snippet(temp.path(), &config).unwrap();
|
||||
|
||||
assert!(snippet.truncated);
|
||||
assert_eq!(snippet.total_lines, 1000);
|
||||
assert!(snippet.head.is_some());
|
||||
assert!(snippet.tail.is_some());
|
||||
|
||||
// Head should have first 10 lines
|
||||
let head = snippet.head.unwrap();
|
||||
assert!(head.contains("Line 1"));
|
||||
assert!(head.contains("Line 10"));
|
||||
assert!(!head.contains("Line 11"));
|
||||
|
||||
// Tail should have last 10 lines
|
||||
let tail = snippet.tail.unwrap();
|
||||
assert!(tail.contains("Line 1000"));
|
||||
assert!(tail.contains("Line 991"));
|
||||
assert!(!tail.contains("Line 990"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snippet_render() {
|
||||
let snippet = Snippet {
|
||||
head: Some("Line 1\nLine 2".to_string()),
|
||||
tail: Some("Line 99\nLine 100".to_string()),
|
||||
total_size: 10_000,
|
||||
total_lines: 100,
|
||||
truncated: true,
|
||||
};
|
||||
|
||||
let rendered = snippet.render(Path::new("/test/file.txt"));
|
||||
|
||||
assert!(rendered.contains("# File: file.txt (truncated)"));
|
||||
assert!(rendered.contains("# Total size:"));
|
||||
assert!(rendered.contains("100 lines"));
|
||||
assert!(rendered.contains("Line 1"));
|
||||
assert!(rendered.contains("Line 100"));
|
||||
assert!(rendered.contains("... (middle content omitted) ..."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snippet_render_head_only() {
|
||||
let snippet = Snippet {
|
||||
head: Some("Line 1\nLine 2".to_string()),
|
||||
tail: None,
|
||||
total_size: 1_000,
|
||||
total_lines: 10,
|
||||
truncated: true,
|
||||
};
|
||||
|
||||
let rendered = snippet.render(Path::new("/test/file.txt"));
|
||||
|
||||
assert!(rendered.contains("Line 1"));
|
||||
assert!(!rendered.contains("... (middle content omitted) ..."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_to_bytes_exact() {
|
||||
let s = "Hello";
|
||||
let truncated = truncate_to_bytes(s, 5);
|
||||
assert_eq!(truncated, "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_to_bytes_shorter() {
|
||||
let s = "Hello, world!";
|
||||
let truncated = truncate_to_bytes(s, 5);
|
||||
assert_eq!(truncated, "Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_to_bytes_utf8_boundary() {
|
||||
let s = "Hello 你好世界";
|
||||
// Try to truncate in the middle of a multibyte character
|
||||
let truncated = truncate_to_bytes(s, 8);
|
||||
// Should truncate before the multibyte char to maintain UTF-8 validity
|
||||
assert_eq!(truncated, "Hello ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snippet_exceeds_max_bytes() {
|
||||
let mut temp = NamedTempFile::new().unwrap();
|
||||
for i in 1..=1000 {
|
||||
writeln!(temp, "This is a very long line number {} with lots of text", i).unwrap();
|
||||
}
|
||||
|
||||
let config = SnippetConfig {
|
||||
head_lines: 500,
|
||||
tail_lines: 500,
|
||||
max_snippet_bytes: 1_000, // Very small limit
|
||||
};
|
||||
|
||||
let snippet = generate_snippet(temp.path(), &config).unwrap();
|
||||
|
||||
assert!(snippet.truncated);
|
||||
// Should fall back to head only when combined exceeds max
|
||||
assert!(snippet.head.is_some());
|
||||
let head = snippet.head.unwrap();
|
||||
assert!(head.len() <= config.max_snippet_bytes);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
//! Error types for tool operations.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type for tool operations.
|
||||
pub type ToolResult<T> = Result<T, ToolError>;
|
||||
|
||||
/// Errors that can occur during tool operations.
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ToolError {
|
||||
/// File or directory not found.
|
||||
#[error("Not found: {path}")]
|
||||
NotFound { path: String },
|
||||
|
||||
/// Permission denied by OS or sandbox policy.
|
||||
#[error("Permission denied: {reason}")]
|
||||
PermissionDenied { reason: String },
|
||||
|
||||
/// Path outside allowed sandbox roots.
|
||||
#[error("Sandbox violation: {reason}")]
|
||||
SandboxViolation { reason: String },
|
||||
|
||||
/// Path matched blocklist pattern.
|
||||
#[error("Blocked path: {reason}")]
|
||||
BlockedPath { reason: String },
|
||||
|
||||
/// File too large to process.
|
||||
#[error("File too large: {size} bytes exceeds limit of {limit} bytes")]
|
||||
FileTooLarge { size: u64, limit: u64 },
|
||||
|
||||
/// Invalid encoding (non-UTF-8).
|
||||
#[error("Encoding not supported: {encoding}")]
|
||||
EncodingUnsupported { encoding: String },
|
||||
|
||||
/// User rejected permission prompt.
|
||||
#[error("Permission rejected by user")]
|
||||
PermissionRejected,
|
||||
|
||||
/// Terminal operation failed.
|
||||
#[error("Terminal error: {message}")]
|
||||
TerminalError { message: String },
|
||||
|
||||
/// Terminal not found.
|
||||
#[error("Terminal not found: {terminal_id}")]
|
||||
TerminalNotFound { terminal_id: String },
|
||||
|
||||
/// Search operation exceeded limits.
|
||||
#[error("Search limit exceeded: {reason}")]
|
||||
SearchLimitExceeded { reason: String },
|
||||
|
||||
/// Invalid configuration.
|
||||
#[error("Invalid configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
/// File read error with detailed information.
|
||||
#[error("Failed to read file {path}: {source}")]
|
||||
FileReadError {
|
||||
path: String,
|
||||
#[source]
|
||||
source: std::io::Error,
|
||||
},
|
||||
|
||||
/// Invalid input or parameters.
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
/// I/O error.
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// JSON serialization error.
|
||||
#[error("JSON error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
/// Generic error.
|
||||
#[error("{0}")]
|
||||
Other(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
impl ToolError {
|
||||
/// Create a sandbox violation error without exposing the full path.
|
||||
pub fn sandbox_violation(reason: impl Into<String>) -> Self {
|
||||
Self::SandboxViolation {
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a blocked path error without exposing the full path.
|
||||
pub fn blocked_path(reason: impl Into<String>) -> Self {
|
||||
Self::BlockedPath {
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a permission denied error.
|
||||
pub fn permission_denied(reason: impl Into<String>) -> Self {
|
||||
Self::PermissionDenied {
|
||||
reason: reason.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a terminal error.
|
||||
pub fn terminal_error(message: impl Into<String>) -> Self {
|
||||
Self::TerminalError {
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
//! Hardcoded security floor. Settings cannot bypass these rules.
|
||||
//!
|
||||
//! Initial rule set lifted from Zed's `HARDCODED_SECURITY_RULES`:
|
||||
//! catastrophic recursive deletes (`rm -rf /`, `~`, `$HOME`, `.`, `..`).
|
||||
|
||||
pub mod shell;
|
||||
|
||||
use crate::tool::ToolContext;
|
||||
use regex::Regex;
|
||||
use std::sync::OnceLock;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum FloorDecision {
|
||||
Pass,
|
||||
Block { reason: &'static str },
|
||||
}
|
||||
|
||||
pub struct SecurityFloor {
|
||||
terminal_deny: Vec<Regex>,
|
||||
}
|
||||
|
||||
impl SecurityFloor {
|
||||
pub fn new() -> Self {
|
||||
Self { terminal_deny: terminal_deny_patterns().clone() }
|
||||
}
|
||||
|
||||
/// Check whether the given tool invocation hits a hard rule.
|
||||
/// `tool` is the tool name; `input` is the JSON arguments object.
|
||||
pub fn check(&self, tool: &str, input: &serde_json::Value, _ctx: &ToolContext) -> FloorDecision {
|
||||
if tool != "terminal" && tool != "bash" {
|
||||
return FloorDecision::Pass;
|
||||
}
|
||||
let Some(cmd) = extract_command(input) else { return FloorDecision::Pass };
|
||||
|
||||
// Check the raw command first
|
||||
if self.matches_any(cmd) {
|
||||
return FloorDecision::Block {
|
||||
reason: "Blocked by built-in security rule (recursive delete of \
|
||||
root, home, current, or parent directory).",
|
||||
};
|
||||
}
|
||||
// Then each sub-command in a chain
|
||||
for sub in shell::split_chain(cmd) {
|
||||
if self.matches_any(sub) {
|
||||
return FloorDecision::Block {
|
||||
reason: "Blocked by built-in security rule (chained recursive \
|
||||
delete detected).",
|
||||
};
|
||||
}
|
||||
}
|
||||
if interpolation_pattern().is_match(cmd) {
|
||||
return FloorDecision::Block {
|
||||
reason: "Blocked: shell substitutions/interpolations are not allowed \
|
||||
in terminal commands. Resolve $VAR, ${VAR}, $(...), backticks, \
|
||||
$((...)), <(...), >(...) before calling.",
|
||||
};
|
||||
}
|
||||
FloorDecision::Pass
|
||||
}
|
||||
|
||||
fn matches_any(&self, cmd: &str) -> bool {
|
||||
self.terminal_deny.iter().any(|re| re.is_match(cmd))
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecurityFloor { fn default() -> Self { Self::new() } }
|
||||
|
||||
fn extract_command(input: &serde_json::Value) -> Option<&str> {
|
||||
input.get("command").and_then(|v| v.as_str())
|
||||
}
|
||||
|
||||
fn terminal_deny_patterns() -> &'static Vec<Regex> {
|
||||
static SET: OnceLock<Vec<Regex>> = OnceLock::new();
|
||||
SET.get_or_init(|| {
|
||||
const FLAGS: &str = r"(--[a-zA-Z0-9][-a-zA-Z0-9_]*(=[^\s]*)?\s+|-[a-zA-Z]+\s+)*";
|
||||
const TRAIL: &str = r"(\s+--[a-zA-Z0-9][-a-zA-Z0-9_]*(=[^\s]*)?|\s+-[a-zA-Z]+)*\s*";
|
||||
vec![
|
||||
Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?/\*?{TRAIL}$")).unwrap(),
|
||||
Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?~/?\*?{TRAIL}$")).unwrap(),
|
||||
Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?(\$HOME|\$\{{HOME\}})/?\*?{TRAIL}$")).unwrap(),
|
||||
Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?\./?\*?{TRAIL}$")).unwrap(),
|
||||
Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?\.\./?\*?{TRAIL}$")).unwrap(),
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
fn interpolation_pattern() -> &'static Regex {
|
||||
static RE: OnceLock<Regex> = OnceLock::new();
|
||||
RE.get_or_init(|| {
|
||||
// $VAR, ${...}, $(...), $((...)), `...`, <(...), >(...)
|
||||
Regex::new(r#"(\$[A-Za-z_][A-Za-z0-9_]*|\$\{[^}]*\}|\$\([^)]*\)|\$\(\([^)]*\)\)|`[^`]*`|<\([^)]*\)|>\([^)]*\))"#).unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig};
|
||||
use crate::permission::check::PermissionContext;
|
||||
use crate::permission::whitelist::CompiledWhitelist;
|
||||
use crate::tool::{ClientKind, ProtocolKind};
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn ctx() -> ToolContext {
|
||||
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
|
||||
let pc = PermissionContext::new("c".to_string(), None, wl);
|
||||
ToolContext::for_test(
|
||||
"c", ClientKind::claude(), ProtocolKind::acp(),
|
||||
PathBuf::from("/tmp"),
|
||||
SandboxConfig::default(), PermissionConfig::default(), pc,
|
||||
)
|
||||
}
|
||||
|
||||
fn input(cmd: &str) -> serde_json::Value {
|
||||
serde_json::json!({ "command": cmd })
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_passes_non_terminal_tools() {
|
||||
let f = SecurityFloor::new();
|
||||
assert!(matches!(f.check("read", &input("rm -rf /"), &ctx()), FloorDecision::Pass));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_blocks_rm_rf_root() {
|
||||
let f = SecurityFloor::new();
|
||||
for cmd in ["rm -rf /", "rm -rfv /", "rm -rf /*"] {
|
||||
assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Block { .. }),
|
||||
"expected block for {cmd}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_blocks_home_variants() {
|
||||
let f = SecurityFloor::new();
|
||||
for cmd in ["rm -rf ~", "rm -rf ~/", "rm -rf $HOME", "rm -rf ${HOME}", "rm -rf ${HOME}/*"] {
|
||||
assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Block { .. }),
|
||||
"expected block for {cmd}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_blocks_dot_dotdot() {
|
||||
let f = SecurityFloor::new();
|
||||
for cmd in ["rm -rf .", "rm -rf ./", "rm -rf ./*", "rm -rf ..", "rm -rf ../*"] {
|
||||
assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Block { .. }),
|
||||
"expected block for {cmd}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_passes_safe_commands() {
|
||||
let f = SecurityFloor::new();
|
||||
for cmd in ["ls", "rm foo.txt", "rm -rf ./build", "rm -rf /tmp/work"] {
|
||||
assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Pass),
|
||||
"expected pass for {cmd}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_blocks_chained_rm_rf_root() {
|
||||
let f = SecurityFloor::new();
|
||||
assert!(matches!(
|
||||
f.check("terminal", &input("ls && rm -rf /"), &ctx()),
|
||||
FloorDecision::Block { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_blocks_interpolation_in_terminal() {
|
||||
let f = SecurityFloor::new();
|
||||
for cmd in ["echo $VAR", "echo ${VAR}", "echo $(date)", "echo `date`",
|
||||
"echo $((1+1))", "cat <(ls)", "tee >(ls)"] {
|
||||
assert!(matches!(
|
||||
f.check("terminal", &input(cmd), &ctx()),
|
||||
FloorDecision::Block { .. }
|
||||
), "expected block for {cmd}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn floor_passes_command_without_interpolation() {
|
||||
let f = SecurityFloor::new();
|
||||
assert!(matches!(
|
||||
f.check("terminal", &input("git status"), &ctx()),
|
||||
FloorDecision::Pass
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
//! Lightweight POSIX shell command decomposition.
|
||||
//!
|
||||
//! Splits a one-line shell input on `&&`, `||`, `;`, `|` operators.
|
||||
//! Quoted regions are preserved so chain tokens inside quotes do not split.
|
||||
//! This is a heuristic — not a full shell parser. It is sufficient to ensure
|
||||
//! a chained `rm -rf /` is not hidden behind a leading benign command.
|
||||
|
||||
/// Split `input` into individual sub-commands by POSIX chain operators.
|
||||
/// Returns the original string as a single element if no chains are detected.
|
||||
pub fn split_chain(input: &str) -> Vec<&str> {
|
||||
let mut out = Vec::new();
|
||||
let bytes = input.as_bytes();
|
||||
let mut start = 0;
|
||||
let mut i = 0;
|
||||
let mut quote: Option<u8> = None;
|
||||
let mut escape = false;
|
||||
|
||||
while i < bytes.len() {
|
||||
let b = bytes[i];
|
||||
if escape { escape = false; i += 1; continue; }
|
||||
if b == b'\\' && quote != Some(b'\'') { escape = true; i += 1; continue; }
|
||||
|
||||
match quote {
|
||||
Some(q) if q == b => { quote = None; i += 1; continue; }
|
||||
Some(_) => { i += 1; continue; }
|
||||
None => {
|
||||
if b == b'\'' || b == b'"' { quote = Some(b); i += 1; continue; }
|
||||
}
|
||||
}
|
||||
|
||||
let two = bytes.get(i..i+2);
|
||||
if two == Some(b"&&") || two == Some(b"||") {
|
||||
out.push(input[start..i].trim());
|
||||
i += 2; start = i; continue;
|
||||
}
|
||||
if b == b';' || b == b'|' {
|
||||
out.push(input[start..i].trim());
|
||||
i += 1; start = i; continue;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
let tail = input[start..].trim();
|
||||
if !tail.is_empty() { out.push(tail); }
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test] fn splits_on_and() {
|
||||
assert_eq!(split_chain("ls && rm -rf /"), vec!["ls", "rm -rf /"]);
|
||||
}
|
||||
#[test] fn splits_on_semicolon_and_pipe() {
|
||||
assert_eq!(split_chain("ls ; rm -rf / | wc"), vec!["ls", "rm -rf /", "wc"]);
|
||||
}
|
||||
#[test] fn keeps_quoted_chains_intact() {
|
||||
assert_eq!(split_chain("echo 'a && b'"), vec!["echo 'a && b'"]);
|
||||
}
|
||||
#[test] fn handles_no_chain() {
|
||||
assert_eq!(split_chain("ls -la"), vec!["ls -la"]);
|
||||
}
|
||||
#[test] fn handles_escaped_pipe() {
|
||||
assert_eq!(split_chain(r"echo a\|b"), vec![r"echo a\|b"]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
//! File system operations with sandbox enforcement.
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - `read_text_file()` - Read UTF-8 text files with line range support (TOOLS-FS-01)
|
||||
//! - `write_text_file()` - Write UTF-8 text files with atomic writes (TOOLS-FS-02)
|
||||
//! - `generate_diff()` - Generate unified diffs for changes (TOOLS-FS-03)
|
||||
//! - `edit_file()` - Apply text replacements and generate diffs (TOOLS-FS-04)
|
||||
//!
|
||||
//! All operations are subject to:
|
||||
//! - Sandbox containment checks
|
||||
//! - Blocklist enforcement
|
||||
//! - Size limits
|
||||
//! - Permission prompts (for writes)
|
||||
//!
|
||||
//! **Status**: All functions stubbed, implementation pending
|
||||
|
||||
pub mod read;
|
||||
pub mod write;
|
||||
pub mod diff;
|
||||
pub mod edit;
|
||||
pub mod file_type;
|
||||
|
||||
// Re-export main types and functions
|
||||
pub use read::{read_text_file, ReadTextFileRequest, ReadTextFileResponse};
|
||||
pub use write::{write_text_file, normalize_eol, WriteTextFileRequest, WriteTextFileResponse};
|
||||
pub use diff::generate_diff;
|
||||
pub use edit::{edit_file, EditFileRequest, EditFileResponse, EditOperation};
|
||||
pub use file_type::{detect_file_type, is_valid_utf8, FileTypeInfo};
|
||||
@@ -0,0 +1,195 @@
|
||||
//! Unified diff generation for write operations.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-FS-03)
|
||||
//!
|
||||
//! This module implements:
|
||||
//! - Unified diff generation using similar crate
|
||||
//! - Edge case handling (new files, deleted files, no changes)
|
||||
//! - Diff size limiting for UI rendering
|
||||
//! - Binary file detection and fallback
|
||||
|
||||
use similar::{ChangeTag, TextDiff};
|
||||
use std::path::Path;
|
||||
|
||||
/// Maximum diff size in characters before truncation.
|
||||
const MAX_DIFF_SIZE: usize = 100_000;
|
||||
|
||||
/// Generate a unified diff between old and new file contents.
|
||||
///
|
||||
/// Returns a unified diff in standard format, suitable for UI rendering.
|
||||
///
|
||||
/// ## Edge Cases
|
||||
///
|
||||
/// - New file (old = empty) → Shows all lines as additions
|
||||
/// - Deleted file (new = empty) → Shows all lines as deletions
|
||||
/// - No change (old == new) → Returns empty string
|
||||
/// - Very large diff → Truncated with message
|
||||
///
|
||||
/// ## Format
|
||||
///
|
||||
/// Standard unified diff format:
|
||||
/// ```text
|
||||
/// --- path/to/file.txt
|
||||
/// +++ path/to/file.txt
|
||||
/// @@ -1,3 +1,3 @@
|
||||
/// context line
|
||||
/// -old line
|
||||
/// +new line
|
||||
/// context line
|
||||
/// ```
|
||||
///
|
||||
/// ## Error Handling
|
||||
///
|
||||
/// Diff generation is best-effort:
|
||||
/// - Should never panic
|
||||
/// - Falls back to generic message on error
|
||||
/// - Logs warnings for debugging
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-03`
|
||||
pub fn generate_diff(old_content: &str, new_content: &str, path: &Path) -> String {
|
||||
// Early exit if contents are identical
|
||||
if old_content == new_content {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
// Handle new file case (old is empty)
|
||||
if old_content.is_empty() {
|
||||
let line_count = new_content.lines().count();
|
||||
let mut result = format!(
|
||||
"--- /dev/null\n+++ {}\n@@ -0,0 +1,{} @@\n",
|
||||
path.display(),
|
||||
line_count
|
||||
);
|
||||
for line in new_content.lines() {
|
||||
result.push_str(&format!("+{}\n", line));
|
||||
}
|
||||
return truncate_diff(result);
|
||||
}
|
||||
|
||||
// Handle deleted file case (new is empty)
|
||||
if new_content.is_empty() {
|
||||
let line_count = old_content.lines().count();
|
||||
let mut result = format!(
|
||||
"--- {}\n+++ /dev/null\n@@ -1,{} +0,0 @@\n",
|
||||
path.display(),
|
||||
line_count
|
||||
);
|
||||
for line in old_content.lines() {
|
||||
result.push_str(&format!("-{}\n", line));
|
||||
}
|
||||
return truncate_diff(result);
|
||||
}
|
||||
|
||||
// Generate unified diff using similar crate
|
||||
match generate_unified_diff(old_content, new_content, path) {
|
||||
Ok(diff) => truncate_diff(diff),
|
||||
Err(e) => {
|
||||
tracing::warn!(path = %path.display(), error = %e, "Failed to generate diff");
|
||||
format!("# Content changed (diff generation failed: {})\n", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate unified diff using the similar crate.
|
||||
fn generate_unified_diff(old_content: &str, new_content: &str, path: &Path) -> Result<String, String> {
|
||||
let diff = TextDiff::from_lines(old_content, new_content);
|
||||
|
||||
let mut output = String::new();
|
||||
output.push_str(&format!("--- {}\n", path.display()));
|
||||
output.push_str(&format!("+++ {}\n", path.display()));
|
||||
|
||||
// Generate unified diff format with context
|
||||
for hunk in diff.unified_diff().iter_hunks() {
|
||||
// Write hunk header
|
||||
output.push_str(&format!("{}", hunk.header()));
|
||||
|
||||
// Write changes
|
||||
for change in hunk.iter_changes() {
|
||||
let sign = match change.tag() {
|
||||
ChangeTag::Delete => "-",
|
||||
ChangeTag::Insert => "+",
|
||||
ChangeTag::Equal => " ",
|
||||
};
|
||||
output.push_str(&format!("{}{}", sign, change.value()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Truncate diff if it exceeds maximum size.
|
||||
fn truncate_diff(mut diff: String) -> String {
|
||||
if diff.len() > MAX_DIFF_SIZE {
|
||||
diff.truncate(MAX_DIFF_SIZE);
|
||||
diff.push_str("\n... [diff truncated for display] ...\n");
|
||||
}
|
||||
diff
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[test]
|
||||
fn test_generate_diff_identical() {
|
||||
let path = PathBuf::from("test.txt");
|
||||
let content = "line1\nline2\nline3";
|
||||
let diff = generate_diff(content, content, &path);
|
||||
assert_eq!(diff, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_diff_new_file() {
|
||||
let path = PathBuf::from("test.txt");
|
||||
let old = "";
|
||||
let new = "line1\nline2\nline3";
|
||||
let diff = generate_diff(old, new, &path);
|
||||
|
||||
assert!(diff.contains("--- /dev/null"));
|
||||
assert!(diff.contains("+++ test.txt"));
|
||||
assert!(diff.contains("+line1"));
|
||||
assert!(diff.contains("+line2"));
|
||||
assert!(diff.contains("+line3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_diff_deleted_file() {
|
||||
let path = PathBuf::from("test.txt");
|
||||
let old = "line1\nline2\nline3";
|
||||
let new = "";
|
||||
let diff = generate_diff(old, new, &path);
|
||||
|
||||
assert!(diff.contains("--- test.txt"));
|
||||
assert!(diff.contains("+++ /dev/null"));
|
||||
assert!(diff.contains("-line1"));
|
||||
assert!(diff.contains("-line2"));
|
||||
assert!(diff.contains("-line3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_diff_change() {
|
||||
let path = PathBuf::from("test.txt");
|
||||
let old = "line1\nline2\nline3";
|
||||
let new = "line1\nmodified\nline3";
|
||||
let diff = generate_diff(old, new, &path);
|
||||
|
||||
assert!(diff.contains("--- test.txt"));
|
||||
assert!(diff.contains("+++ test.txt"));
|
||||
assert!(diff.contains("-line2"));
|
||||
assert!(diff.contains("+modified"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncate_diff() {
|
||||
let short_diff = "short diff";
|
||||
assert_eq!(truncate_diff(short_diff.to_string()), short_diff);
|
||||
|
||||
let long_diff = "x".repeat(MAX_DIFF_SIZE + 1000);
|
||||
let truncated = truncate_diff(long_diff);
|
||||
assert!(truncated.len() <= MAX_DIFF_SIZE + 100); // Allow for truncation message
|
||||
assert!(truncated.contains("[diff truncated for display]"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,258 @@
|
||||
//! Internal edit helper for read + transform + write operations.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-FS-04)
|
||||
//!
|
||||
//! This module implements:
|
||||
//! - Edit operation abstraction (not ACP-native, internal API)
|
||||
//! - Read + transform + write flow
|
||||
//! - Automatic diff generation
|
||||
//! - String replacement operations
|
||||
|
||||
use crate::config::SandboxConfig;
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use crate::fs::diff::generate_diff;
|
||||
use crate::fs::read::{read_text_file, ReadTextFileRequest};
|
||||
use crate::fs::write::{write_text_file, WriteTextFileRequest};
|
||||
use crate::path::validate_path;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to edit a file via transformation operations.
|
||||
///
|
||||
/// **Note**: This is an internal API, not exposed via ACP directly.
|
||||
/// Agents use fs/write_text_file; this is for richer Dirigent UX.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EditFileRequest {
|
||||
/// Absolute path to the file to edit.
|
||||
pub path: String,
|
||||
|
||||
/// Ordered list of edit operations to apply.
|
||||
pub edits: Vec<EditOperation>,
|
||||
}
|
||||
|
||||
/// A single edit operation to apply to file content.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum EditOperation {
|
||||
/// Replace occurrences of old_text with new_text.
|
||||
Replace {
|
||||
/// Text to find.
|
||||
old_text: String,
|
||||
|
||||
/// Text to replace with.
|
||||
new_text: String,
|
||||
|
||||
/// Replace all occurrences (true) or first only (false).
|
||||
replace_all: bool,
|
||||
},
|
||||
|
||||
/// Apply a unified diff patch (future).
|
||||
#[allow(dead_code)]
|
||||
Patch {
|
||||
/// Unified diff string.
|
||||
diff: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Response from editing a file.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EditFileResponse {
|
||||
/// Unified diff showing changes made.
|
||||
pub diff: String,
|
||||
}
|
||||
|
||||
/// Edit a file by applying transformation operations.
|
||||
///
|
||||
/// ## Implementation
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Reads existing file content
|
||||
/// 2. Applies edits in order:
|
||||
/// - Replace: String find/replace (once or all)
|
||||
/// - Patch: Apply unified diff (future)
|
||||
/// 3. Writes transformed content
|
||||
/// 4. Generates and returns diff
|
||||
///
|
||||
/// ## Algorithm
|
||||
///
|
||||
/// ```text
|
||||
/// 1. old_content = read_text_file(path)
|
||||
/// 2. new_content = old_content
|
||||
/// 3. for each edit in edits:
|
||||
/// new_content = apply_edit(new_content, edit)
|
||||
/// 4. write_text_file(path, new_content)
|
||||
/// 5. diff = generate_diff(old_content, new_content, path)
|
||||
/// 6. return EditFileResponse { diff }
|
||||
/// ```
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - File not found → `ToolError::NotFound`
|
||||
/// - Edit on new file → `ToolError::NotFound` (edits require existing content)
|
||||
/// - Same sandboxing/permission errors as read/write
|
||||
///
|
||||
/// ## Tool Call Rendering
|
||||
///
|
||||
/// Always emits `ToolCallContent::diff` for UX visualization.
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-04`
|
||||
pub async fn edit_file(
|
||||
request: EditFileRequest,
|
||||
sandbox_config: &SandboxConfig,
|
||||
permission_config: &crate::config::PermissionConfig,
|
||||
permission_context: &crate::permission::check::PermissionContext,
|
||||
) -> ToolResult<EditFileResponse> {
|
||||
// Validate path first to get canonical path for diff generation
|
||||
let canonical_path = validate_path(&request.path, sandbox_config)?;
|
||||
|
||||
// Read existing file content
|
||||
let read_request = ReadTextFileRequest {
|
||||
path: request.path.clone(),
|
||||
line: None,
|
||||
limit: None,
|
||||
};
|
||||
|
||||
let read_response = read_text_file(read_request, sandbox_config).await?;
|
||||
let old_content = read_response.content;
|
||||
|
||||
// Apply all edits in order
|
||||
let mut new_content = old_content.clone();
|
||||
for edit in &request.edits {
|
||||
new_content = apply_edit(&new_content, edit)?;
|
||||
}
|
||||
|
||||
// Write transformed content
|
||||
let write_request = WriteTextFileRequest {
|
||||
path: request.path.clone(),
|
||||
content: new_content.clone(),
|
||||
};
|
||||
|
||||
write_text_file(write_request, sandbox_config, permission_config, permission_context).await?;
|
||||
|
||||
// Generate diff for UX
|
||||
let diff = generate_diff(&old_content, &new_content, &canonical_path);
|
||||
|
||||
tracing::info!(
|
||||
path = %request.path,
|
||||
edit_count = request.edits.len(),
|
||||
"File edited successfully"
|
||||
);
|
||||
|
||||
Ok(EditFileResponse { diff })
|
||||
}
|
||||
|
||||
/// Apply a single edit operation to content.
|
||||
fn apply_edit(content: &str, edit: &EditOperation) -> ToolResult<String> {
|
||||
match edit {
|
||||
EditOperation::Replace {
|
||||
old_text,
|
||||
new_text,
|
||||
replace_all,
|
||||
} => {
|
||||
if *replace_all {
|
||||
// Replace all occurrences
|
||||
Ok(content.replace(old_text, new_text))
|
||||
} else {
|
||||
// Replace only the first occurrence
|
||||
if let Some(pos) = content.find(old_text) {
|
||||
let mut result = String::with_capacity(content.len());
|
||||
result.push_str(&content[..pos]);
|
||||
result.push_str(new_text);
|
||||
result.push_str(&content[pos + old_text.len()..]);
|
||||
Ok(result)
|
||||
} else {
|
||||
// No match found - return content unchanged
|
||||
// This could be a warning, but we'll allow it
|
||||
tracing::warn!(
|
||||
old_text = %old_text,
|
||||
"Edit operation: old_text not found in content"
|
||||
);
|
||||
Ok(content.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
EditOperation::Patch { diff: _ } => {
|
||||
// Future: Apply unified diff patch
|
||||
Err(ToolError::InvalidInput(
|
||||
"Patch edit operation not yet implemented".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_apply_edit_replace_all() {
|
||||
let content = "foo bar foo baz foo";
|
||||
let edit = EditOperation::Replace {
|
||||
old_text: "foo".to_string(),
|
||||
new_text: "qux".to_string(),
|
||||
replace_all: true,
|
||||
};
|
||||
let result = apply_edit(content, &edit).unwrap();
|
||||
assert_eq!(result, "qux bar qux baz qux");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_edit_replace_first() {
|
||||
let content = "foo bar foo baz foo";
|
||||
let edit = EditOperation::Replace {
|
||||
old_text: "foo".to_string(),
|
||||
new_text: "qux".to_string(),
|
||||
replace_all: false,
|
||||
};
|
||||
let result = apply_edit(content, &edit).unwrap();
|
||||
assert_eq!(result, "qux bar foo baz foo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_edit_no_match() {
|
||||
let content = "foo bar baz";
|
||||
let edit = EditOperation::Replace {
|
||||
old_text: "qux".to_string(),
|
||||
new_text: "quux".to_string(),
|
||||
replace_all: false,
|
||||
};
|
||||
let result = apply_edit(content, &edit).unwrap();
|
||||
assert_eq!(result, content); // Unchanged
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_edit_empty_replacement() {
|
||||
let content = "foo bar foo";
|
||||
let edit = EditOperation::Replace {
|
||||
old_text: "foo".to_string(),
|
||||
new_text: "".to_string(),
|
||||
replace_all: true,
|
||||
};
|
||||
let result = apply_edit(content, &edit).unwrap();
|
||||
assert_eq!(result, " bar ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_edit_multiline() {
|
||||
let content = "line1\nfoo\nline3\nfoo\nline5";
|
||||
let edit = EditOperation::Replace {
|
||||
old_text: "foo".to_string(),
|
||||
new_text: "bar".to_string(),
|
||||
replace_all: true,
|
||||
};
|
||||
let result = apply_edit(content, &edit).unwrap();
|
||||
assert_eq!(result, "line1\nbar\nline3\nbar\nline5");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_edit_patch_unimplemented() {
|
||||
let content = "foo bar baz";
|
||||
let edit = EditOperation::Patch {
|
||||
diff: "some diff".to_string(),
|
||||
};
|
||||
let result = apply_edit(content, &edit);
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result, Err(ToolError::InvalidInput(_))));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,331 @@
|
||||
//! File type detection and MIME type guessing.
|
||||
//!
|
||||
//! This module provides utilities for detecting whether a file is text or binary,
|
||||
//! and guessing appropriate MIME types based on file extensions.
|
||||
|
||||
use std::fs;
|
||||
use std::io::Read;
|
||||
use std::path::Path;
|
||||
|
||||
use crate::error::ToolResult;
|
||||
|
||||
/// Information about a file's type and encoding.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct FileTypeInfo {
|
||||
/// Whether the file appears to be text (not binary).
|
||||
pub is_text: bool,
|
||||
/// Whether the file appears to be binary (not text).
|
||||
pub is_binary: bool,
|
||||
/// MIME type guess based on extension and content.
|
||||
pub mime_type: Option<String>,
|
||||
/// Character set (always "utf-8" for text files in Phase 03).
|
||||
pub charset: Option<String>,
|
||||
}
|
||||
|
||||
impl FileTypeInfo {
|
||||
/// Create a text file type info.
|
||||
pub fn text(mime_type: impl Into<String>) -> Self {
|
||||
Self {
|
||||
is_text: true,
|
||||
is_binary: false,
|
||||
mime_type: Some(mime_type.into()),
|
||||
charset: Some("utf-8".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a binary file type info.
|
||||
pub fn binary(mime_type: impl Into<String>) -> Self {
|
||||
Self {
|
||||
is_text: false,
|
||||
is_binary: true,
|
||||
mime_type: Some(mime_type.into()),
|
||||
charset: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an unknown file type info (defaults to binary).
|
||||
pub fn unknown() -> Self {
|
||||
Self {
|
||||
is_text: false,
|
||||
is_binary: true,
|
||||
mime_type: Some("application/octet-stream".to_string()),
|
||||
charset: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect file type based on extension and content analysis.
|
||||
///
|
||||
/// This uses a fast path (extension-based detection) followed by optional
|
||||
/// content sniffing for extensionless files.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the file to analyze
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// File type information including text/binary classification and MIME type.
|
||||
pub fn detect_file_type(path: &Path) -> ToolResult<FileTypeInfo> {
|
||||
// First, try extension-based detection (fast path)
|
||||
if let Some(info) = detect_by_extension(path) {
|
||||
return Ok(info);
|
||||
}
|
||||
|
||||
// For extensionless files, do content sniffing
|
||||
detect_by_content(path)
|
||||
}
|
||||
|
||||
/// Detect file type by extension.
|
||||
///
|
||||
/// This is the fast path for common file types. Returns None if the extension
|
||||
/// is unknown or missing.
|
||||
fn detect_by_extension(path: &Path) -> Option<FileTypeInfo> {
|
||||
let ext = path.extension()?.to_str()?.to_lowercase();
|
||||
|
||||
match ext.as_str() {
|
||||
// Programming languages
|
||||
"rs" => Some(FileTypeInfo::text("text/x-rust")),
|
||||
"ts" => Some(FileTypeInfo::text("text/typescript")),
|
||||
"tsx" => Some(FileTypeInfo::text("text/tsx")),
|
||||
"js" => Some(FileTypeInfo::text("text/javascript")),
|
||||
"jsx" => Some(FileTypeInfo::text("text/jsx")),
|
||||
"py" => Some(FileTypeInfo::text("text/x-python")),
|
||||
"rb" => Some(FileTypeInfo::text("text/x-ruby")),
|
||||
"go" => Some(FileTypeInfo::text("text/x-go")),
|
||||
"java" => Some(FileTypeInfo::text("text/x-java")),
|
||||
"c" => Some(FileTypeInfo::text("text/x-c")),
|
||||
"cpp" | "cc" | "cxx" => Some(FileTypeInfo::text("text/x-c++")),
|
||||
"h" | "hpp" => Some(FileTypeInfo::text("text/x-c-header")),
|
||||
"cs" => Some(FileTypeInfo::text("text/x-csharp")),
|
||||
"sh" | "bash" | "zsh" => Some(FileTypeInfo::text("text/x-shellscript")),
|
||||
"ps1" => Some(FileTypeInfo::text("text/x-powershell")),
|
||||
|
||||
// Data formats
|
||||
"json" => Some(FileTypeInfo::text("application/json")),
|
||||
"toml" => Some(FileTypeInfo::text("application/toml")),
|
||||
"yaml" | "yml" => Some(FileTypeInfo::text("application/yaml")),
|
||||
"xml" => Some(FileTypeInfo::text("application/xml")),
|
||||
"csv" => Some(FileTypeInfo::text("text/csv")),
|
||||
|
||||
// Markup
|
||||
"html" | "htm" => Some(FileTypeInfo::text("text/html")),
|
||||
"css" => Some(FileTypeInfo::text("text/css")),
|
||||
"scss" | "sass" => Some(FileTypeInfo::text("text/x-scss")),
|
||||
"md" | "markdown" => Some(FileTypeInfo::text("text/markdown")),
|
||||
"rst" => Some(FileTypeInfo::text("text/x-rst")),
|
||||
|
||||
// Plain text
|
||||
"txt" => Some(FileTypeInfo::text("text/plain")),
|
||||
"log" => Some(FileTypeInfo::text("text/plain")),
|
||||
|
||||
// Config files
|
||||
"ini" | "cfg" | "conf" => Some(FileTypeInfo::text("text/plain")),
|
||||
"env" => Some(FileTypeInfo::text("text/plain")),
|
||||
|
||||
// Binary formats
|
||||
"png" => Some(FileTypeInfo::binary("image/png")),
|
||||
"jpg" | "jpeg" => Some(FileTypeInfo::binary("image/jpeg")),
|
||||
"gif" => Some(FileTypeInfo::binary("image/gif")),
|
||||
"webp" => Some(FileTypeInfo::binary("image/webp")),
|
||||
"svg" => Some(FileTypeInfo::text("image/svg+xml")), // SVG is text
|
||||
"pdf" => Some(FileTypeInfo::binary("application/pdf")),
|
||||
"zip" => Some(FileTypeInfo::binary("application/zip")),
|
||||
"gz" | "gzip" => Some(FileTypeInfo::binary("application/gzip")),
|
||||
"tar" => Some(FileTypeInfo::binary("application/x-tar")),
|
||||
"mp3" => Some(FileTypeInfo::binary("audio/mpeg")),
|
||||
"mp4" => Some(FileTypeInfo::binary("video/mp4")),
|
||||
"exe" | "dll" | "so" | "dylib" => Some(FileTypeInfo::binary("application/octet-stream")),
|
||||
|
||||
// Unknown extension
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect file type by analyzing content.
|
||||
///
|
||||
/// Reads the first 8KB of the file and checks for UTF-8 validity and
|
||||
/// presence of null bytes (indicating binary).
|
||||
fn detect_by_content(path: &Path) -> ToolResult<FileTypeInfo> {
|
||||
// Open file
|
||||
let mut file = fs::File::open(path)?;
|
||||
|
||||
// Read first 8KB for analysis
|
||||
let mut buffer = vec![0u8; 8192];
|
||||
let bytes_read = file.read(&mut buffer)?;
|
||||
buffer.truncate(bytes_read);
|
||||
|
||||
// Check for null bytes (strong indicator of binary)
|
||||
if buffer.contains(&0) {
|
||||
return Ok(FileTypeInfo::binary("application/octet-stream"));
|
||||
}
|
||||
|
||||
// Try to validate as UTF-8
|
||||
match std::str::from_utf8(&buffer) {
|
||||
Ok(_) => Ok(FileTypeInfo::text("text/plain")),
|
||||
Err(_) => Ok(FileTypeInfo::binary("application/octet-stream")),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate that a file contains valid UTF-8 text.
|
||||
///
|
||||
/// This is used to ensure text files can be safely read and embedded.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the file to validate
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(true)` if the file is valid UTF-8, `Ok(false)` if not.
|
||||
pub fn is_valid_utf8(path: &Path) -> ToolResult<bool> {
|
||||
let content = fs::read(path)?;
|
||||
Ok(std::str::from_utf8(&content).is_ok())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_detect_by_extension_rust() {
|
||||
let path = Path::new("test.rs");
|
||||
let info = detect_by_extension(path).unwrap();
|
||||
|
||||
assert!(info.is_text);
|
||||
assert!(!info.is_binary);
|
||||
assert_eq!(info.mime_type, Some("text/x-rust".to_string()));
|
||||
assert_eq!(info.charset, Some("utf-8".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_by_extension_typescript() {
|
||||
let path = Path::new("test.ts");
|
||||
let info = detect_by_extension(path).unwrap();
|
||||
|
||||
assert!(info.is_text);
|
||||
assert_eq!(info.mime_type, Some("text/typescript".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_by_extension_json() {
|
||||
let path = Path::new("config.json");
|
||||
let info = detect_by_extension(path).unwrap();
|
||||
|
||||
assert!(info.is_text);
|
||||
assert_eq!(info.mime_type, Some("application/json".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_by_extension_binary_png() {
|
||||
let path = Path::new("image.png");
|
||||
let info = detect_by_extension(path).unwrap();
|
||||
|
||||
assert!(!info.is_text);
|
||||
assert!(info.is_binary);
|
||||
assert_eq!(info.mime_type, Some("image/png".to_string()));
|
||||
assert_eq!(info.charset, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_by_extension_unknown() {
|
||||
let path = Path::new("file.unknownext");
|
||||
let info = detect_by_extension(path);
|
||||
|
||||
assert!(info.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_by_extension_no_extension() {
|
||||
let path = Path::new("Makefile");
|
||||
let info = detect_by_extension(path);
|
||||
|
||||
assert!(info.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_by_content_text() {
|
||||
let mut temp = NamedTempFile::new().unwrap();
|
||||
temp.write_all(b"Hello, world!\nThis is a text file.\n")
|
||||
.unwrap();
|
||||
|
||||
let info = detect_by_content(temp.path()).unwrap();
|
||||
|
||||
assert!(info.is_text);
|
||||
assert!(!info.is_binary);
|
||||
assert_eq!(info.mime_type, Some("text/plain".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_by_content_binary_with_nulls() {
|
||||
let mut temp = NamedTempFile::new().unwrap();
|
||||
temp.write_all(b"\x00\x01\x02\x03binary data").unwrap();
|
||||
|
||||
let info = detect_by_content(temp.path()).unwrap();
|
||||
|
||||
assert!(!info.is_text);
|
||||
assert!(info.is_binary);
|
||||
assert_eq!(info.mime_type, Some("application/octet-stream".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_by_content_invalid_utf8() {
|
||||
let mut temp = NamedTempFile::new().unwrap();
|
||||
// Invalid UTF-8 sequence
|
||||
temp.write_all(&[0xFF, 0xFE, 0xFD]).unwrap();
|
||||
|
||||
let info = detect_by_content(temp.path()).unwrap();
|
||||
|
||||
assert!(!info.is_text);
|
||||
assert!(info.is_binary);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_file_type_with_extension() {
|
||||
let path = Path::new("test.rs");
|
||||
// Will use extension-based detection (no file needed)
|
||||
let info = detect_by_extension(path).unwrap();
|
||||
|
||||
assert!(info.is_text);
|
||||
assert_eq!(info.mime_type, Some("text/x-rust".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_utf8_valid() {
|
||||
let mut temp = NamedTempFile::new().unwrap();
|
||||
temp.write_all("Hello, UTF-8! 你好世界".as_bytes())
|
||||
.unwrap();
|
||||
|
||||
let result = is_valid_utf8(temp.path()).unwrap();
|
||||
assert!(result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_utf8_invalid() {
|
||||
let mut temp = NamedTempFile::new().unwrap();
|
||||
temp.write_all(&[0xFF, 0xFE, 0xFD]).unwrap();
|
||||
|
||||
let result = is_valid_utf8(temp.path()).unwrap();
|
||||
assert!(!result);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_type_info_constructors() {
|
||||
let text = FileTypeInfo::text("text/plain");
|
||||
assert!(text.is_text);
|
||||
assert!(!text.is_binary);
|
||||
assert_eq!(text.charset, Some("utf-8".to_string()));
|
||||
|
||||
let binary = FileTypeInfo::binary("application/pdf");
|
||||
assert!(!binary.is_text);
|
||||
assert!(binary.is_binary);
|
||||
assert_eq!(binary.charset, None);
|
||||
|
||||
let unknown = FileTypeInfo::unknown();
|
||||
assert!(!unknown.is_text);
|
||||
assert!(unknown.is_binary);
|
||||
assert_eq!(unknown.mime_type, Some("application/octet-stream".to_string()));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,215 @@
|
||||
//! File read operation with sandboxing and line/limit support.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-FS-01)
|
||||
//!
|
||||
//! This module implements:
|
||||
//! - Path validation using the security layer
|
||||
//! - Asynchronous file reading
|
||||
//! - UTF-8 encoding validation
|
||||
//! - Line range and limit semantics
|
||||
//! - Sandbox containment checks
|
||||
//! - Blocklist enforcement
|
||||
//! - Size limit soft caps
|
||||
|
||||
use crate::config::SandboxConfig;
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use crate::path::validate_path;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to read a text file.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReadTextFileRequest {
|
||||
/// Absolute path to the file to read.
|
||||
pub path: String,
|
||||
|
||||
/// Optional starting line (1-indexed).
|
||||
///
|
||||
/// If provided with limit, reads from this line.
|
||||
/// If provided without limit, reads from this line to end.
|
||||
pub line: Option<usize>,
|
||||
|
||||
/// Optional maximum number of lines to read.
|
||||
///
|
||||
/// If provided without line, reads first N lines.
|
||||
/// If provided with line, reads N lines starting from line.
|
||||
pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
/// Response from reading a text file.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReadTextFileResponse {
|
||||
/// UTF-8 text content of the file (or portion thereof).
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Read a text file with sandboxing and line/limit support.
|
||||
///
|
||||
/// ## Line/Limit Semantics
|
||||
///
|
||||
/// - `line: None, limit: None` → Read entire file
|
||||
/// - `line: Some(n), limit: None` → Read from line n to end
|
||||
/// - `line: None, limit: Some(m)` → Read first m lines
|
||||
/// - `line: Some(n), limit: Some(m)` → Read m lines starting from line n
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - Read disabled → `ToolError::PermissionDenied`
|
||||
/// - Path outside allowed roots → `ToolError::SandboxViolation`
|
||||
/// - Path matches blocklist → `ToolError::BlockedPath`
|
||||
/// - File not found → `ToolError::NotFound`
|
||||
/// - Non-UTF-8 content → `ToolError::EncodingUnsupported`
|
||||
/// - File too large (soft cap) → Warning logged, content returned
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-01`
|
||||
/// - Security spec: `docs/building/04_acp_client/03_fs_sandboxing_and_permissions_spec.md`
|
||||
pub async fn read_text_file(
|
||||
request: ReadTextFileRequest,
|
||||
config: &SandboxConfig,
|
||||
) -> ToolResult<ReadTextFileResponse> {
|
||||
// Check if read is enabled
|
||||
if !config.read_enabled {
|
||||
return Err(ToolError::permission_denied("Read operations are disabled"));
|
||||
}
|
||||
|
||||
// Validate and canonicalize path (checks containment and blocklist)
|
||||
let canonical_path = validate_path(&request.path, config)?;
|
||||
|
||||
// Read file asynchronously
|
||||
let content_bytes = tokio::fs::read(&canonical_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ToolError::NotFound {
|
||||
path: request.path.clone(),
|
||||
}
|
||||
} else {
|
||||
ToolError::FileReadError {
|
||||
path: request.path.clone(),
|
||||
source: e,
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
// Check soft cap and log warning if exceeded
|
||||
let file_size = content_bytes.len() as u64;
|
||||
if file_size > config.max_read_bytes {
|
||||
tracing::warn!(
|
||||
path = %request.path,
|
||||
size = file_size,
|
||||
limit = config.max_read_bytes,
|
||||
"File size exceeds max_read_bytes soft cap"
|
||||
);
|
||||
}
|
||||
|
||||
// Validate UTF-8 encoding
|
||||
let content = String::from_utf8(content_bytes).map_err(|_| {
|
||||
ToolError::EncodingUnsupported {
|
||||
encoding: "non-UTF-8".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
// Apply line/limit semantics
|
||||
let final_content = apply_line_limit(&content, request.line, request.limit);
|
||||
|
||||
Ok(ReadTextFileResponse {
|
||||
content: final_content,
|
||||
})
|
||||
}
|
||||
|
||||
/// Apply line/limit semantics to file content.
|
||||
///
|
||||
/// ## Semantics
|
||||
///
|
||||
/// - `line: None, limit: None` → Return entire content
|
||||
/// - `line: Some(n), limit: None` → Return from line n to end (1-indexed)
|
||||
/// - `line: None, limit: Some(m)` → Return first m lines
|
||||
/// - `line: Some(n), limit: Some(m)` → Return m lines starting from line n (1-indexed)
|
||||
fn apply_line_limit(content: &str, line: Option<usize>, limit: Option<usize>) -> String {
|
||||
match (line, limit) {
|
||||
// No line or limit specified - return entire content
|
||||
(None, None) => content.to_string(),
|
||||
|
||||
// Only limit specified - return first N lines
|
||||
(None, Some(limit)) => {
|
||||
content
|
||||
.lines()
|
||||
.take(limit)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
// Only line specified - return from line N to end (1-indexed)
|
||||
(Some(start_line), None) => {
|
||||
if start_line == 0 {
|
||||
return content.to_string();
|
||||
}
|
||||
content
|
||||
.lines()
|
||||
.skip(start_line.saturating_sub(1))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
// Both line and limit specified - return N lines starting from line M (1-indexed)
|
||||
(Some(start_line), Some(limit)) => {
|
||||
if start_line == 0 {
|
||||
return content
|
||||
.lines()
|
||||
.take(limit)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
}
|
||||
content
|
||||
.lines()
|
||||
.skip(start_line.saturating_sub(1))
|
||||
.take(limit)
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_apply_line_limit_no_params() {
|
||||
let content = "line1\nline2\nline3";
|
||||
assert_eq!(apply_line_limit(content, None, None), content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_line_limit_only_limit() {
|
||||
let content = "line1\nline2\nline3\nline4";
|
||||
assert_eq!(apply_line_limit(content, None, Some(2)), "line1\nline2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_line_limit_only_line() {
|
||||
let content = "line1\nline2\nline3\nline4";
|
||||
assert_eq!(apply_line_limit(content, Some(2), None), "line2\nline3\nline4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_line_limit_both() {
|
||||
let content = "line1\nline2\nline3\nline4\nline5";
|
||||
assert_eq!(apply_line_limit(content, Some(2), Some(2)), "line2\nline3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_line_limit_start_zero() {
|
||||
let content = "line1\nline2\nline3";
|
||||
assert_eq!(apply_line_limit(content, Some(0), None), content);
|
||||
assert_eq!(apply_line_limit(content, Some(0), Some(2)), "line1\nline2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_line_limit_beyond_end() {
|
||||
let content = "line1\nline2\nline3";
|
||||
assert_eq!(apply_line_limit(content, Some(10), None), "");
|
||||
assert_eq!(apply_line_limit(content, Some(2), Some(10)), "line2\nline3");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
//! File write operation with sandboxing, permissions, and atomic writes.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-FS-02)
|
||||
//!
|
||||
//! This module implements:
|
||||
//! - Path validation and canonicalization
|
||||
//! - Permission checks (stubbed for now)
|
||||
//! - Atomic write operations (temp + rename)
|
||||
//! - EOL normalization
|
||||
//! - Parent directory creation
|
||||
//! - Size limit enforcement
|
||||
|
||||
use crate::config::{EolPolicy, PermissionConfig, SandboxConfig};
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use crate::path::validate_path;
|
||||
use crate::permission::check::{check_permission, PermissionContext};
|
||||
use crate::permission::cache::PermissionDecision;
|
||||
use crate::permission::whitelist::PermissionOperation;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
/// Request to write a text file.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WriteTextFileRequest {
|
||||
/// Absolute path to the file to write.
|
||||
pub path: String,
|
||||
|
||||
/// UTF-8 text content to write.
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Response from writing a text file.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WriteTextFileResponse {}
|
||||
|
||||
/// Write a text file with sandboxing, permission gating, and atomic writes.
|
||||
///
|
||||
/// ## Implementation
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Validates `write_enabled = true` in config
|
||||
/// 2. Validates and canonicalizes path
|
||||
/// 3. Checks containment and blocklist
|
||||
/// 4. Checks permission (TODO: integrate with permission system)
|
||||
/// 5. Validates content size against max_write_bytes
|
||||
/// 6. Applies EOL policy normalization
|
||||
/// 7. Creates parent directories if needed
|
||||
/// 8. Performs atomic write (temp file + rename)
|
||||
///
|
||||
/// ## EOL Normalization
|
||||
///
|
||||
/// Applies configured EOL policy:
|
||||
/// - `Preserve` → Keep original line endings
|
||||
/// - `Lf` → Normalize to LF (\n)
|
||||
/// - `Crlf` → Normalize to CRLF (\r\n)
|
||||
///
|
||||
/// ## Atomic Writes
|
||||
///
|
||||
/// To prevent partial writes:
|
||||
/// 1. Write content to temporary file in same directory
|
||||
/// 2. Rename temp file to target path (atomic on POSIX, near-atomic on Windows)
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - Write disabled → `ToolError::PermissionDenied`
|
||||
/// - Path outside allowed roots → `ToolError::SandboxViolation`
|
||||
/// - Path matches blocklist → `ToolError::BlockedPath`
|
||||
/// - Permission denied → `ToolError::PermissionDenied`
|
||||
/// - Content too large → `ToolError::FileTooLarge`
|
||||
/// - I/O errors → `ToolError::Io`
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-02`
|
||||
/// - Security spec: `docs/building/04_acp_client/03_fs_sandboxing_and_permissions_spec.md`
|
||||
pub async fn write_text_file(
|
||||
request: WriteTextFileRequest,
|
||||
sandbox_config: &SandboxConfig,
|
||||
permission_config: &PermissionConfig,
|
||||
permission_context: &PermissionContext,
|
||||
) -> ToolResult<WriteTextFileResponse> {
|
||||
// Check if write is enabled
|
||||
if !sandbox_config.write_enabled {
|
||||
return Err(ToolError::permission_denied("Write operations are disabled"));
|
||||
}
|
||||
|
||||
// Validate content size
|
||||
let content_size = request.content.len() as u64;
|
||||
if content_size > sandbox_config.max_write_bytes {
|
||||
return Err(ToolError::FileTooLarge {
|
||||
size: content_size,
|
||||
limit: sandbox_config.max_write_bytes,
|
||||
});
|
||||
}
|
||||
|
||||
// Validate and canonicalize path (checks containment and blocklist)
|
||||
let canonical_path = validate_path(&request.path, sandbox_config)?;
|
||||
|
||||
// Check permission via permission system (TOOLS-PERM-01)
|
||||
let operation = PermissionOperation::Write {
|
||||
path: request.path.clone(),
|
||||
};
|
||||
|
||||
let decision = check_permission(&operation, permission_context, permission_config).await?;
|
||||
|
||||
match decision {
|
||||
PermissionDecision::Allowed => {
|
||||
// Permission granted, proceed with write
|
||||
tracing::debug!(
|
||||
path = %request.path,
|
||||
"Permission granted for write operation"
|
||||
);
|
||||
}
|
||||
PermissionDecision::Denied => {
|
||||
return Err(ToolError::PermissionRejected);
|
||||
}
|
||||
PermissionDecision::Cancelled => {
|
||||
return Err(ToolError::permission_denied("Permission prompt was cancelled"));
|
||||
}
|
||||
}
|
||||
|
||||
// Apply EOL normalization
|
||||
let normalized_content = normalize_eol(&request.content, sandbox_config.eol_policy);
|
||||
|
||||
// Create parent directories if they don't exist
|
||||
if let Some(parent) = canonical_path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await.map_err(|e| {
|
||||
ToolError::FileReadError {
|
||||
path: parent.display().to_string(),
|
||||
source: e,
|
||||
}
|
||||
})?;
|
||||
}
|
||||
|
||||
// Perform atomic write
|
||||
write_atomic(&canonical_path, &normalized_content).await?;
|
||||
|
||||
tracing::info!(
|
||||
path = %request.path,
|
||||
size = content_size,
|
||||
"File written successfully"
|
||||
);
|
||||
|
||||
Ok(WriteTextFileResponse {})
|
||||
}
|
||||
|
||||
/// Normalize line endings according to configured policy.
|
||||
pub fn normalize_eol(content: &str, policy: EolPolicy) -> String {
|
||||
match policy {
|
||||
EolPolicy::Preserve => content.to_string(),
|
||||
EolPolicy::Lf => {
|
||||
// Replace CRLF with LF, then ensure all CR are removed
|
||||
content.replace("\r\n", "\n").replace('\r', "\n")
|
||||
}
|
||||
EolPolicy::Crlf => {
|
||||
// First normalize to LF, then replace with CRLF
|
||||
let lf_normalized = content.replace("\r\n", "\n").replace('\r', "\n");
|
||||
lf_normalized.replace('\n', "\r\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write content to a file atomically using temp file + rename.
|
||||
///
|
||||
/// This minimizes the risk of partial writes by:
|
||||
/// 1. Writing to a temporary file in the same directory
|
||||
/// 2. Renaming the temp file to the target path
|
||||
///
|
||||
/// The rename operation is atomic on POSIX systems and near-atomic on Windows.
|
||||
async fn write_atomic(path: &Path, content: &str) -> ToolResult<()> {
|
||||
// Create a temporary file in the same directory
|
||||
let parent = path.parent().unwrap_or_else(|| Path::new("."));
|
||||
let file_name = path.file_name().unwrap_or_default().to_string_lossy();
|
||||
let temp_path = parent.join(format!(".{}.tmp", file_name));
|
||||
|
||||
// Write content to temporary file
|
||||
tokio::fs::write(&temp_path, content).await.map_err(|e| {
|
||||
ToolError::FileReadError {
|
||||
path: temp_path.display().to_string(),
|
||||
source: e,
|
||||
}
|
||||
})?;
|
||||
|
||||
// Rename temp file to target (atomic operation)
|
||||
tokio::fs::rename(&temp_path, path).await.map_err(|e| {
|
||||
// Clean up temp file on error
|
||||
let _ = std::fs::remove_file(&temp_path);
|
||||
ToolError::FileReadError {
|
||||
path: path.display().to_string(),
|
||||
source: e,
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_normalize_eol_preserve() {
|
||||
let content = "line1\nline2\r\nline3\rline4";
|
||||
assert_eq!(normalize_eol(content, EolPolicy::Preserve), content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_eol_lf() {
|
||||
let content = "line1\nline2\r\nline3\rline4";
|
||||
let expected = "line1\nline2\nline3\nline4";
|
||||
assert_eq!(normalize_eol(content, EolPolicy::Lf), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_eol_crlf() {
|
||||
let content = "line1\nline2\r\nline3\rline4";
|
||||
let expected = "line1\r\nline2\r\nline3\r\nline4";
|
||||
assert_eq!(normalize_eol(content, EolPolicy::Crlf), expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_eol_edge_cases() {
|
||||
// Empty string
|
||||
assert_eq!(normalize_eol("", EolPolicy::Lf), "");
|
||||
|
||||
// Only newlines
|
||||
assert_eq!(normalize_eol("\n\n\n", EolPolicy::Crlf), "\r\n\r\n\r\n");
|
||||
|
||||
// Mixed line endings to LF
|
||||
assert_eq!(normalize_eol("a\rb\nc\r\nd", EolPolicy::Lf), "a\nb\nc\nd");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
//! # dirigent_tools
|
||||
//!
|
||||
//! Tool implementations for ACP (Agent-Client Protocol) client operations.
|
||||
//!
|
||||
//! This package provides:
|
||||
//! - **File operations** (read, write, edit) with sandboxing
|
||||
//! - **Terminal operations** (create, output, wait, kill) with isolation
|
||||
//! - **Search operations** (glob, grep, ls) with result limiting
|
||||
//! - **Permission system** for securing write and execute operations
|
||||
//! - **Path sandboxing** with configurable allowed roots and blocklists
|
||||
//!
|
||||
//! ## Module Overview
|
||||
//!
|
||||
//! - [`error`] - Error types for tool operations
|
||||
//! - [`config`] - Configuration types for sandboxing, permissions, and limits
|
||||
//! - [`path`] - Path normalization and containment checking (cross-platform)
|
||||
//! - [`fs`] - File system operations (read, write) with sandbox enforcement
|
||||
//! - [`search`] - Search operations (glob, grep, ls) with result limiting
|
||||
//! - [`terminal`] - Terminal/command execution with output capture
|
||||
//! - [`permission`] - Permission prompts and decision caching
|
||||
//! - [`audit`] - Audit logging for sensitive operations
|
||||
//! - [`tool`] - Tool trait, events, and per-call context
|
||||
//! - [`registry`] - ToolRegistry for built-in and dynamic tools
|
||||
//! - [`floor`] - Hardcoded SecurityFloor (cannot be bypassed by settings)
|
||||
//! - [`dispatch`] - Dispatch pipeline (registry → floor → run)
|
||||
//! - [`tools`] - Built-in tool implementations (read, ...)
|
||||
//!
|
||||
//! ## Safety and Security
|
||||
//!
|
||||
//! All file and terminal operations are subject to:
|
||||
//! - **Sandbox containment** - Operations restricted to configured allowed roots
|
||||
//! - **Blocklist enforcement** - Sensitive paths can be explicitly denied
|
||||
//! - **Permission prompts** - Write and execute operations can require user approval
|
||||
//! - **Resource limits** - File size, search results, and terminal output are bounded
|
||||
//! - **Audit logging** - All operations are logged with structured context
|
||||
//!
|
||||
//! ## Platform Support
|
||||
//!
|
||||
//! This crate is designed with Windows as a first-class platform:
|
||||
//! - Handles Windows paths (backslashes, drive letters, UNC shares, `\\?\` prefixes)
|
||||
//! - Supports MINGW-style paths (`/c/...`)
|
||||
//! - Normalizes path separators for consistent policy enforcement
|
||||
//! - Tests run on Windows, Linux, and macOS
|
||||
|
||||
pub mod error;
|
||||
pub mod config;
|
||||
pub mod path;
|
||||
pub mod fs;
|
||||
pub mod search;
|
||||
pub mod terminal;
|
||||
pub mod permission;
|
||||
pub mod audit;
|
||||
pub mod embedding;
|
||||
pub mod tool;
|
||||
pub mod registry;
|
||||
pub mod floor;
|
||||
pub mod dispatch;
|
||||
pub mod tools;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use error::{ToolError, ToolResult};
|
||||
pub use config::{
|
||||
SandboxConfig, PermissionConfig, TerminalConfig, SearchConfig, EmbeddingConfig,
|
||||
};
|
||||
pub use tool::{
|
||||
AnyTool, AnyToolInput, ClientKind, Erased, PermissionRequestId, ProtocolKind, Tool,
|
||||
ToolContext, ToolEvent, ToolEventSink, ToolInput, ToolKind, ToolLocation, ToolResultContent,
|
||||
};
|
||||
pub use registry::{CollisionPolicy, DynamicEntry, ToolRegistry, ToolSource, Winner};
|
||||
pub use floor::{FloorDecision, SecurityFloor};
|
||||
pub use dispatch::{dispatch, DispatchResult};
|
||||
|
||||
/// Re-exports of the policy types `dirigent_tools` consumes from
|
||||
/// `dirigent_fermata`. Keeping this here pins the dependency direction:
|
||||
/// `dirigent_tools` → `dirigent_fermata`, never the reverse.
|
||||
pub mod policy {
|
||||
pub use dirigent_fermata::core::{Decision, Op, Policy, Reason, Rule};
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
//! Path normalization and containment checking.
|
||||
//!
|
||||
//! This module provides cross-platform path utilities with special attention to Windows:
|
||||
//! - Normalizes path separators (backslash vs forward slash)
|
||||
//! - Handles Windows drive letters (C:\, /c/, etc.)
|
||||
//! - Handles UNC paths (\\server\share\...)
|
||||
//! - Handles long path prefixes (\\?\...)
|
||||
//! - Handles MINGW-style paths (/c/Users/...)
|
||||
//! - Canonical path resolution (symlinks, junctions)
|
||||
//! - Containment checking for sandbox enforcement
|
||||
//!
|
||||
//! # Security-Critical
|
||||
//!
|
||||
//! This module is the security foundation for tool sandboxing. All operations must:
|
||||
//! - Prevent path traversal attacks
|
||||
//! - Prevent symlink escape attacks
|
||||
//! - Handle all Windows path edge cases correctly
|
||||
//! - Never expose disallowed paths in error messages
|
||||
|
||||
pub mod canonicalize;
|
||||
pub mod containment;
|
||||
pub mod blocklist;
|
||||
pub mod validate;
|
||||
|
||||
// Re-export public API
|
||||
pub use canonicalize::{canonicalize_path, SymlinkPolicy};
|
||||
pub use containment::check_containment;
|
||||
pub use blocklist::check_blocklist;
|
||||
pub use validate::validate_path;
|
||||
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Get the basename (final component) of a path for safe error messages.
|
||||
///
|
||||
/// This is used to avoid leaking full paths in error messages.
|
||||
pub fn basename(path: &Path) -> String {
|
||||
path.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| String::new())
|
||||
}
|
||||
|
||||
/// Check if a path is absolute.
|
||||
///
|
||||
/// On Windows, this handles:
|
||||
/// - Standard absolute paths (C:\...)
|
||||
/// - UNC paths (\\server\share\...)
|
||||
/// - Long-path prefixes (\\?\...)
|
||||
/// - Verbatim UNC (\\?\UNC\server\share\...)
|
||||
pub fn is_absolute(path: &Path) -> bool {
|
||||
path.is_absolute()
|
||||
}
|
||||
|
||||
/// Normalize a path to use the platform's standard separator.
|
||||
///
|
||||
/// On Windows, this converts forward slashes to backslashes.
|
||||
/// On Unix, this is a no-op.
|
||||
pub fn normalize_separators(path: &Path) -> PathBuf {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
let s = path.to_string_lossy();
|
||||
let normalized = s.replace('/', "\\");
|
||||
PathBuf::from(normalized)
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
path.to_path_buf()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_basename() {
|
||||
assert_eq!(basename(Path::new("/foo/bar/baz.txt")), "baz.txt");
|
||||
assert_eq!(basename(Path::new("baz.txt")), "baz.txt");
|
||||
assert_eq!(basename(Path::new("/")), "");
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[test]
|
||||
fn test_is_absolute_windows() {
|
||||
assert!(is_absolute(Path::new("C:\\Users")));
|
||||
assert!(is_absolute(Path::new("C:/Users")));
|
||||
assert!(is_absolute(Path::new("\\\\server\\share")));
|
||||
assert!(is_absolute(Path::new("\\\\?\\C:\\Users")));
|
||||
assert!(!is_absolute(Path::new("relative\\path")));
|
||||
assert!(!is_absolute(Path::new("..\\parent")));
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn test_is_absolute_unix() {
|
||||
assert!(is_absolute(Path::new("/home/user")));
|
||||
assert!(!is_absolute(Path::new("relative/path")));
|
||||
assert!(!is_absolute(Path::new("../parent")));
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[test]
|
||||
fn test_normalize_separators_windows() {
|
||||
assert_eq!(
|
||||
normalize_separators(Path::new("C:/Users/foo/bar.txt")),
|
||||
PathBuf::from("C:\\Users\\foo\\bar.txt")
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_separators(Path::new("C:\\Users\\foo\\bar.txt")),
|
||||
PathBuf::from("C:\\Users\\foo\\bar.txt")
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,220 @@
|
||||
//! Blocklist evaluation using glob patterns.
|
||||
//!
|
||||
//! This module checks if paths match configured blocklist patterns,
|
||||
//! allowing fine-grained denial of sensitive paths even within allowed roots.
|
||||
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use globset::{Glob, GlobSet, GlobSetBuilder};
|
||||
use std::path::Path;
|
||||
|
||||
/// Check if a path matches any blocklist patterns.
|
||||
///
|
||||
/// Blocklist patterns can be:
|
||||
/// - Absolute paths (exact match)
|
||||
/// - Glob patterns (e.g., `**/.env`, `**/secrets/**`)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `canonical_path` - The canonical path to check
|
||||
/// * `blocklist_patterns` - List of path patterns or globs to deny
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Ok(()) if the path is not blocked, or an error if it matches a blocklist pattern.
|
||||
///
|
||||
/// # Performance
|
||||
///
|
||||
/// For best performance, pre-compile patterns into a `CompiledBlocklist` and use
|
||||
/// `check_blocklist_compiled` instead.
|
||||
pub fn check_blocklist(
|
||||
canonical_path: &Path,
|
||||
blocklist_patterns: &[String],
|
||||
) -> ToolResult<()> {
|
||||
if blocklist_patterns.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Compile patterns on the fly (slower)
|
||||
let compiled = compile_blocklist(blocklist_patterns)?;
|
||||
check_blocklist_compiled(canonical_path, &compiled)
|
||||
}
|
||||
|
||||
/// Pre-compiled blocklist for efficient matching.
|
||||
pub struct CompiledBlocklist {
|
||||
glob_set: GlobSet,
|
||||
}
|
||||
|
||||
impl CompiledBlocklist {
|
||||
/// Create a new compiled blocklist.
|
||||
pub fn new(glob_set: GlobSet) -> Self {
|
||||
Self { glob_set }
|
||||
}
|
||||
|
||||
/// Get the underlying GlobSet.
|
||||
pub fn glob_set(&self) -> &GlobSet {
|
||||
&self.glob_set
|
||||
}
|
||||
}
|
||||
|
||||
/// Compile blocklist patterns into a GlobSet for efficient matching.
|
||||
///
|
||||
/// This should be done once at configuration load time.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `patterns` - List of glob patterns or absolute paths
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A compiled blocklist ready for fast matching.
|
||||
pub fn compile_blocklist(patterns: &[String]) -> ToolResult<CompiledBlocklist> {
|
||||
let mut builder = GlobSetBuilder::new();
|
||||
|
||||
for pattern in patterns {
|
||||
let glob = Glob::new(pattern).map_err(|e| {
|
||||
ToolError::InvalidConfig(format!("Invalid blocklist pattern '{}': {}", pattern, e))
|
||||
})?;
|
||||
builder.add(glob);
|
||||
}
|
||||
|
||||
let glob_set = builder.build().map_err(|e| {
|
||||
ToolError::InvalidConfig(format!("Failed to compile blocklist patterns: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(CompiledBlocklist::new(glob_set))
|
||||
}
|
||||
|
||||
/// Check if a path matches a pre-compiled blocklist.
|
||||
///
|
||||
/// This is the fast path for blocklist checking.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `canonical_path` - The canonical path to check
|
||||
/// * `compiled` - Pre-compiled blocklist
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Ok(()) if the path is not blocked, or an error if it matches a pattern.
|
||||
pub fn check_blocklist_compiled(
|
||||
canonical_path: &Path,
|
||||
compiled: &CompiledBlocklist,
|
||||
) -> ToolResult<()> {
|
||||
if compiled.glob_set.is_match(canonical_path) {
|
||||
return Err(ToolError::blocked_path(format!(
|
||||
"Path matches blocklist: {}",
|
||||
super::basename(canonical_path)
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_blocklist() {
|
||||
let result = check_blocklist(Path::new("/any/path"), &[]);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_path_match() {
|
||||
let patterns = vec![
|
||||
"/etc/passwd".to_string(),
|
||||
"/home/user/.ssh/id_rsa".to_string(),
|
||||
];
|
||||
|
||||
// Blocked
|
||||
assert!(check_blocklist(Path::new("/etc/passwd"), &patterns).is_err());
|
||||
assert!(check_blocklist(Path::new("/home/user/.ssh/id_rsa"), &patterns).is_err());
|
||||
|
||||
// Not blocked
|
||||
assert!(check_blocklist(Path::new("/etc/hosts"), &patterns).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glob_patterns() {
|
||||
let patterns = vec![
|
||||
"**/.env".to_string(),
|
||||
"**/secrets/**".to_string(),
|
||||
"**/.git/**".to_string(),
|
||||
];
|
||||
|
||||
// Blocked by .env pattern
|
||||
assert!(check_blocklist(Path::new("/project/.env"), &patterns).is_err());
|
||||
assert!(check_blocklist(Path::new("/project/subdir/.env"), &patterns).is_err());
|
||||
|
||||
// Blocked by secrets pattern
|
||||
assert!(check_blocklist(Path::new("/project/secrets/key.txt"), &patterns).is_err());
|
||||
assert!(check_blocklist(Path::new("/app/secrets/api_key"), &patterns).is_err());
|
||||
|
||||
// Blocked by .git pattern
|
||||
assert!(check_blocklist(Path::new("/project/.git/config"), &patterns).is_err());
|
||||
|
||||
// Not blocked
|
||||
assert!(check_blocklist(Path::new("/project/src/main.rs"), &patterns).is_ok());
|
||||
assert!(check_blocklist(Path::new("/project/README.md"), &patterns).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compile_blocklist() {
|
||||
let patterns = vec![
|
||||
"**/.env".to_string(),
|
||||
"**/secrets/**".to_string(),
|
||||
];
|
||||
|
||||
let compiled = compile_blocklist(&patterns).unwrap();
|
||||
|
||||
// Blocked
|
||||
assert!(check_blocklist_compiled(
|
||||
Path::new("/project/.env"),
|
||||
&compiled
|
||||
).is_err());
|
||||
|
||||
// Not blocked
|
||||
assert!(check_blocklist_compiled(
|
||||
Path::new("/project/src/main.rs"),
|
||||
&compiled
|
||||
).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_pattern() {
|
||||
let patterns = vec![
|
||||
"[invalid".to_string(), // Invalid glob syntax
|
||||
];
|
||||
|
||||
let result = compile_blocklist(&patterns);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[test]
|
||||
fn test_windows_paths() {
|
||||
let patterns = vec![
|
||||
"**/.env".to_string(),
|
||||
"C:/secrets/**".to_string(),
|
||||
];
|
||||
|
||||
let compiled = compile_blocklist(&patterns).unwrap();
|
||||
|
||||
// Blocked
|
||||
assert!(check_blocklist_compiled(
|
||||
Path::new("C:\\project\\.env"),
|
||||
&compiled
|
||||
).is_err());
|
||||
assert!(check_blocklist_compiled(
|
||||
Path::new("C:\\secrets\\key.txt"),
|
||||
&compiled
|
||||
).is_err());
|
||||
|
||||
// Not blocked
|
||||
assert!(check_blocklist_compiled(
|
||||
Path::new("C:\\project\\src\\main.rs"),
|
||||
&compiled
|
||||
).is_ok());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,387 @@
|
||||
//! Path canonicalization with cross-platform support.
|
||||
//!
|
||||
//! This module handles:
|
||||
//! - Converting relative paths to absolute
|
||||
//! - Resolving symlinks and junctions per policy
|
||||
//! - Normalizing Windows path formats (UNC, long-path, MINGW)
|
||||
//! - Rejecting Windows reserved device names
|
||||
//! - Handling non-existent paths (for write operations)
|
||||
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use std::path::{Path, PathBuf, Component};
|
||||
|
||||
/// Symlink handling policy for path canonicalization.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct SymlinkPolicy {
|
||||
/// Allow symlinks to escape allowed roots.
|
||||
///
|
||||
/// If false (recommended), symlinks pointing outside allowed roots are rejected.
|
||||
pub allow_symlink_escape: bool,
|
||||
|
||||
/// Follow symlinks within allowed roots.
|
||||
///
|
||||
/// If true, symlinks within allowed roots are followed during canonicalization.
|
||||
pub follow_symlinks_within_roots: bool,
|
||||
}
|
||||
|
||||
impl Default for SymlinkPolicy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allow_symlink_escape: false,
|
||||
follow_symlinks_within_roots: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SymlinkPolicy {
|
||||
/// Create a policy that follows all symlinks (least restrictive).
|
||||
pub fn follow_all() -> Self {
|
||||
Self {
|
||||
allow_symlink_escape: true,
|
||||
follow_symlinks_within_roots: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a policy that never follows symlinks (most restrictive).
|
||||
pub fn follow_none() -> Self {
|
||||
Self {
|
||||
allow_symlink_escape: false,
|
||||
follow_symlinks_within_roots: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Canonicalize a path with the given symlink policy.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Rejects empty or relative paths (ACP requires absolute paths)
|
||||
/// 2. Normalizes separators to platform standard
|
||||
/// 3. Converts MINGW-style paths (/c/...) to native Windows paths (C:\...)
|
||||
/// 4. Strips long-path prefixes (\\?\) for policy comparison
|
||||
/// 5. Resolves symlinks/junctions per policy
|
||||
/// 6. For non-existent paths, canonicalizes parent + appends remainder
|
||||
/// 7. Ensures no ".." escapes remain
|
||||
/// 8. Rejects Windows reserved device names (CON, NUL, etc.)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `user_path` - The path to canonicalize (must be absolute)
|
||||
/// * `policy` - Symlink handling policy
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The canonical absolute path, or an error if the path is invalid or violates policy.
|
||||
///
|
||||
/// # Security
|
||||
///
|
||||
/// This is a security-critical function. It must prevent:
|
||||
/// - Path traversal attacks via ".." components
|
||||
/// - Symlink escape attacks
|
||||
/// - Access to Windows reserved device names
|
||||
pub fn canonicalize_path(user_path: &Path, _policy: &SymlinkPolicy) -> ToolResult<PathBuf> {
|
||||
// Step 1: Reject empty paths
|
||||
if user_path.as_os_str().is_empty() {
|
||||
return Err(ToolError::sandbox_violation("Path cannot be empty"));
|
||||
}
|
||||
|
||||
// Step 2: Convert to string for preprocessing
|
||||
let path_str = user_path.to_string_lossy();
|
||||
|
||||
// Step 3: Handle MINGW-style paths on Windows (/c/... -> C:\...)
|
||||
#[cfg(windows)]
|
||||
let path_str = convert_mingw_path(&path_str);
|
||||
|
||||
let mut path = PathBuf::from(path_str.as_ref());
|
||||
|
||||
// Step 4: Normalize separators
|
||||
path = super::normalize_separators(&path);
|
||||
|
||||
// Step 5: Reject relative paths
|
||||
if !path.is_absolute() {
|
||||
return Err(ToolError::sandbox_violation(format!(
|
||||
"Path must be absolute: {}",
|
||||
super::basename(&path)
|
||||
)));
|
||||
}
|
||||
|
||||
// Step 6: Strip long-path prefix for normalization (\\?\C:\... -> C:\...)
|
||||
#[cfg(windows)]
|
||||
{
|
||||
path = strip_long_path_prefix(&path);
|
||||
}
|
||||
|
||||
// Step 7: Check for reserved device names on Windows
|
||||
#[cfg(windows)]
|
||||
{
|
||||
if is_reserved_device_name(&path) {
|
||||
return Err(ToolError::sandbox_violation(format!(
|
||||
"Path is a reserved device name: {}",
|
||||
super::basename(&path)
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Step 8: Normalize drive letter case on Windows
|
||||
#[cfg(windows)]
|
||||
{
|
||||
path = normalize_drive_letter(&path);
|
||||
}
|
||||
|
||||
// Step 9: Try to canonicalize using dunce (follows symlinks on all platforms)
|
||||
// If the path doesn't exist, find the first existing ancestor and canonicalize it,
|
||||
// then append all non-existent components
|
||||
let canonical = match dunce::canonicalize(&path) {
|
||||
Ok(canonical) => canonical,
|
||||
Err(_) => {
|
||||
// Path doesn't exist - find first existing ancestor
|
||||
canonicalize_non_existent_path(&path)?
|
||||
}
|
||||
};
|
||||
|
||||
// Step 10: Verify no ".." components remain (security check)
|
||||
for component in canonical.components() {
|
||||
if component == Component::ParentDir {
|
||||
return Err(ToolError::sandbox_violation(
|
||||
"Path contains '..' after canonicalization (potential traversal attack)",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Step 11: Strip long-path prefix again if added by canonicalize
|
||||
#[cfg(windows)]
|
||||
let canonical = strip_long_path_prefix(&canonical);
|
||||
|
||||
Ok(canonical)
|
||||
}
|
||||
|
||||
/// Canonicalize a non-existent path by finding the first existing ancestor.
|
||||
///
|
||||
/// This function walks up the directory tree until it finds an existing directory,
|
||||
/// then builds the canonical path by appending the non-existent components.
|
||||
fn canonicalize_non_existent_path(path: &Path) -> ToolResult<PathBuf> {
|
||||
// Collect components of the non-existent path parts
|
||||
let mut components_to_append: Vec<std::ffi::OsString> = Vec::new();
|
||||
let mut current = path.to_path_buf();
|
||||
|
||||
// Walk up the directory tree until we find an existing directory
|
||||
loop {
|
||||
if let Some(parent) = current.parent() {
|
||||
if parent == current {
|
||||
// Hit root without finding existing directory
|
||||
// This shouldn't happen with absolute paths, but handle it
|
||||
return Ok(path.to_path_buf());
|
||||
}
|
||||
|
||||
// Try to canonicalize the parent
|
||||
match dunce::canonicalize(parent) {
|
||||
Ok(canonical_parent) => {
|
||||
// Found existing parent - build the full path
|
||||
components_to_append.reverse();
|
||||
let mut result = canonical_parent;
|
||||
|
||||
// First add the current file_name (if any)
|
||||
if let Some(file_name) = current.file_name() {
|
||||
result = result.join(file_name);
|
||||
}
|
||||
|
||||
// Then add all the rest
|
||||
for component in components_to_append {
|
||||
result = result.join(component);
|
||||
}
|
||||
|
||||
return Ok(result);
|
||||
}
|
||||
Err(_) => {
|
||||
// Parent doesn't exist - collect the component and continue up
|
||||
if let Some(file_name) = current.file_name() {
|
||||
components_to_append.push(file_name.to_os_string());
|
||||
}
|
||||
current = parent.to_path_buf();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No parent (shouldn't reach here with absolute paths)
|
||||
return Ok(path.to_path_buf());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert MINGW-style path (/c/Users/...) to Windows path (C:\Users\...).
|
||||
///
|
||||
/// Only applies on Windows. Detects paths like:
|
||||
/// - /c/... -> C:\...
|
||||
/// - /d/... -> D:\...
|
||||
#[cfg(windows)]
|
||||
fn convert_mingw_path(path: &str) -> std::borrow::Cow<'_, str> {
|
||||
// Check for MINGW-style path: starts with /<letter>/
|
||||
if path.len() >= 3
|
||||
&& path.starts_with('/')
|
||||
&& path.chars().nth(1).map_or(false, |c| c.is_ascii_alphabetic())
|
||||
&& (path.len() == 3 || path.chars().nth(2) == Some('/'))
|
||||
{
|
||||
let drive_letter = path.chars().nth(1).unwrap().to_ascii_uppercase();
|
||||
let rest = if path.len() > 3 {
|
||||
// Normalize forward slashes to backslashes in the remainder
|
||||
path[3..].replace('/', "\\")
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
format!("{}:\\{}", drive_letter, rest).into()
|
||||
} else {
|
||||
path.into()
|
||||
}
|
||||
}
|
||||
|
||||
/// Strip the long-path prefix (\\?\) from a Windows path.
|
||||
///
|
||||
/// Also handles verbatim UNC paths (\\?\UNC\server\share -> \\server\share).
|
||||
#[cfg(windows)]
|
||||
fn strip_long_path_prefix(path: &Path) -> PathBuf {
|
||||
let path_str = path.to_string_lossy();
|
||||
|
||||
// Handle \\?\UNC\server\share -> \\server\share
|
||||
if path_str.starts_with(r"\\?\UNC\") {
|
||||
let without_prefix = &path_str[r"\\?\UNC\".len()..];
|
||||
return PathBuf::from(format!(r"\\{}", without_prefix));
|
||||
}
|
||||
|
||||
// Handle \\?\C:\ -> C:\
|
||||
if path_str.starts_with(r"\\?\") {
|
||||
let without_prefix = &path_str[r"\\?\".len()..];
|
||||
return PathBuf::from(without_prefix);
|
||||
}
|
||||
|
||||
path.to_path_buf()
|
||||
}
|
||||
|
||||
/// Normalize drive letter to uppercase (C:\ instead of c:\).
|
||||
#[cfg(windows)]
|
||||
fn normalize_drive_letter(path: &Path) -> PathBuf {
|
||||
let path_str = path.to_string_lossy();
|
||||
|
||||
// Check if path starts with a drive letter (e.g., "c:\")
|
||||
if path_str.len() >= 2
|
||||
&& path_str.chars().nth(0).map_or(false, |c| c.is_ascii_alphabetic())
|
||||
&& path_str.chars().nth(1) == Some(':')
|
||||
{
|
||||
let mut chars: Vec<char> = path_str.chars().collect();
|
||||
chars[0] = chars[0].to_ascii_uppercase();
|
||||
return PathBuf::from(chars.iter().collect::<String>());
|
||||
}
|
||||
|
||||
path.to_path_buf()
|
||||
}
|
||||
|
||||
/// Check if a path is a Windows reserved device name.
|
||||
///
|
||||
/// Reserved names include:
|
||||
/// - CON, PRN, AUX, NUL
|
||||
/// - COM1-COM9
|
||||
/// - LPT1-LPT9
|
||||
///
|
||||
/// These can appear with or without extensions (e.g., CON.txt is also reserved).
|
||||
#[cfg(windows)]
|
||||
fn is_reserved_device_name(path: &Path) -> bool {
|
||||
const RESERVED_NAMES: &[&str] = &[
|
||||
"CON", "PRN", "AUX", "NUL",
|
||||
"COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9",
|
||||
"LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
|
||||
];
|
||||
|
||||
if let Some(file_name) = path.file_name() {
|
||||
let name = file_name.to_string_lossy().to_uppercase();
|
||||
|
||||
// Check exact match
|
||||
if RESERVED_NAMES.contains(&name.as_str()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check with extension (e.g., CON.txt)
|
||||
if let Some(stem) = Path::new(&*name).file_stem() {
|
||||
let stem_str = stem.to_string_lossy();
|
||||
if RESERVED_NAMES.contains(&stem_str.as_ref()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_path() {
|
||||
let result = canonicalize_path(Path::new(""), &SymlinkPolicy::default());
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[test]
|
||||
fn test_convert_mingw_path() {
|
||||
assert_eq!(convert_mingw_path("/c/Users/foo"), "C:\\Users\\foo");
|
||||
assert_eq!(convert_mingw_path("/d/Projects"), "D:\\Projects");
|
||||
assert_eq!(convert_mingw_path("/c/"), "C:\\");
|
||||
assert_eq!(convert_mingw_path("C:\\Users"), "C:\\Users"); // No conversion
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[test]
|
||||
fn test_strip_long_path_prefix() {
|
||||
assert_eq!(
|
||||
strip_long_path_prefix(Path::new(r"\\?\C:\Users\foo")),
|
||||
PathBuf::from(r"C:\Users\foo")
|
||||
);
|
||||
assert_eq!(
|
||||
strip_long_path_prefix(Path::new(r"\\?\UNC\server\share\file")),
|
||||
PathBuf::from(r"\\server\share\file")
|
||||
);
|
||||
assert_eq!(
|
||||
strip_long_path_prefix(Path::new(r"C:\Users\foo")),
|
||||
PathBuf::from(r"C:\Users\foo")
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[test]
|
||||
fn test_normalize_drive_letter() {
|
||||
assert_eq!(
|
||||
normalize_drive_letter(Path::new("c:\\Users")),
|
||||
PathBuf::from("C:\\Users")
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_drive_letter(Path::new("C:\\Users")),
|
||||
PathBuf::from("C:\\Users")
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[test]
|
||||
fn test_is_reserved_device_name() {
|
||||
assert!(is_reserved_device_name(Path::new("CON")));
|
||||
assert!(is_reserved_device_name(Path::new("con")));
|
||||
assert!(is_reserved_device_name(Path::new("CON.txt")));
|
||||
assert!(is_reserved_device_name(Path::new("C:\\path\\to\\NUL")));
|
||||
assert!(is_reserved_device_name(Path::new("COM1")));
|
||||
assert!(is_reserved_device_name(Path::new("LPT5")));
|
||||
assert!(!is_reserved_device_name(Path::new("CONNECT.txt")));
|
||||
assert!(!is_reserved_device_name(Path::new("file.txt")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_symlink_policy() {
|
||||
let default = SymlinkPolicy::default();
|
||||
assert!(!default.allow_symlink_escape);
|
||||
assert!(default.follow_symlinks_within_roots);
|
||||
|
||||
let follow_all = SymlinkPolicy::follow_all();
|
||||
assert!(follow_all.allow_symlink_escape);
|
||||
assert!(follow_all.follow_symlinks_within_roots);
|
||||
|
||||
let follow_none = SymlinkPolicy::follow_none();
|
||||
assert!(!follow_none.allow_symlink_escape);
|
||||
assert!(!follow_none.follow_symlinks_within_roots);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,265 @@
|
||||
//! Path containment checking for sandbox enforcement.
|
||||
//!
|
||||
//! This module verifies that a canonical path is within allowed sandbox roots.
|
||||
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use std::path::{Path, PathBuf, Component};
|
||||
|
||||
/// Check if a canonical path is contained within allowed roots.
|
||||
///
|
||||
/// This function performs component-wise prefix matching (not string matching)
|
||||
/// to ensure the path is strictly within at least one allowed root.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `canonical_path` - The canonical path to check (must already be canonicalized)
|
||||
/// * `allowed_roots` - List of allowed root paths (must already be canonical)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Ok(()) if the path is contained, or an error if it's outside all roots.
|
||||
///
|
||||
/// # Security
|
||||
///
|
||||
/// This is a security-critical function. It uses component-wise comparison to prevent:
|
||||
/// - String prefix attacks (e.g., "/a" should not contain "/ab")
|
||||
/// - Path traversal attacks
|
||||
pub fn check_containment(
|
||||
canonical_path: &Path,
|
||||
allowed_roots: &[PathBuf],
|
||||
) -> ToolResult<()> {
|
||||
// If no roots configured, deny all paths
|
||||
if allowed_roots.is_empty() {
|
||||
return Err(ToolError::sandbox_violation(
|
||||
"No allowed roots configured - all paths denied",
|
||||
));
|
||||
}
|
||||
|
||||
// Check if path is contained by at least one root
|
||||
for root in allowed_roots {
|
||||
if is_contained_by(canonical_path, root) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Path is outside all roots
|
||||
Err(ToolError::sandbox_violation(format!(
|
||||
"Path is outside allowed roots: {}",
|
||||
super::basename(canonical_path)
|
||||
)))
|
||||
}
|
||||
|
||||
/// Check if a path is strictly contained by a root.
|
||||
///
|
||||
/// This performs component-wise comparison to ensure proper containment:
|
||||
/// - "/a/b/c" is contained by "/a"
|
||||
/// - "/a/b/c" is NOT contained by "/a/b/c" (must be strict)
|
||||
/// - "/ab" is NOT contained by "/a" (not a string prefix match)
|
||||
///
|
||||
/// On Windows, comparison is case-insensitive to match filesystem behavior.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - The path to check (must be canonical)
|
||||
/// * `root` - The root path (must be canonical)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `true` if path is strictly contained within root, `false` otherwise.
|
||||
pub fn is_contained_by(path: &Path, root: &Path) -> bool {
|
||||
// Get components
|
||||
let path_components: Vec<_> = path.components().collect();
|
||||
let root_components: Vec<_> = root.components().collect();
|
||||
|
||||
// Path must have MORE components than root (strict containment)
|
||||
// Equal components means path == root, which is not strict containment
|
||||
if path_components.len() <= root_components.len() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check if all root components match path components (prefix)
|
||||
for (i, root_comp) in root_components.iter().enumerate() {
|
||||
let path_comp = &path_components[i];
|
||||
|
||||
if !components_equal(path_comp, root_comp) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Compare two path components for equality.
|
||||
///
|
||||
/// On Windows, this is case-insensitive.
|
||||
/// On Unix, this is case-sensitive.
|
||||
fn components_equal(a: &Component, b: &Component) -> bool {
|
||||
#[cfg(windows)]
|
||||
{
|
||||
// Case-insensitive comparison on Windows
|
||||
let a_str = component_to_string(a).to_lowercase();
|
||||
let b_str = component_to_string(b).to_lowercase();
|
||||
a_str == b_str
|
||||
}
|
||||
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
// Case-sensitive comparison on Unix
|
||||
a == b
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a path component to a string for comparison.
|
||||
#[cfg(windows)]
|
||||
fn component_to_string(comp: &Component) -> String {
|
||||
match comp {
|
||||
Component::Prefix(prefix) => prefix.as_os_str().to_string_lossy().to_string(),
|
||||
Component::RootDir => "/".to_string(),
|
||||
Component::CurDir => ".".to_string(),
|
||||
Component::ParentDir => "..".to_string(),
|
||||
Component::Normal(s) => s.to_string_lossy().to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_contained_by_unix() {
|
||||
// Strict containment (child must be strictly inside parent)
|
||||
assert!(is_contained_by(
|
||||
Path::new("/a/b/c"),
|
||||
Path::new("/a")
|
||||
));
|
||||
assert!(is_contained_by(
|
||||
Path::new("/a/b/c"),
|
||||
Path::new("/a/b")
|
||||
));
|
||||
|
||||
// Equal paths are NOT strictly contained
|
||||
assert!(!is_contained_by(
|
||||
Path::new("/a/b"),
|
||||
Path::new("/a/b")
|
||||
));
|
||||
|
||||
// Parent is not contained by child
|
||||
assert!(!is_contained_by(
|
||||
Path::new("/a"),
|
||||
Path::new("/a/b")
|
||||
));
|
||||
|
||||
// Different trees
|
||||
assert!(!is_contained_by(
|
||||
Path::new("/a/b"),
|
||||
Path::new("/c")
|
||||
));
|
||||
|
||||
// String prefix but not path prefix
|
||||
assert!(!is_contained_by(
|
||||
Path::new("/ab"),
|
||||
Path::new("/a")
|
||||
));
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[test]
|
||||
fn test_is_contained_by_windows_case_insensitive() {
|
||||
// Windows is case-insensitive
|
||||
assert!(is_contained_by(
|
||||
Path::new("C:\\Users\\foo\\bar"),
|
||||
Path::new("C:\\Users")
|
||||
));
|
||||
assert!(is_contained_by(
|
||||
Path::new("c:\\users\\foo\\bar"),
|
||||
Path::new("C:\\Users")
|
||||
));
|
||||
assert!(is_contained_by(
|
||||
Path::new("C:\\Users\\foo"),
|
||||
Path::new("c:\\users")
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_containment_no_roots() {
|
||||
let result = check_containment(Path::new("/a/b/c"), &[]);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_containment_allowed() {
|
||||
let roots = vec![
|
||||
PathBuf::from("/home/user/project"),
|
||||
PathBuf::from("/tmp"),
|
||||
];
|
||||
|
||||
// Inside first root
|
||||
assert!(check_containment(
|
||||
Path::new("/home/user/project/src/main.rs"),
|
||||
&roots
|
||||
).is_ok());
|
||||
|
||||
// Inside second root
|
||||
assert!(check_containment(
|
||||
Path::new("/tmp/foo.txt"),
|
||||
&roots
|
||||
).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_containment_denied() {
|
||||
let roots = vec![
|
||||
PathBuf::from("/home/user/project"),
|
||||
];
|
||||
|
||||
// Outside all roots
|
||||
assert!(check_containment(
|
||||
Path::new("/etc/passwd"),
|
||||
&roots
|
||||
).is_err());
|
||||
|
||||
// Sibling directory (string prefix but not path prefix)
|
||||
assert!(check_containment(
|
||||
Path::new("/home/user/project2/file.txt"),
|
||||
&roots
|
||||
).is_err());
|
||||
|
||||
// Parent directory
|
||||
assert!(check_containment(
|
||||
Path::new("/home/user/other.txt"),
|
||||
&roots
|
||||
).is_err());
|
||||
|
||||
// Equal to root (not strictly contained)
|
||||
assert!(check_containment(
|
||||
Path::new("/home/user/project"),
|
||||
&roots
|
||||
).is_err());
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
#[test]
|
||||
fn test_check_containment_windows() {
|
||||
let roots = vec![
|
||||
PathBuf::from("C:\\Users\\foo"),
|
||||
];
|
||||
|
||||
// Allowed
|
||||
assert!(check_containment(
|
||||
Path::new("C:\\Users\\foo\\Documents\\file.txt"),
|
||||
&roots
|
||||
).is_ok());
|
||||
|
||||
// Denied - different drive
|
||||
assert!(check_containment(
|
||||
Path::new("D:\\file.txt"),
|
||||
&roots
|
||||
).is_err());
|
||||
|
||||
// Denied - outside root
|
||||
assert!(check_containment(
|
||||
Path::new("C:\\Windows\\system32"),
|
||||
&roots
|
||||
).is_err());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
//! Path validation facade combining all security checks.
|
||||
//!
|
||||
//! This module provides a single entry point for path validation that:
|
||||
//! 1. Canonicalizes the path
|
||||
//! 2. Checks containment within allowed roots
|
||||
//! 3. Evaluates blocklist patterns
|
||||
//!
|
||||
//! All tool operations should use this facade for consistent security enforcement.
|
||||
|
||||
use crate::config::SandboxConfig;
|
||||
use crate::error::ToolResult;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use super::blocklist::{compile_blocklist, check_blocklist_compiled, CompiledBlocklist};
|
||||
use super::canonicalize::{canonicalize_path, SymlinkPolicy};
|
||||
use super::containment::check_containment;
|
||||
|
||||
/// Validate a path against sandbox configuration.
|
||||
///
|
||||
/// This is the main entry point for path validation. It performs all security checks:
|
||||
/// 1. Path canonicalization (handles symlinks, Windows paths, etc.)
|
||||
/// 2. Containment checking (ensures path is within allowed roots)
|
||||
/// 3. Blocklist evaluation (checks against denied patterns)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `user_path` - The path provided by the user/agent (must be absolute)
|
||||
/// * `config` - Sandbox configuration with allowed roots and blocklist
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The canonical path if all checks pass, or a security error otherwise.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use dirigent_tools::path::validate_path;
|
||||
/// use dirigent_tools::config::SandboxConfig;
|
||||
/// use std::path::PathBuf;
|
||||
///
|
||||
/// let mut config = SandboxConfig::default();
|
||||
/// config.allowed_roots = vec![PathBuf::from("/home/user/project")];
|
||||
/// config.blocked_paths = vec!["**/.env".to_string()];
|
||||
///
|
||||
/// // Valid path
|
||||
/// let result = validate_path("/home/user/project/src/main.rs", &config);
|
||||
/// assert!(result.is_ok());
|
||||
///
|
||||
/// // Blocked path
|
||||
/// let result = validate_path("/home/user/project/.env", &config);
|
||||
/// assert!(result.is_err());
|
||||
///
|
||||
/// // Outside roots
|
||||
/// let result = validate_path("/etc/passwd", &config);
|
||||
/// assert!(result.is_err());
|
||||
/// ```
|
||||
pub fn validate_path(user_path: &str, config: &SandboxConfig) -> ToolResult<PathBuf> {
|
||||
// Step 1: Canonicalize the path
|
||||
let symlink_policy = SymlinkPolicy {
|
||||
allow_symlink_escape: config.allow_symlink_escape,
|
||||
follow_symlinks_within_roots: config.follow_symlinks_within_roots,
|
||||
};
|
||||
|
||||
let canonical_path = canonicalize_path(Path::new(user_path), &symlink_policy)?;
|
||||
|
||||
// Step 2: Check containment
|
||||
check_containment(&canonical_path, &config.allowed_roots)?;
|
||||
|
||||
// Step 3: Check blocklist
|
||||
if !config.blocked_paths.is_empty() {
|
||||
// Compile blocklist on the fly (for now)
|
||||
// TODO: Pre-compile blocklist in SandboxConfig for better performance
|
||||
let compiled = compile_blocklist(&config.blocked_paths)?;
|
||||
check_blocklist_compiled(&canonical_path, &compiled)?;
|
||||
}
|
||||
|
||||
Ok(canonical_path)
|
||||
}
|
||||
|
||||
/// Validate a path with a pre-compiled blocklist (for performance).
|
||||
///
|
||||
/// This is a faster variant of `validate_path` that uses a pre-compiled blocklist.
|
||||
/// Useful when validating many paths with the same configuration.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `user_path` - The path provided by the user/agent (must be absolute)
|
||||
/// * `config` - Sandbox configuration
|
||||
/// * `compiled_blocklist` - Pre-compiled blocklist (can be None if blocklist is empty)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The canonical path if all checks pass, or a security error otherwise.
|
||||
pub fn validate_path_compiled(
|
||||
user_path: &str,
|
||||
config: &SandboxConfig,
|
||||
compiled_blocklist: Option<&CompiledBlocklist>,
|
||||
) -> ToolResult<PathBuf> {
|
||||
// Step 1: Canonicalize the path
|
||||
let symlink_policy = SymlinkPolicy {
|
||||
allow_symlink_escape: config.allow_symlink_escape,
|
||||
follow_symlinks_within_roots: config.follow_symlinks_within_roots,
|
||||
};
|
||||
|
||||
let canonical_path = canonicalize_path(Path::new(user_path), &symlink_policy)?;
|
||||
|
||||
// Step 2: Check containment
|
||||
check_containment(&canonical_path, &config.allowed_roots)?;
|
||||
|
||||
// Step 3: Check blocklist (if provided)
|
||||
if let Some(compiled) = compiled_blocklist {
|
||||
check_blocklist_compiled(&canonical_path, &compiled)?;
|
||||
}
|
||||
|
||||
Ok(canonical_path)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ToolError;
|
||||
|
||||
#[test]
|
||||
fn test_validate_path_no_roots() {
|
||||
let mut config = SandboxConfig::default();
|
||||
config.allowed_roots = vec![]; // No roots allowed
|
||||
|
||||
// Any path should fail with no roots
|
||||
let result = validate_path("/any/path", &config);
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result, Err(ToolError::SandboxViolation { .. })));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_path_empty_blocklist() {
|
||||
// Note: Integration tests with real filesystem are in tests/path_normalization.rs
|
||||
// Unit tests here focus on error conditions without filesystem dependencies
|
||||
let config = SandboxConfig::default();
|
||||
assert_eq!(config.blocked_paths.len(), 2); // Default has .env and secrets
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compile_blocklist_empty() {
|
||||
let compiled = compile_blocklist(&[]).unwrap();
|
||||
assert_eq!(compiled.glob_set().len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compile_blocklist_valid() {
|
||||
let patterns = vec!["**/.env".to_string(), "**/secrets/**".to_string()];
|
||||
let result = compile_blocklist(&patterns);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let compiled = result.unwrap();
|
||||
assert_eq!(compiled.glob_set().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compile_blocklist_invalid() {
|
||||
let patterns = vec!["[invalid".to_string()];
|
||||
let result = compile_blocklist(&patterns);
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result, Err(ToolError::InvalidConfig { .. })));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
//! Permission prompt system and decision caching.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-PERM-01 through TOOLS-PERM-04)
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - **Permission checking** - Core algorithm integrating cache, whitelist, and ACP prompts
|
||||
//! - **Decision caching** - Thread-safe cache with TTL and scope support
|
||||
//! - **Whitelist matching** - Pattern-based auto-approval for safe operations
|
||||
//! - **ACP integration** - User prompts via session/request_permission
|
||||
//!
|
||||
//! ## Module Structure
|
||||
//!
|
||||
//! - [`check`] - Core permission check function
|
||||
//! - [`cache`] - Decision cache with TTL
|
||||
//! - [`whitelist`] - Whitelist pattern matching
|
||||
//! - [`acp`] - ACP integration for user prompts
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use dirigent_tools::config::{PermissionConfig, PermissionMode, WhitelistConfig};
|
||||
//! use dirigent_tools::permission::check::{PermissionContext, check_permission};
|
||||
//! use dirigent_tools::permission::whitelist::{CompiledWhitelist, PermissionOperation};
|
||||
//!
|
||||
//! # async fn example() {
|
||||
//! // Configure permission system
|
||||
//! let config = PermissionConfig {
|
||||
//! mode: PermissionMode::Whitelist,
|
||||
//! remember_decisions: true,
|
||||
//! remember_ttl_secs: 3600,
|
||||
//! ..Default::default()
|
||||
//! };
|
||||
//!
|
||||
//! // Create context with whitelist
|
||||
//! let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
|
||||
//! let context = PermissionContext::new(
|
||||
//! "connector-1".to_string(),
|
||||
//! Some("session-1".to_string()),
|
||||
//! whitelist,
|
||||
//! );
|
||||
//!
|
||||
//! // Check permission for an operation
|
||||
//! let operation = PermissionOperation::Write {
|
||||
//! path: "/path/to/file".to_string(),
|
||||
//! };
|
||||
//! let decision = check_permission(&operation, &context, &config).await.unwrap();
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
pub mod check;
|
||||
pub mod cache;
|
||||
pub mod whitelist;
|
||||
pub mod acp;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use check::{check_permission, PermissionContext};
|
||||
pub use cache::{CacheKey, DecisionCache, PermissionDecision};
|
||||
pub use whitelist::{CompiledWhitelist, PermissionOperation, matches_whitelist};
|
||||
pub use acp::{AcpPermissionContext, PermissionOutcome, request_permission_from_user};
|
||||
@@ -0,0 +1,242 @@
|
||||
//! ACP integration for permission prompts.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-PERM-03) - Stub for ACP integration
|
||||
//!
|
||||
//! This module provides integration with ACP's `session/request_permission` capability
|
||||
//! to prompt users for permissions when operations require approval.
|
||||
//!
|
||||
//! ## Permission Outcomes
|
||||
//!
|
||||
//! Users can respond with:
|
||||
//! - **AllowOnce**: Approve this operation only
|
||||
//! - **AllowAlways**: Approve this operation and remember the decision
|
||||
//! - **RejectOnce**: Deny this operation only
|
||||
//! - **RejectAlways**: Deny this operation and remember the decision
|
||||
//! - **Cancelled**: User cancelled the prompt (treat as rejection)
|
||||
//!
|
||||
//! ## ACP Protocol Integration
|
||||
//!
|
||||
//! This module calls the ACP client's `session/request_permission` handler (agent → client).
|
||||
//! The actual implementation depends on the ACP client infrastructure in `dirigent_core`.
|
||||
//!
|
||||
//! For now, this provides a stub that can be replaced with real ACP calls when the
|
||||
//! infrastructure is ready.
|
||||
|
||||
use crate::error::ToolResult;
|
||||
use crate::permission::whitelist::PermissionOperation;
|
||||
|
||||
/// Permission outcome from user prompt.
|
||||
///
|
||||
/// Maps directly to ACP session/request_permission response options.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PermissionOutcome {
|
||||
/// Allow this operation once (do not cache)
|
||||
AllowOnce,
|
||||
/// Allow this operation and remember for future (cache with TTL)
|
||||
AllowAlways,
|
||||
/// Reject this operation once (do not cache)
|
||||
RejectOnce,
|
||||
/// Reject this operation and remember for future (cache with TTL)
|
||||
RejectAlways,
|
||||
/// User cancelled the prompt (treat as rejection)
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
impl PermissionOutcome {
|
||||
/// Check if this outcome should be cached.
|
||||
pub fn should_cache(&self) -> bool {
|
||||
matches!(self, Self::AllowAlways | Self::RejectAlways)
|
||||
}
|
||||
|
||||
/// Check if this outcome allows the operation.
|
||||
pub fn is_allowed(&self) -> bool {
|
||||
matches!(self, Self::AllowOnce | Self::AllowAlways)
|
||||
}
|
||||
|
||||
/// Convert to cache decision if this outcome should be cached.
|
||||
pub fn to_cache_decision(&self) -> Option<crate::permission::cache::PermissionDecision> {
|
||||
use crate::permission::cache::PermissionDecision;
|
||||
|
||||
match self {
|
||||
Self::AllowAlways => Some(PermissionDecision::Allowed),
|
||||
Self::RejectAlways => Some(PermissionDecision::Denied),
|
||||
Self::Cancelled => Some(PermissionDecision::Cancelled),
|
||||
Self::AllowOnce | Self::RejectOnce => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Context for ACP permission requests.
|
||||
///
|
||||
/// Contains information needed to make the ACP call:
|
||||
/// - Connector and session identifiers
|
||||
/// - Reference to the ACP session object (when available)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AcpPermissionContext {
|
||||
/// Connector ID
|
||||
pub connector_id: String,
|
||||
/// Session ID (if in a session)
|
||||
pub session_id: Option<String>,
|
||||
// TODO: Add ACP session handle when available
|
||||
// pub session_handle: Arc<AcpSession>,
|
||||
}
|
||||
|
||||
/// Request permission from user via ACP session/request_permission.
|
||||
///
|
||||
/// This calls the ACP protocol's `session/request_permission` capability to present
|
||||
/// a permission prompt to the user with the following options:
|
||||
/// - Allow once
|
||||
/// - Allow always (remember decision)
|
||||
/// - Reject once
|
||||
/// - Reject always (remember decision)
|
||||
///
|
||||
/// ## Implementation Status
|
||||
///
|
||||
/// **TODO**: This is currently a stub that returns `AllowOnce` for all operations.
|
||||
/// Once the ACP client infrastructure in `dirigent_core` is ready, this should be
|
||||
/// replaced with actual ACP calls.
|
||||
///
|
||||
/// ## Expected ACP Call
|
||||
///
|
||||
/// ```json
|
||||
/// {
|
||||
/// "method": "session/request_permission",
|
||||
/// "params": {
|
||||
/// "session_id": "session-123",
|
||||
/// "operation": "write",
|
||||
/// "path": "C:/work/project/file.txt",
|
||||
/// "description": "Write to file C:/work/project/file.txt"
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// ## Expected Response
|
||||
///
|
||||
/// ```json
|
||||
/// {
|
||||
/// "result": {
|
||||
/// "outcome": "allow_always" | "allow_once" | "reject_always" | "reject_once" | "cancelled"
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// ## Errors
|
||||
///
|
||||
/// - `ToolError::PermissionDenied` if the ACP call fails or times out
|
||||
/// - `ToolError::AcpError` if the ACP protocol reports an error
|
||||
pub async fn request_permission_from_user(
|
||||
operation: &PermissionOperation,
|
||||
context: &AcpPermissionContext,
|
||||
) -> ToolResult<PermissionOutcome> {
|
||||
// TODO: Replace with actual ACP call when infrastructure is ready
|
||||
//
|
||||
// Expected implementation:
|
||||
// 1. Get session handle from context
|
||||
// 2. Call session.request_permission(operation, description)
|
||||
// 3. Wait for user response (with timeout)
|
||||
// 4. Map ACP response to PermissionOutcome
|
||||
// 5. Handle errors and timeouts
|
||||
|
||||
tracing::warn!(
|
||||
operation = %operation.display(),
|
||||
connector_id = %context.connector_id,
|
||||
session_id = ?context.session_id,
|
||||
"TODO: Actual ACP request_permission call not yet implemented (TOOLS-PERM-03)"
|
||||
);
|
||||
|
||||
// For now, always return AllowOnce as a safe default for development
|
||||
// In production, this should prompt the user via ACP
|
||||
Ok(PermissionOutcome::AllowOnce)
|
||||
}
|
||||
|
||||
/// Mock implementation for testing (when ACP infrastructure is not available).
|
||||
///
|
||||
/// This allows tests to simulate different permission outcomes without requiring
|
||||
/// a full ACP stack.
|
||||
#[cfg(test)]
|
||||
pub async fn request_permission_mock(
|
||||
_operation: &PermissionOperation,
|
||||
outcome: PermissionOutcome,
|
||||
) -> ToolResult<PermissionOutcome> {
|
||||
Ok(outcome)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_permission_outcome_should_cache() {
|
||||
assert!(PermissionOutcome::AllowAlways.should_cache());
|
||||
assert!(PermissionOutcome::RejectAlways.should_cache());
|
||||
assert!(!PermissionOutcome::AllowOnce.should_cache());
|
||||
assert!(!PermissionOutcome::RejectOnce.should_cache());
|
||||
assert!(!PermissionOutcome::Cancelled.should_cache());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_outcome_is_allowed() {
|
||||
assert!(PermissionOutcome::AllowOnce.is_allowed());
|
||||
assert!(PermissionOutcome::AllowAlways.is_allowed());
|
||||
assert!(!PermissionOutcome::RejectOnce.is_allowed());
|
||||
assert!(!PermissionOutcome::RejectAlways.is_allowed());
|
||||
assert!(!PermissionOutcome::Cancelled.is_allowed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_outcome_to_cache_decision() {
|
||||
use crate::permission::cache::PermissionDecision;
|
||||
|
||||
assert_eq!(
|
||||
PermissionOutcome::AllowAlways.to_cache_decision(),
|
||||
Some(PermissionDecision::Allowed)
|
||||
);
|
||||
assert_eq!(
|
||||
PermissionOutcome::RejectAlways.to_cache_decision(),
|
||||
Some(PermissionDecision::Denied)
|
||||
);
|
||||
assert_eq!(
|
||||
PermissionOutcome::Cancelled.to_cache_decision(),
|
||||
Some(PermissionDecision::Cancelled)
|
||||
);
|
||||
assert_eq!(PermissionOutcome::AllowOnce.to_cache_decision(), None);
|
||||
assert_eq!(PermissionOutcome::RejectOnce.to_cache_decision(), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_permission_stub() {
|
||||
let operation = PermissionOperation::Write {
|
||||
path: "/test/path".to_string(),
|
||||
};
|
||||
let context = AcpPermissionContext {
|
||||
connector_id: "test-connector".to_string(),
|
||||
session_id: Some("test-session".to_string()),
|
||||
};
|
||||
|
||||
// Stub currently returns AllowOnce
|
||||
let outcome = request_permission_from_user(&operation, &context)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(outcome, PermissionOutcome::AllowOnce);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_permission_mock() {
|
||||
let operation = PermissionOperation::Execute {
|
||||
command: "test".to_string(),
|
||||
cwd: "/".to_string(),
|
||||
};
|
||||
|
||||
// Test all outcomes
|
||||
for outcome in [
|
||||
PermissionOutcome::AllowOnce,
|
||||
PermissionOutcome::AllowAlways,
|
||||
PermissionOutcome::RejectOnce,
|
||||
PermissionOutcome::RejectAlways,
|
||||
PermissionOutcome::Cancelled,
|
||||
] {
|
||||
let result = request_permission_mock(&operation, outcome).await.unwrap();
|
||||
assert_eq!(result, outcome);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,389 @@
|
||||
//! Decision cache with TTL and scope support.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-PERM-02)
|
||||
//!
|
||||
//! This module provides thread-safe caching of permission decisions with:
|
||||
//! - TTL (time-to-live) for cached decisions
|
||||
//! - Scope support (per-connector or per-session)
|
||||
//! - Automatic expiration of stale entries
|
||||
//! - Hash-based cache keys for efficient lookups
|
||||
|
||||
use crate::config::DecisionScope;
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Permission decision outcome.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PermissionDecision {
|
||||
Allowed,
|
||||
Denied,
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// Cache key for permission decisions.
|
||||
///
|
||||
/// Keys are constructed from:
|
||||
/// - Operation kind (read, write, execute)
|
||||
/// - Normalized path or command
|
||||
/// - Scope identifier (connector_id or session_id)
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct CacheKey {
|
||||
/// Operation discriminant
|
||||
operation_kind: OperationKind,
|
||||
/// Normalized path or command string
|
||||
target: String,
|
||||
/// Scope identifier (connector or session)
|
||||
scope_id: String,
|
||||
}
|
||||
|
||||
/// Operation kind for cache key discrimination.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
enum OperationKind {
|
||||
Read,
|
||||
Write,
|
||||
Execute,
|
||||
}
|
||||
|
||||
impl CacheKey {
|
||||
/// Create a cache key for a read operation.
|
||||
pub fn read(path: &str, connector_id: &str, scope: DecisionScope) -> Self {
|
||||
Self {
|
||||
operation_kind: OperationKind::Read,
|
||||
target: path.to_string(),
|
||||
scope_id: Self::scope_id(connector_id, None, scope),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a cache key for a write operation.
|
||||
pub fn write(path: &str, connector_id: &str, session_id: Option<&str>, scope: DecisionScope) -> Self {
|
||||
Self {
|
||||
operation_kind: OperationKind::Write,
|
||||
target: path.to_string(),
|
||||
scope_id: Self::scope_id(connector_id, session_id, scope),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a cache key for an execute operation.
|
||||
pub fn execute(command: &str, connector_id: &str, session_id: Option<&str>, scope: DecisionScope) -> Self {
|
||||
Self {
|
||||
operation_kind: OperationKind::Execute,
|
||||
target: command.to_string(),
|
||||
scope_id: Self::scope_id(connector_id, session_id, scope),
|
||||
}
|
||||
}
|
||||
|
||||
/// Construct scope identifier based on scope policy.
|
||||
fn scope_id(connector_id: &str, session_id: Option<&str>, scope: DecisionScope) -> String {
|
||||
match scope {
|
||||
DecisionScope::PerConnector => connector_id.to_string(),
|
||||
DecisionScope::PerSession => {
|
||||
format!("{}:{}", connector_id, session_id.unwrap_or("default"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cached permission decision with expiration.
|
||||
#[derive(Debug, Clone)]
|
||||
struct CachedDecision {
|
||||
/// The permission decision
|
||||
decision: PermissionDecision,
|
||||
/// When this entry expires
|
||||
expires_at: Instant,
|
||||
}
|
||||
|
||||
impl CachedDecision {
|
||||
/// Create a new cached decision with TTL.
|
||||
fn new(decision: PermissionDecision, ttl: Duration) -> Self {
|
||||
Self {
|
||||
decision,
|
||||
expires_at: Instant::now() + ttl,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this entry has expired.
|
||||
fn is_expired(&self) -> bool {
|
||||
Instant::now() >= self.expires_at
|
||||
}
|
||||
}
|
||||
|
||||
/// Thread-safe decision cache with TTL support.
|
||||
///
|
||||
/// The cache stores permission decisions keyed by operation, path/command, and scope.
|
||||
/// Entries automatically expire after their TTL, and expired entries are pruned on access.
|
||||
///
|
||||
/// ## Thread Safety
|
||||
///
|
||||
/// The cache is designed to be wrapped in `Arc<Mutex<DecisionCache>>` for thread-safe access.
|
||||
/// Individual operations (get, insert, clear) should acquire the lock briefly.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use dirigent_tools::permission::cache::{DecisionCache, PermissionDecision, CacheKey};
|
||||
/// use dirigent_tools::config::DecisionScope;
|
||||
/// use std::time::Duration;
|
||||
///
|
||||
/// let mut cache = DecisionCache::new();
|
||||
/// let key = CacheKey::write("/path/to/file", "connector-1", None, DecisionScope::PerConnector);
|
||||
///
|
||||
/// // Cache a decision
|
||||
/// cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_secs(300));
|
||||
///
|
||||
/// // Retrieve it
|
||||
/// assert_eq!(cache.get(&key), Some(PermissionDecision::Allowed));
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct DecisionCache {
|
||||
entries: HashMap<CacheKey, CachedDecision>,
|
||||
}
|
||||
|
||||
impl DecisionCache {
|
||||
/// Create a new empty decision cache.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
entries: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a cached decision if it exists and hasn't expired.
|
||||
///
|
||||
/// Expired entries are automatically removed.
|
||||
///
|
||||
/// Returns `Some(decision)` if a valid cached entry exists, `None` otherwise.
|
||||
pub fn get(&mut self, key: &CacheKey) -> Option<PermissionDecision> {
|
||||
// Check if entry exists
|
||||
if let Some(cached) = self.entries.get(key) {
|
||||
// Check if expired
|
||||
if cached.is_expired() {
|
||||
// Remove expired entry
|
||||
self.entries.remove(key);
|
||||
None
|
||||
} else {
|
||||
// Return valid decision
|
||||
Some(cached.decision)
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a new decision into the cache with the given TTL.
|
||||
///
|
||||
/// If an entry already exists for this key, it will be replaced.
|
||||
pub fn insert(&mut self, key: CacheKey, decision: PermissionDecision, ttl: Duration) {
|
||||
self.entries.insert(key, CachedDecision::new(decision, ttl));
|
||||
}
|
||||
|
||||
/// Clear all cached decisions.
|
||||
///
|
||||
/// Useful for manual cache invalidation or testing.
|
||||
pub fn clear(&mut self) {
|
||||
self.entries.clear();
|
||||
}
|
||||
|
||||
/// Remove all expired entries from the cache.
|
||||
///
|
||||
/// This is automatically done during `get()` operations, but can be called
|
||||
/// periodically to prune the cache and free memory.
|
||||
///
|
||||
/// Returns the number of expired entries removed.
|
||||
pub fn clear_expired(&mut self) -> usize {
|
||||
let before = self.entries.len();
|
||||
self.entries.retain(|_, cached| !cached.is_expired());
|
||||
before - self.entries.len()
|
||||
}
|
||||
|
||||
/// Get the number of cached entries (including expired).
|
||||
///
|
||||
/// For testing and monitoring purposes.
|
||||
pub fn len(&self) -> usize {
|
||||
self.entries.len()
|
||||
}
|
||||
|
||||
/// Check if the cache is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.entries.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DecisionCache {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_creation() {
|
||||
let key1 = CacheKey::write("/path/to/file", "conn-1", None, DecisionScope::PerConnector);
|
||||
let key2 = CacheKey::write("/path/to/file", "conn-1", None, DecisionScope::PerConnector);
|
||||
let key3 = CacheKey::write("/path/to/file", "conn-2", None, DecisionScope::PerConnector);
|
||||
|
||||
// Same connector and path should produce equal keys
|
||||
assert_eq!(key1, key2);
|
||||
|
||||
// Different connector should produce different keys
|
||||
assert_ne!(key1, key3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_scope() {
|
||||
let key_connector = CacheKey::write(
|
||||
"/path",
|
||||
"conn-1",
|
||||
Some("session-1"),
|
||||
DecisionScope::PerConnector,
|
||||
);
|
||||
let key_session = CacheKey::write(
|
||||
"/path",
|
||||
"conn-1",
|
||||
Some("session-1"),
|
||||
DecisionScope::PerSession,
|
||||
);
|
||||
|
||||
// Different scopes should produce different keys
|
||||
assert_ne!(key_connector, key_session);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_key_operation_kind() {
|
||||
let key_read = CacheKey::read("/path", "conn-1", DecisionScope::PerConnector);
|
||||
let key_write = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector);
|
||||
|
||||
// Different operations should produce different keys
|
||||
assert_ne!(key_read, key_write);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_insert_and_get() {
|
||||
let mut cache = DecisionCache::new();
|
||||
let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector);
|
||||
|
||||
// Initially empty
|
||||
assert_eq!(cache.get(&key), None);
|
||||
|
||||
// Insert decision
|
||||
cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_secs(300));
|
||||
|
||||
// Should retrieve it
|
||||
assert_eq!(cache.get(&key), Some(PermissionDecision::Allowed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_ttl_expiration() {
|
||||
let mut cache = DecisionCache::new();
|
||||
let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector);
|
||||
|
||||
// Insert with very short TTL
|
||||
cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_millis(1));
|
||||
|
||||
// Wait for expiration
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
|
||||
// Should not retrieve expired entry
|
||||
assert_eq!(cache.get(&key), None);
|
||||
|
||||
// Entry should be removed from cache
|
||||
assert_eq!(cache.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_clear() {
|
||||
let mut cache = DecisionCache::new();
|
||||
let key1 = CacheKey::write("/path1", "conn-1", None, DecisionScope::PerConnector);
|
||||
let key2 = CacheKey::write("/path2", "conn-1", None, DecisionScope::PerConnector);
|
||||
|
||||
cache.insert(key1.clone(), PermissionDecision::Allowed, Duration::from_secs(300));
|
||||
cache.insert(key2.clone(), PermissionDecision::Denied, Duration::from_secs(300));
|
||||
|
||||
assert_eq!(cache.len(), 2);
|
||||
|
||||
// Clear all entries
|
||||
cache.clear();
|
||||
|
||||
assert_eq!(cache.len(), 0);
|
||||
assert_eq!(cache.get(&key1), None);
|
||||
assert_eq!(cache.get(&key2), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_clear_expired() {
|
||||
let mut cache = DecisionCache::new();
|
||||
let key1 = CacheKey::write("/path1", "conn-1", None, DecisionScope::PerConnector);
|
||||
let key2 = CacheKey::write("/path2", "conn-1", None, DecisionScope::PerConnector);
|
||||
|
||||
// Insert one short-lived and one long-lived entry
|
||||
cache.insert(key1.clone(), PermissionDecision::Allowed, Duration::from_millis(1));
|
||||
cache.insert(key2.clone(), PermissionDecision::Denied, Duration::from_secs(300));
|
||||
|
||||
assert_eq!(cache.len(), 2);
|
||||
|
||||
// Wait for first to expire
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
|
||||
// Clear expired entries
|
||||
let removed = cache.clear_expired();
|
||||
assert_eq!(removed, 1);
|
||||
assert_eq!(cache.len(), 1);
|
||||
|
||||
// Second entry should still be accessible
|
||||
assert_eq!(cache.get(&key2), Some(PermissionDecision::Denied));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_replace_entry() {
|
||||
let mut cache = DecisionCache::new();
|
||||
let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector);
|
||||
|
||||
// Insert initial decision
|
||||
cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_secs(300));
|
||||
assert_eq!(cache.get(&key), Some(PermissionDecision::Allowed));
|
||||
|
||||
// Replace with different decision
|
||||
cache.insert(key.clone(), PermissionDecision::Denied, Duration::from_secs(300));
|
||||
assert_eq!(cache.get(&key), Some(PermissionDecision::Denied));
|
||||
|
||||
// Should only have one entry
|
||||
assert_eq!(cache.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_different_sessions() {
|
||||
let mut cache = DecisionCache::new();
|
||||
let key_session1 = CacheKey::write(
|
||||
"/path",
|
||||
"conn-1",
|
||||
Some("session-1"),
|
||||
DecisionScope::PerSession,
|
||||
);
|
||||
let key_session2 = CacheKey::write(
|
||||
"/path",
|
||||
"conn-1",
|
||||
Some("session-2"),
|
||||
DecisionScope::PerSession,
|
||||
);
|
||||
|
||||
// Insert different decisions for different sessions
|
||||
cache.insert(key_session1.clone(), PermissionDecision::Allowed, Duration::from_secs(300));
|
||||
cache.insert(key_session2.clone(), PermissionDecision::Denied, Duration::from_secs(300));
|
||||
|
||||
// Each session should have its own decision
|
||||
assert_eq!(cache.get(&key_session1), Some(PermissionDecision::Allowed));
|
||||
assert_eq!(cache.get(&key_session2), Some(PermissionDecision::Denied));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_cancelled_decision() {
|
||||
let mut cache = DecisionCache::new();
|
||||
let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector);
|
||||
|
||||
// Cancelled decisions can also be cached
|
||||
cache.insert(key.clone(), PermissionDecision::Cancelled, Duration::from_secs(300));
|
||||
assert_eq!(cache.get(&key), Some(PermissionDecision::Cancelled));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,366 @@
|
||||
//! Core permission check function integrating cache, whitelist, and ACP prompts.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-PERM-01)
|
||||
//!
|
||||
//! This module provides the main `check_permission` function that orchestrates:
|
||||
//! - Permission mode evaluation (yolo, whitelist, ask)
|
||||
//! - Decision cache lookup and updates
|
||||
//! - Whitelist pattern matching
|
||||
//! - ACP permission prompts
|
||||
//!
|
||||
//! ## Algorithm
|
||||
//!
|
||||
//! 1. **Yolo mode**: Always allow (with audit logging)
|
||||
//! 2. **Check cache**: Return cached decision if present and valid
|
||||
//! 3. **Whitelist mode**: Auto-approve if operation matches whitelist
|
||||
//! 4. **Ask mode / No whitelist match**: Prompt user via ACP
|
||||
//! 5. **Cache decision**: Store if user selected "always" option
|
||||
//! 6. **Return decision**: Allow or deny the operation
|
||||
|
||||
use crate::config::{DecisionScope, PermissionConfig, PermissionMode};
|
||||
use crate::error::ToolResult;
|
||||
use crate::permission::acp::{request_permission_from_user, AcpPermissionContext, PermissionOutcome};
|
||||
use crate::permission::cache::{CacheKey, DecisionCache, PermissionDecision};
|
||||
use crate::permission::whitelist::{matches_whitelist, CompiledWhitelist, PermissionOperation};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Context for permission checks.
|
||||
///
|
||||
/// This should be created once per connector or session and reused for all
|
||||
/// permission checks to maintain consistent cache state.
|
||||
#[derive(Clone)]
|
||||
pub struct PermissionContext {
|
||||
/// Connector ID
|
||||
pub connector_id: String,
|
||||
/// Session ID (if in a session)
|
||||
pub session_id: Option<String>,
|
||||
/// Shared decision cache (thread-safe)
|
||||
pub cache: Arc<Mutex<DecisionCache>>,
|
||||
/// Compiled whitelist for fast matching
|
||||
pub whitelist: Arc<CompiledWhitelist>,
|
||||
}
|
||||
|
||||
impl PermissionContext {
|
||||
/// Create a new permission context.
|
||||
pub fn new(
|
||||
connector_id: String,
|
||||
session_id: Option<String>,
|
||||
whitelist: CompiledWhitelist,
|
||||
) -> Self {
|
||||
Self {
|
||||
connector_id,
|
||||
session_id,
|
||||
cache: Arc::new(Mutex::new(DecisionCache::new())),
|
||||
whitelist: Arc::new(whitelist),
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the decision cache (for testing or manual reset).
|
||||
pub fn clear_cache(&self) {
|
||||
if let Ok(mut cache) = self.cache.lock() {
|
||||
cache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cache statistics (for monitoring/debugging).
|
||||
pub fn cache_size(&self) -> usize {
|
||||
self.cache.lock().map(|c| c.len()).unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an operation requires permission and prompt if needed.
|
||||
///
|
||||
/// This is the main entry point for permission checks. It implements the full
|
||||
/// permission algorithm including mode evaluation, caching, whitelist matching,
|
||||
/// and ACP prompts.
|
||||
///
|
||||
/// ## Algorithm
|
||||
///
|
||||
/// 1. If mode is `Yolo`: Always allow (with audit log)
|
||||
/// 2. Check decision cache (with TTL)
|
||||
/// 3. If cached decision exists: Return it
|
||||
/// 4. If mode is `Whitelist` and operation matches: Allow
|
||||
/// 5. Otherwise: Call ACP `session/request_permission`
|
||||
/// 6. If outcome is "always": Cache the decision
|
||||
/// 7. Return decision (Allowed/Denied/Cancelled)
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use dirigent_tools::config::{PermissionConfig, PermissionMode, WhitelistConfig};
|
||||
/// use dirigent_tools::permission::check::{PermissionContext, check_permission};
|
||||
/// use dirigent_tools::permission::whitelist::{CompiledWhitelist, PermissionOperation};
|
||||
///
|
||||
/// # async fn example() {
|
||||
/// let config = PermissionConfig {
|
||||
/// mode: PermissionMode::Ask,
|
||||
/// remember_decisions: true,
|
||||
/// remember_ttl_secs: 3600,
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
///
|
||||
/// let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
|
||||
/// let context = PermissionContext::new(
|
||||
/// "connector-1".to_string(),
|
||||
/// Some("session-1".to_string()),
|
||||
/// whitelist,
|
||||
/// );
|
||||
///
|
||||
/// let operation = PermissionOperation::Write {
|
||||
/// path: "/path/to/file".to_string(),
|
||||
/// };
|
||||
///
|
||||
/// let decision = check_permission(&operation, &context, &config).await.unwrap();
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn check_permission(
|
||||
operation: &PermissionOperation,
|
||||
context: &PermissionContext,
|
||||
config: &PermissionConfig,
|
||||
) -> ToolResult<PermissionDecision> {
|
||||
// Step 1: Yolo mode - always allow
|
||||
if config.mode == PermissionMode::Yolo {
|
||||
tracing::info!(
|
||||
operation = %operation.display(),
|
||||
mode = "yolo",
|
||||
"Permission check: auto-approved (yolo mode)"
|
||||
);
|
||||
return Ok(PermissionDecision::Allowed);
|
||||
}
|
||||
|
||||
// Step 2: Check decision cache
|
||||
if config.remember_decisions {
|
||||
let cache_key = create_cache_key(operation, context, config.scope);
|
||||
|
||||
if let Ok(mut cache) = context.cache.lock() {
|
||||
if let Some(cached_decision) = cache.get(&cache_key) {
|
||||
tracing::debug!(
|
||||
operation = %operation.display(),
|
||||
decision = ?cached_decision,
|
||||
"Permission check: using cached decision"
|
||||
);
|
||||
return Ok(cached_decision);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Whitelist mode - check if operation matches whitelist
|
||||
if config.mode == PermissionMode::Whitelist {
|
||||
if matches_whitelist(operation, &context.whitelist) {
|
||||
tracing::info!(
|
||||
operation = %operation.display(),
|
||||
mode = "whitelist",
|
||||
"Permission check: auto-approved (whitelist match)"
|
||||
);
|
||||
return Ok(PermissionDecision::Allowed);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Prompt user via ACP
|
||||
tracing::debug!(
|
||||
operation = %operation.display(),
|
||||
mode = ?config.mode,
|
||||
"Permission check: prompting user"
|
||||
);
|
||||
|
||||
let acp_context = AcpPermissionContext {
|
||||
connector_id: context.connector_id.clone(),
|
||||
session_id: context.session_id.clone(),
|
||||
};
|
||||
|
||||
let outcome = request_permission_from_user(operation, &acp_context).await?;
|
||||
|
||||
// Step 5: Convert outcome to decision
|
||||
let decision = if outcome.is_allowed() {
|
||||
PermissionDecision::Allowed
|
||||
} else {
|
||||
match outcome {
|
||||
PermissionOutcome::Cancelled => PermissionDecision::Cancelled,
|
||||
_ => PermissionDecision::Denied,
|
||||
}
|
||||
};
|
||||
|
||||
// Step 6: Cache decision if "always" option was selected
|
||||
if config.remember_decisions && outcome.should_cache() {
|
||||
if let Some(cache_decision) = outcome.to_cache_decision() {
|
||||
let cache_key = create_cache_key(operation, context, config.scope);
|
||||
let ttl = Duration::from_secs(config.remember_ttl_secs);
|
||||
|
||||
if let Ok(mut cache) = context.cache.lock() {
|
||||
cache.insert(cache_key, cache_decision, ttl);
|
||||
tracing::debug!(
|
||||
operation = %operation.display(),
|
||||
decision = ?cache_decision,
|
||||
ttl_secs = config.remember_ttl_secs,
|
||||
"Cached permission decision"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
operation = %operation.display(),
|
||||
decision = ?decision,
|
||||
outcome = ?outcome,
|
||||
"Permission check: decision made"
|
||||
);
|
||||
|
||||
Ok(decision)
|
||||
}
|
||||
|
||||
/// Create a cache key for an operation with the appropriate scope.
|
||||
fn create_cache_key(
|
||||
operation: &PermissionOperation,
|
||||
context: &PermissionContext,
|
||||
scope: DecisionScope,
|
||||
) -> CacheKey {
|
||||
match operation {
|
||||
PermissionOperation::Read { path } => {
|
||||
CacheKey::read(path, &context.connector_id, scope)
|
||||
}
|
||||
PermissionOperation::Write { path } => {
|
||||
CacheKey::write(path, &context.connector_id, context.session_id.as_deref(), scope)
|
||||
}
|
||||
PermissionOperation::Execute { command, .. } => {
|
||||
CacheKey::execute(command, &context.connector_id, context.session_id.as_deref(), scope)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::WhitelistConfig;
|
||||
|
||||
fn create_test_config(mode: PermissionMode) -> PermissionConfig {
|
||||
PermissionConfig {
|
||||
mode,
|
||||
remember_decisions: true,
|
||||
remember_ttl_secs: 300,
|
||||
scope: DecisionScope::PerConnector,
|
||||
whitelist: WhitelistConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_context() -> PermissionContext {
|
||||
let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
|
||||
PermissionContext::new(
|
||||
"test-connector".to_string(),
|
||||
Some("test-session".to_string()),
|
||||
whitelist,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_yolo_mode_always_allows() {
|
||||
let config = create_test_config(PermissionMode::Yolo);
|
||||
let context = create_test_context();
|
||||
|
||||
let operations = vec![
|
||||
PermissionOperation::Read { path: "/any/path".to_string() },
|
||||
PermissionOperation::Write { path: "/any/path".to_string() },
|
||||
PermissionOperation::Execute {
|
||||
command: "dangerous_command".to_string(),
|
||||
cwd: "/".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
for operation in operations {
|
||||
let decision = check_permission(&operation, &context, &config).await.unwrap();
|
||||
assert_eq!(decision, PermissionDecision::Allowed);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_whitelist_mode_read_always_allowed() {
|
||||
let config = create_test_config(PermissionMode::Whitelist);
|
||||
let context = create_test_context();
|
||||
|
||||
let operation = PermissionOperation::Read {
|
||||
path: "/any/path".to_string(),
|
||||
};
|
||||
|
||||
let decision = check_permission(&operation, &context, &config).await.unwrap();
|
||||
assert_eq!(decision, PermissionDecision::Allowed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_whitelist_mode_with_pattern() {
|
||||
let mut config = create_test_config(PermissionMode::Whitelist);
|
||||
config.whitelist = WhitelistConfig {
|
||||
write_paths: vec!["C:/work/**".to_string()],
|
||||
execute_commands: vec!["cargo".to_string()],
|
||||
};
|
||||
|
||||
let whitelist = CompiledWhitelist::compile(&config.whitelist).unwrap();
|
||||
let context = PermissionContext::new(
|
||||
"test-connector".to_string(),
|
||||
None,
|
||||
whitelist,
|
||||
);
|
||||
|
||||
// Should match whitelist
|
||||
let write_ok = PermissionOperation::Write {
|
||||
path: "C:/work/project/file.txt".to_string(),
|
||||
};
|
||||
let decision = check_permission(&write_ok, &context, &config).await.unwrap();
|
||||
assert_eq!(decision, PermissionDecision::Allowed);
|
||||
|
||||
// TODO: Test non-matching write would require ACP mock
|
||||
// This is tested in integration tests
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_key_creation() {
|
||||
let context = create_test_context();
|
||||
let scope = DecisionScope::PerConnector;
|
||||
|
||||
let read_op = PermissionOperation::Read {
|
||||
path: "/path".to_string(),
|
||||
};
|
||||
let key1 = create_cache_key(&read_op, &context, scope);
|
||||
let key2 = create_cache_key(&read_op, &context, scope);
|
||||
|
||||
// Same operation should produce same key
|
||||
assert_eq!(key1, key2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_context_cache_operations() {
|
||||
let context = create_test_context();
|
||||
|
||||
assert_eq!(context.cache_size(), 0);
|
||||
|
||||
// Add entry to cache
|
||||
{
|
||||
let mut cache = context.cache.lock().unwrap();
|
||||
let key = CacheKey::write("/path", "test", None, DecisionScope::PerConnector);
|
||||
cache.insert(key, PermissionDecision::Allowed, Duration::from_secs(300));
|
||||
}
|
||||
|
||||
assert_eq!(context.cache_size(), 1);
|
||||
|
||||
// Clear cache
|
||||
context.clear_cache();
|
||||
assert_eq!(context.cache_size(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_context_clone() {
|
||||
let context = create_test_context();
|
||||
let cloned = context.clone();
|
||||
|
||||
// Should share the same cache
|
||||
assert_eq!(context.cache_size(), cloned.cache_size());
|
||||
|
||||
// Modifications to one should affect the other
|
||||
{
|
||||
let mut cache = context.cache.lock().unwrap();
|
||||
let key = CacheKey::write("/path", "test", None, DecisionScope::PerConnector);
|
||||
cache.insert(key, PermissionDecision::Allowed, Duration::from_secs(300));
|
||||
}
|
||||
|
||||
assert_eq!(cloned.cache_size(), 1);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,397 @@
|
||||
//! Whitelist pattern matching for auto-approval of safe operations.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-PERM-04)
|
||||
//!
|
||||
//! This module provides pattern matching against configured whitelists to
|
||||
//! automatically approve safe operations without prompting the user.
|
||||
//!
|
||||
//! ## Whitelist Strategy
|
||||
//!
|
||||
//! - **Read operations**: Always approved in whitelist mode (reads are safe)
|
||||
//! - **Write operations**: Match path against `write_paths` glob patterns
|
||||
//! - **Execute operations**: Match command against `execute_commands` glob patterns
|
||||
//!
|
||||
//! ## Pattern Matching
|
||||
//!
|
||||
//! Uses `globset` for efficient compiled glob patterns:
|
||||
//! - `**` matches any number of path segments
|
||||
//! - `*` matches any characters within a segment
|
||||
//! - `?` matches a single character
|
||||
//! - Character classes like `[abc]` are supported
|
||||
//!
|
||||
//! ## Performance
|
||||
//!
|
||||
//! Globs are pre-compiled at configuration load time and stored in `CompiledWhitelist`
|
||||
//! for fast matching during permission checks.
|
||||
|
||||
use crate::config::WhitelistConfig;
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use globset::{Glob, GlobSet, GlobSetBuilder};
|
||||
use std::path::Path;
|
||||
|
||||
/// Compiled whitelist with pre-built glob sets for efficient matching.
|
||||
///
|
||||
/// This should be constructed once at configuration load time and reused
|
||||
/// for all permission checks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CompiledWhitelist {
|
||||
/// Compiled glob patterns for write path matching
|
||||
write_paths: GlobSet,
|
||||
/// Compiled glob patterns for execute command matching
|
||||
execute_commands: GlobSet,
|
||||
}
|
||||
|
||||
impl CompiledWhitelist {
|
||||
/// Compile a whitelist configuration into an efficient matcher.
|
||||
///
|
||||
/// ## Errors
|
||||
///
|
||||
/// Returns an error if any glob pattern is invalid.
|
||||
pub fn compile(config: &WhitelistConfig) -> ToolResult<Self> {
|
||||
let write_paths = Self::compile_patterns(&config.write_paths)?;
|
||||
let execute_commands = Self::compile_patterns(&config.execute_commands)?;
|
||||
|
||||
Ok(Self {
|
||||
write_paths,
|
||||
execute_commands,
|
||||
})
|
||||
}
|
||||
|
||||
/// Compile a list of glob patterns into a GlobSet.
|
||||
fn compile_patterns(patterns: &[String]) -> ToolResult<GlobSet> {
|
||||
let mut builder = GlobSetBuilder::new();
|
||||
|
||||
for pattern in patterns {
|
||||
let glob = Glob::new(pattern).map_err(|e| {
|
||||
ToolError::InvalidConfig(format!("Invalid whitelist glob pattern '{}': {}", pattern, e))
|
||||
})?;
|
||||
builder.add(glob);
|
||||
}
|
||||
|
||||
builder.build().map_err(|e| {
|
||||
ToolError::InvalidConfig(format!("Failed to build whitelist glob set: {}", e))
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a read operation matches the whitelist.
|
||||
///
|
||||
/// In whitelist mode, all read operations are considered safe and return true.
|
||||
pub fn matches_read(&self, _path: &Path) -> bool {
|
||||
// Reads are always safe in whitelist mode
|
||||
true
|
||||
}
|
||||
|
||||
/// Check if a write operation matches the whitelist.
|
||||
///
|
||||
/// Returns true if the path matches any of the configured write_paths patterns.
|
||||
pub fn matches_write(&self, path: &Path) -> bool {
|
||||
self.write_paths.is_match(path)
|
||||
}
|
||||
|
||||
/// Check if an execute operation matches the whitelist.
|
||||
///
|
||||
/// Returns true if the command matches any of the configured execute_commands patterns.
|
||||
///
|
||||
/// The command is compared as a string (not a path) to support both:
|
||||
/// - Simple command names (e.g., "cargo", "npm")
|
||||
/// - Command patterns (e.g., "cargo*", "npm*")
|
||||
pub fn matches_execute(&self, command: &str) -> bool {
|
||||
// For command matching, we treat the command as a simple string path
|
||||
// This allows patterns like "cargo", "npm*", etc.
|
||||
self.execute_commands.is_match(command)
|
||||
}
|
||||
|
||||
/// Check if the whitelist has any write patterns configured.
|
||||
pub fn has_write_patterns(&self) -> bool {
|
||||
!self.write_paths.is_empty()
|
||||
}
|
||||
|
||||
/// Check if the whitelist has any execute patterns configured.
|
||||
pub fn has_execute_patterns(&self) -> bool {
|
||||
!self.execute_commands.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Operation type for whitelist matching.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum PermissionOperation {
|
||||
Read { path: String },
|
||||
Write { path: String },
|
||||
Execute { command: String, cwd: String },
|
||||
}
|
||||
|
||||
impl PermissionOperation {
|
||||
/// Get a display string for this operation (for logging/debugging).
|
||||
pub fn display(&self) -> String {
|
||||
match self {
|
||||
Self::Read { path } => format!("read {}", path),
|
||||
Self::Write { path } => format!("write {}", path),
|
||||
Self::Execute { command, cwd } => format!("execute '{}' in {}", command, cwd),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the operation kind as a string.
|
||||
pub fn kind(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Read { .. } => "read",
|
||||
Self::Write { .. } => "write",
|
||||
Self::Execute { .. } => "execute",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if an operation matches the configured whitelist.
|
||||
///
|
||||
/// ## Returns
|
||||
///
|
||||
/// - `true` if the operation should be auto-approved
|
||||
/// - `false` if the operation requires a permission prompt
|
||||
///
|
||||
/// ## Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use dirigent_tools::config::WhitelistConfig;
|
||||
/// use dirigent_tools::permission::whitelist::{CompiledWhitelist, PermissionOperation, matches_whitelist};
|
||||
///
|
||||
/// let config = WhitelistConfig {
|
||||
/// write_paths: vec!["C:/work/project/**".to_string()],
|
||||
/// execute_commands: vec!["cargo".to_string(), "npm".to_string()],
|
||||
/// };
|
||||
///
|
||||
/// let whitelist = CompiledWhitelist::compile(&config).unwrap();
|
||||
///
|
||||
/// // Read operations always match
|
||||
/// let read_op = PermissionOperation::Read { path: "C:/anywhere/file.txt".to_string() };
|
||||
/// assert!(matches_whitelist(&read_op, &whitelist));
|
||||
///
|
||||
/// // Write within allowed path
|
||||
/// let write_op = PermissionOperation::Write { path: "C:/work/project/src/main.rs".to_string() };
|
||||
/// assert!(matches_whitelist(&write_op, &whitelist));
|
||||
///
|
||||
/// // Execute whitelisted command
|
||||
/// let exec_op = PermissionOperation::Execute {
|
||||
/// command: "cargo".to_string(),
|
||||
/// cwd: "C:/work/project".to_string(),
|
||||
/// };
|
||||
/// assert!(matches_whitelist(&exec_op, &whitelist));
|
||||
/// ```
|
||||
pub fn matches_whitelist(operation: &PermissionOperation, whitelist: &CompiledWhitelist) -> bool {
|
||||
match operation {
|
||||
PermissionOperation::Read { path } => {
|
||||
whitelist.matches_read(Path::new(path))
|
||||
}
|
||||
PermissionOperation::Write { path } => {
|
||||
whitelist.matches_write(Path::new(path))
|
||||
}
|
||||
PermissionOperation::Execute { command, .. } => {
|
||||
whitelist.matches_execute(command)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_whitelist() -> CompiledWhitelist {
|
||||
let config = WhitelistConfig {
|
||||
write_paths: vec![
|
||||
"C:/work/project/**".to_string(),
|
||||
"/home/user/project/**".to_string(),
|
||||
"**/safe_dir/**".to_string(),
|
||||
],
|
||||
execute_commands: vec![
|
||||
"cargo".to_string(),
|
||||
"npm".to_string(),
|
||||
"git".to_string(),
|
||||
"python*".to_string(),
|
||||
],
|
||||
};
|
||||
CompiledWhitelist::compile(&config).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_always_matches() {
|
||||
let whitelist = create_test_whitelist();
|
||||
|
||||
// Any read operation should match
|
||||
assert!(whitelist.matches_read(Path::new("C:/anywhere/file.txt")));
|
||||
assert!(whitelist.matches_read(Path::new("/random/path")));
|
||||
assert!(whitelist.matches_read(Path::new("../../../etc/passwd")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_matches_patterns() {
|
||||
let whitelist = create_test_whitelist();
|
||||
|
||||
// Should match configured patterns
|
||||
assert!(whitelist.matches_write(Path::new("C:/work/project/src/main.rs")));
|
||||
assert!(whitelist.matches_write(Path::new("C:/work/project/Cargo.toml")));
|
||||
assert!(whitelist.matches_write(Path::new("/home/user/project/README.md")));
|
||||
|
||||
// Should not match paths outside patterns
|
||||
assert!(!whitelist.matches_write(Path::new("C:/other/project/file.txt")));
|
||||
assert!(!whitelist.matches_write(Path::new("/tmp/file.txt")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_write_matches_relative_patterns() {
|
||||
let whitelist = create_test_whitelist();
|
||||
|
||||
// Pattern **/safe_dir/** should match anywhere
|
||||
assert!(whitelist.matches_write(Path::new("some/path/safe_dir/file.txt")));
|
||||
assert!(whitelist.matches_write(Path::new("safe_dir/file.txt")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_execute_matches_commands() {
|
||||
let whitelist = create_test_whitelist();
|
||||
|
||||
// Should match exact command names
|
||||
assert!(whitelist.matches_execute("cargo"));
|
||||
assert!(whitelist.matches_execute("npm"));
|
||||
assert!(whitelist.matches_execute("git"));
|
||||
|
||||
// Should match patterns
|
||||
assert!(whitelist.matches_execute("python"));
|
||||
assert!(whitelist.matches_execute("python3"));
|
||||
|
||||
// Should not match unlisted commands
|
||||
assert!(!whitelist.matches_execute("rm"));
|
||||
assert!(!whitelist.matches_execute("format"));
|
||||
assert!(!whitelist.matches_execute("del"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_whitelist() {
|
||||
let config = WhitelistConfig {
|
||||
write_paths: vec![],
|
||||
execute_commands: vec![],
|
||||
};
|
||||
let whitelist = CompiledWhitelist::compile(&config).unwrap();
|
||||
|
||||
// Reads should still match
|
||||
assert!(whitelist.matches_read(Path::new("any/path")));
|
||||
|
||||
// Writes and executes should not match empty whitelist
|
||||
assert!(!whitelist.matches_write(Path::new("any/path")));
|
||||
assert!(!whitelist.matches_execute("any_command"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_glob_pattern() {
|
||||
let config = WhitelistConfig {
|
||||
write_paths: vec!["[invalid".to_string()], // Unclosed bracket
|
||||
execute_commands: vec![],
|
||||
};
|
||||
|
||||
let result = CompiledWhitelist::compile(&config);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_display() {
|
||||
let read_op = PermissionOperation::Read {
|
||||
path: "/path/to/file".to_string(),
|
||||
};
|
||||
assert_eq!(read_op.display(), "read /path/to/file");
|
||||
|
||||
let write_op = PermissionOperation::Write {
|
||||
path: "/path/to/file".to_string(),
|
||||
};
|
||||
assert_eq!(write_op.display(), "write /path/to/file");
|
||||
|
||||
let exec_op = PermissionOperation::Execute {
|
||||
command: "cargo test".to_string(),
|
||||
cwd: "/project".to_string(),
|
||||
};
|
||||
assert_eq!(exec_op.display(), "execute 'cargo test' in /project");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_operation_kind() {
|
||||
let read_op = PermissionOperation::Read {
|
||||
path: "/path".to_string(),
|
||||
};
|
||||
assert_eq!(read_op.kind(), "read");
|
||||
|
||||
let write_op = PermissionOperation::Write {
|
||||
path: "/path".to_string(),
|
||||
};
|
||||
assert_eq!(write_op.kind(), "write");
|
||||
|
||||
let exec_op = PermissionOperation::Execute {
|
||||
command: "cmd".to_string(),
|
||||
cwd: "/".to_string(),
|
||||
};
|
||||
assert_eq!(exec_op.kind(), "execute");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matches_whitelist_function() {
|
||||
let whitelist = create_test_whitelist();
|
||||
|
||||
// Test read
|
||||
let read_op = PermissionOperation::Read {
|
||||
path: "C:/any/file.txt".to_string(),
|
||||
};
|
||||
assert!(matches_whitelist(&read_op, &whitelist));
|
||||
|
||||
// Test write (match)
|
||||
let write_ok = PermissionOperation::Write {
|
||||
path: "C:/work/project/file.txt".to_string(),
|
||||
};
|
||||
assert!(matches_whitelist(&write_ok, &whitelist));
|
||||
|
||||
// Test write (no match)
|
||||
let write_fail = PermissionOperation::Write {
|
||||
path: "C:/other/file.txt".to_string(),
|
||||
};
|
||||
assert!(!matches_whitelist(&write_fail, &whitelist));
|
||||
|
||||
// Test execute (match)
|
||||
let exec_ok = PermissionOperation::Execute {
|
||||
command: "cargo".to_string(),
|
||||
cwd: "/project".to_string(),
|
||||
};
|
||||
assert!(matches_whitelist(&exec_ok, &whitelist));
|
||||
|
||||
// Test execute (no match)
|
||||
let exec_fail = PermissionOperation::Execute {
|
||||
command: "rm".to_string(),
|
||||
cwd: "/".to_string(),
|
||||
};
|
||||
assert!(!matches_whitelist(&exec_fail, &whitelist));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_has_patterns() {
|
||||
let config = WhitelistConfig {
|
||||
write_paths: vec!["some/path/**".to_string()],
|
||||
execute_commands: vec![],
|
||||
};
|
||||
let whitelist = CompiledWhitelist::compile(&config).unwrap();
|
||||
|
||||
assert!(whitelist.has_write_patterns());
|
||||
assert!(!whitelist.has_execute_patterns());
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
#[test]
|
||||
fn test_windows_paths() {
|
||||
let config = WhitelistConfig {
|
||||
write_paths: vec![
|
||||
"C:\\work\\project\\**".to_string(),
|
||||
"\\\\server\\share\\**".to_string(), // UNC path
|
||||
],
|
||||
execute_commands: vec![],
|
||||
};
|
||||
let whitelist = CompiledWhitelist::compile(&config).unwrap();
|
||||
|
||||
// Test Windows backslash paths
|
||||
assert!(whitelist.matches_write(Path::new("C:\\work\\project\\file.txt")));
|
||||
|
||||
// Test UNC paths
|
||||
assert!(whitelist.matches_write(Path::new("\\\\server\\share\\file.txt")));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,303 @@
|
||||
//! Tool registry: built-ins (compile-time) + per-session dynamic entries.
|
||||
|
||||
use crate::tool::{AnyTool, ClientKind, ProtocolKind, ToolContext};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
/// Where a dynamic tool came from.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
|
||||
pub enum ToolSource {
|
||||
Mcp(Arc<str>),
|
||||
Custom(Arc<str>),
|
||||
}
|
||||
|
||||
/// A registered dynamic tool with optional client/protocol filters.
|
||||
#[derive(Clone)]
|
||||
pub struct DynamicEntry {
|
||||
pub tool: Arc<dyn AnyTool>,
|
||||
pub source: ToolSource,
|
||||
pub only_for_client: Option<ClientKind>,
|
||||
pub only_for_protocol: Option<ProtocolKind>,
|
||||
}
|
||||
|
||||
/// Resolution policy when a name exists in both built-ins and dynamic.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub enum CollisionPolicy {
|
||||
#[default]
|
||||
BuiltInWins,
|
||||
DynamicWins,
|
||||
PerName(HashMap<Arc<str>, Winner>),
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
pub enum Winner { BuiltIn, Dynamic }
|
||||
|
||||
pub struct ToolRegistry {
|
||||
built_ins: HashMap<&'static str, Arc<dyn AnyTool>>,
|
||||
dynamic: RwLock<HashMap<Arc<str>, HashMap<String, DynamicEntry>>>,
|
||||
collision_policy: CollisionPolicy,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
/// Canonical scope key for dynamic-tool storage.
|
||||
///
|
||||
/// Session takes precedence so tools registered for a single session
|
||||
/// don't leak across the connector; absent that, the connector itself
|
||||
/// scopes the entry.
|
||||
fn scope_key(ctx: &ToolContext) -> Arc<str> {
|
||||
ctx.session_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| ctx.connector_id.clone())
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
built_ins: impl IntoIterator<Item = Arc<dyn AnyTool>>,
|
||||
collision_policy: CollisionPolicy,
|
||||
) -> Self {
|
||||
let built_ins = built_ins
|
||||
.into_iter()
|
||||
.map(|t| (t.name(), t))
|
||||
.collect();
|
||||
Self {
|
||||
built_ins,
|
||||
dynamic: RwLock::new(HashMap::new()),
|
||||
collision_policy,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resolve(&self, name: &str, ctx: &ToolContext) -> Option<Arc<dyn AnyTool>> {
|
||||
let dyn_match = self.find_dynamic(name, ctx);
|
||||
let builtin = self.built_ins.get(name).cloned();
|
||||
match (builtin, dyn_match) {
|
||||
(Some(b), None) => Some(b),
|
||||
(None, Some(d)) => Some(d.tool),
|
||||
(None, None) => None,
|
||||
(Some(b), Some(d)) => match &self.collision_policy {
|
||||
CollisionPolicy::BuiltInWins => Some(b),
|
||||
CollisionPolicy::DynamicWins => Some(d.tool),
|
||||
CollisionPolicy::PerName(map) => match map.get(name) {
|
||||
Some(Winner::BuiltIn) | None => Some(b),
|
||||
Some(Winner::Dynamic) => Some(d.tool),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Enumerate all tools visible in `ctx`'s scope.
|
||||
///
|
||||
/// Returns built-in tool names plus any dynamic tools registered under
|
||||
/// the canonical scope key for `ctx` (see [`Self::scope_key`]) whose
|
||||
/// optional client/protocol filters match. The result is sorted and
|
||||
/// deduplicated.
|
||||
pub fn list(&self, ctx: &ToolContext) -> Vec<String> {
|
||||
let mut names: Vec<String> = self.built_ins.keys().map(|n| n.to_string()).collect();
|
||||
let key = Self::scope_key(ctx);
|
||||
if let Some(per_scope) = self.dynamic.read().unwrap().get(&key) {
|
||||
for (n, e) in per_scope.iter() {
|
||||
if entry_matches_ctx(e, ctx) { names.push(n.clone()); }
|
||||
}
|
||||
}
|
||||
names.sort();
|
||||
names.dedup();
|
||||
names
|
||||
}
|
||||
|
||||
/// Insert a dynamic tool under `scope_key`.
|
||||
///
|
||||
/// Callers must pass the value [`Self::scope_key`] would produce for the
|
||||
/// registering [`ToolContext`]: the `session_id` if present, otherwise
|
||||
/// the `connector_id`. Anything else will be invisible to `resolve` and
|
||||
/// `list` for that context.
|
||||
pub fn register_dynamic(&self, scope_key: impl Into<Arc<str>>, name: String, entry: DynamicEntry) {
|
||||
let mut g = self.dynamic.write().unwrap();
|
||||
g.entry(scope_key.into()).or_default().insert(name, entry);
|
||||
}
|
||||
|
||||
pub fn unregister_scope(&self, scope_key: &str) {
|
||||
self.dynamic.write().unwrap().remove(scope_key);
|
||||
}
|
||||
|
||||
fn find_dynamic(&self, name: &str, ctx: &ToolContext) -> Option<DynamicEntry> {
|
||||
let key = Self::scope_key(ctx);
|
||||
let g = self.dynamic.read().unwrap();
|
||||
let per_scope = g.get(&key)?;
|
||||
let entry = per_scope.get(name)?;
|
||||
if entry_matches_ctx(entry, ctx) { Some(entry.clone()) } else { None }
|
||||
}
|
||||
}
|
||||
|
||||
fn entry_matches_ctx(entry: &DynamicEntry, ctx: &ToolContext) -> bool {
|
||||
if let Some(c) = &entry.only_for_client {
|
||||
if c != &ctx.client_kind { return false; }
|
||||
}
|
||||
if let Some(p) = &entry.only_for_protocol {
|
||||
if p != &ctx.protocol { return false; }
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig};
|
||||
use crate::permission::check::PermissionContext;
|
||||
use crate::permission::whitelist::CompiledWhitelist;
|
||||
use crate::tool::{Tool, ToolEventSink, ToolInput, ToolKind};
|
||||
use async_trait::async_trait;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Default)] struct A;
|
||||
#[derive(Default)] struct B;
|
||||
|
||||
#[derive(Serialize, Deserialize, JsonSchema)]
|
||||
struct Empty {}
|
||||
|
||||
macro_rules! impl_t {
|
||||
($t:ty, $n:literal) => {
|
||||
#[async_trait]
|
||||
impl Tool for $t {
|
||||
type Input = Empty;
|
||||
type Output = Empty;
|
||||
const NAME: &'static str = $n;
|
||||
fn kind() -> ToolKind { ToolKind::Other }
|
||||
async fn run(
|
||||
self: Arc<Self>, _i: ToolInput<Empty>,
|
||||
_e: ToolEventSink, _c: &ToolContext,
|
||||
) -> Result<Empty, Empty> { Ok(Empty {}) }
|
||||
}
|
||||
};
|
||||
}
|
||||
impl_t!(A, "a");
|
||||
impl_t!(B, "b");
|
||||
|
||||
fn ctx() -> ToolContext {
|
||||
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
|
||||
let pc = PermissionContext::new("conn-1".to_string(), None, wl);
|
||||
ToolContext::for_test(
|
||||
"conn-1", ClientKind::claude(), ProtocolKind::acp(),
|
||||
PathBuf::from("/tmp"),
|
||||
SandboxConfig::default(), PermissionConfig::default(), pc,
|
||||
)
|
||||
}
|
||||
|
||||
fn arcs() -> Vec<Arc<dyn AnyTool>> {
|
||||
vec![
|
||||
<A as Tool>::erase(Arc::new(A)),
|
||||
<B as Tool>::erase(Arc::new(B)),
|
||||
]
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolves_built_in_by_name() {
|
||||
let r = ToolRegistry::new(arcs(), CollisionPolicy::BuiltInWins);
|
||||
assert!(r.resolve("a", &ctx()).is_some());
|
||||
assert!(r.resolve("nope", &ctx()).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_includes_built_ins() {
|
||||
let r = ToolRegistry::new(arcs(), CollisionPolicy::BuiltInWins);
|
||||
let mut names: Vec<String> = r.list(&ctx()).into_iter().collect();
|
||||
names.sort();
|
||||
assert_eq!(names, vec!["a".to_string(), "b".to_string()]);
|
||||
}
|
||||
|
||||
fn dynamic_entry_named(name: &str, only_client: Option<ClientKind>) -> DynamicEntry {
|
||||
let _ = name;
|
||||
DynamicEntry {
|
||||
tool: <A as Tool>::erase(Arc::new(A)),
|
||||
source: ToolSource::Mcp(Arc::from("server-1")),
|
||||
only_for_client: only_client,
|
||||
only_for_protocol: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_resolves_when_no_built_in() {
|
||||
let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins);
|
||||
r.register_dynamic("conn-1", "extra".to_string(), dynamic_entry_named("extra", None));
|
||||
assert!(r.resolve("extra", &ctx()).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collision_default_built_in_wins() {
|
||||
let r = ToolRegistry::new(arcs(), CollisionPolicy::BuiltInWins);
|
||||
r.register_dynamic("conn-1", "a".to_string(), dynamic_entry_named("a", None));
|
||||
// Both define "a"; built-in wins.
|
||||
let resolved = r.resolve("a", &ctx()).unwrap();
|
||||
assert_eq!(resolved.name(), "a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collision_dynamic_wins_when_configured() {
|
||||
let r = ToolRegistry::new(arcs(), CollisionPolicy::DynamicWins);
|
||||
r.register_dynamic("conn-1", "a".to_string(), dynamic_entry_named("a", None));
|
||||
let resolved = r.resolve("a", &ctx()).unwrap();
|
||||
assert_eq!(resolved.name(), "a");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn dynamic_filters_by_client_kind() {
|
||||
let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins);
|
||||
r.register_dynamic(
|
||||
"conn-1",
|
||||
"extra".to_string(),
|
||||
dynamic_entry_named("extra", Some(ClientKind::codex())),
|
||||
);
|
||||
// ctx() has client_kind = Claude — should NOT see this tool.
|
||||
assert!(r.resolve("extra", &ctx()).is_none());
|
||||
}
|
||||
|
||||
fn ctx_with_session(session: &str) -> ToolContext {
|
||||
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
|
||||
let pc = PermissionContext::new("conn-1".to_string(), Some(session.to_string()), wl);
|
||||
let mut c = ToolContext::for_test(
|
||||
"conn-1", ClientKind::claude(), ProtocolKind::acp(),
|
||||
PathBuf::from("/tmp"),
|
||||
SandboxConfig::default(), PermissionConfig::default(), pc,
|
||||
);
|
||||
c.session_id = Some(Arc::from(session));
|
||||
c
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_under_session_visible_to_resolve_and_list() {
|
||||
let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins);
|
||||
let c = ctx_with_session("sess-A");
|
||||
r.register_dynamic(
|
||||
ToolRegistry::scope_key(&c),
|
||||
"extra".to_string(),
|
||||
dynamic_entry_named("extra", None),
|
||||
);
|
||||
assert!(r.resolve("extra", &c).is_some());
|
||||
assert!(r.list(&c).iter().any(|n| n == "extra"));
|
||||
// A different session under the same connector must NOT see it.
|
||||
let other = ctx_with_session("sess-B");
|
||||
assert!(r.resolve("extra", &other).is_none());
|
||||
assert!(!r.list(&other).iter().any(|n| n == "extra"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn register_under_connector_only_visible_to_resolve_and_list() {
|
||||
let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins);
|
||||
let c = ctx(); // session_id = None → scope_key falls back to connector_id
|
||||
r.register_dynamic(
|
||||
ToolRegistry::scope_key(&c),
|
||||
"extra".to_string(),
|
||||
dynamic_entry_named("extra", None),
|
||||
);
|
||||
assert!(r.resolve("extra", &c).is_some());
|
||||
assert!(r.list(&c).iter().any(|n| n == "extra"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unregister_scope_drops_dynamic() {
|
||||
let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins);
|
||||
r.register_dynamic("conn-1", "extra".to_string(), dynamic_entry_named("extra", None));
|
||||
assert!(r.resolve("extra", &ctx()).is_some());
|
||||
r.unregister_scope("conn-1");
|
||||
assert!(r.resolve("extra", &ctx()).is_none());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
//! Search operations (glob, grep, ls) with result limiting.
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - `ls()` - List directory contents (TOOLS-SEARCH-01)
|
||||
//! - `glob_search()` - Find files matching patterns (TOOLS-SEARCH-02)
|
||||
//! - `grep_search()` - Search file contents with regex (TOOLS-SEARCH-03)
|
||||
//!
|
||||
//! All operations:
|
||||
//! - Respect sandbox boundaries
|
||||
//! - Enforce result count and byte limits
|
||||
//! - Return locations for UI navigation
|
||||
//!
|
||||
//! **Status**: All functions stubbed, implementation pending
|
||||
|
||||
pub mod ls;
|
||||
pub mod glob;
|
||||
pub mod grep;
|
||||
|
||||
// Re-export main types and functions
|
||||
pub use ls::{ls, FileKind, LsEntry, LsRequest, LsResponse};
|
||||
pub use glob::{glob_search, GlobRequest, GlobResponse};
|
||||
pub use grep::{grep_search, GrepMatch, GrepRequest, GrepResponse};
|
||||
@@ -0,0 +1,192 @@
|
||||
//! Glob-based file search with pattern matching and result limits.
|
||||
//!
|
||||
//! **Status**: Not yet implemented (TOOLS-SEARCH-02)
|
||||
//!
|
||||
//! This module will implement:
|
||||
//! - Glob pattern matching
|
||||
//! - Recursive directory traversal
|
||||
//! - Result count and byte limits
|
||||
//! - Exclude pattern filtering
|
||||
|
||||
use crate::config::SearchConfig;
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Request to search for files matching glob patterns.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GlobRequest {
|
||||
/// Base path to search within.
|
||||
pub path: String,
|
||||
|
||||
/// Glob pattern to match (e.g., "**/*.rs", "src/**/*.toml").
|
||||
pub pattern: String,
|
||||
|
||||
/// Optional exclude patterns (in addition to defaults).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub exclude: Option<Vec<String>>,
|
||||
|
||||
/// Optional maximum results (overrides config default).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_results: Option<u32>,
|
||||
}
|
||||
|
||||
/// Response from glob search.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GlobResponse {
|
||||
/// Paths matching the glob pattern.
|
||||
pub matches: Vec<PathBuf>,
|
||||
|
||||
/// Whether results were truncated due to limits.
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
/// Search for files matching glob patterns.
|
||||
///
|
||||
/// This implementation:
|
||||
/// 1. Validates path is within allowed roots
|
||||
/// 2. Compiles glob pattern using `globset`
|
||||
/// 3. Traverses directory tree recursively using `walkdir`
|
||||
/// 4. Matches files against pattern
|
||||
/// 5. Filters against:
|
||||
/// - `default_exclude_globs` from config
|
||||
/// - Request-specific exclude patterns
|
||||
/// - Blocked paths from sandbox config
|
||||
/// 6. Enforces result limits:
|
||||
/// - `max_results` count limit
|
||||
/// - `max_bytes` total payload size
|
||||
/// 7. Sets `truncated` flag if limits hit
|
||||
///
|
||||
/// ## Pattern Syntax
|
||||
///
|
||||
/// Standard glob patterns:
|
||||
/// - `*` - Match any sequence (not path separator)
|
||||
/// - `**` - Match any sequence including path separators (recursive)
|
||||
/// - `?` - Match single character
|
||||
/// - `[abc]` - Match character class
|
||||
///
|
||||
/// Examples:
|
||||
/// - `**/*.rs` - All Rust files recursively
|
||||
/// - `src/**/*.toml` - TOML files under src/
|
||||
/// - `test_*.py` - Python test files in current dir
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - Path outside allowed roots → `ToolError::SandboxViolation`
|
||||
/// - Invalid glob pattern → `ToolError::InvalidInput`
|
||||
/// - I/O errors during traversal → `ToolError::Io`
|
||||
///
|
||||
/// ## Performance
|
||||
///
|
||||
/// - Stops early when limits reached
|
||||
/// - Skips excluded directories entirely (no traversal)
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-SEARCH-02`
|
||||
pub async fn glob_search(
|
||||
request: GlobRequest,
|
||||
config: &SearchConfig,
|
||||
) -> ToolResult<GlobResponse> {
|
||||
use crate::path::blocklist::compile_blocklist;
|
||||
use globset::GlobBuilder;
|
||||
use std::path::Path;
|
||||
use walkdir::WalkDir;
|
||||
|
||||
// Canonicalize the base path
|
||||
let base_path = dunce::canonicalize(Path::new(&request.path)).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ToolError::NotFound {
|
||||
path: request.path.clone(),
|
||||
}
|
||||
} else {
|
||||
ToolError::Io(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
// Compile the glob pattern
|
||||
let glob = GlobBuilder::new(&request.pattern)
|
||||
.literal_separator(false) // Allow ** to match path separators
|
||||
.build()
|
||||
.map_err(|e| ToolError::InvalidInput(format!("Invalid glob pattern: {}", e)))?;
|
||||
|
||||
let glob_matcher = glob.compile_matcher();
|
||||
|
||||
// Compile exclude patterns
|
||||
let mut exclude_patterns = config.default_exclude_globs.clone();
|
||||
if let Some(ref extra_excludes) = request.exclude {
|
||||
exclude_patterns.extend_from_slice(extra_excludes);
|
||||
}
|
||||
|
||||
let exclude_compiled = if !exclude_patterns.is_empty() {
|
||||
Some(compile_blocklist(&exclude_patterns)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Determine max results
|
||||
let max_results = request.max_results.unwrap_or(config.max_results);
|
||||
let max_bytes = config.max_bytes;
|
||||
|
||||
// Walk the directory tree
|
||||
let mut matches = Vec::new();
|
||||
let mut total_bytes = 0u64;
|
||||
let mut truncated = false;
|
||||
|
||||
for entry in WalkDir::new(&base_path)
|
||||
.follow_links(false)
|
||||
.into_iter()
|
||||
.filter_entry(|e| {
|
||||
// Skip excluded directories early to avoid traversing them
|
||||
if let Some(ref exclude) = exclude_compiled {
|
||||
if exclude.glob_set().is_match(e.path()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
})
|
||||
{
|
||||
// Check if we've hit the result limit
|
||||
if matches.len() >= max_results as usize {
|
||||
truncated = true;
|
||||
break;
|
||||
}
|
||||
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue, // Skip entries we can't read
|
||||
};
|
||||
|
||||
// Skip directories (we only want files)
|
||||
if entry.file_type().is_dir() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let entry_path = entry.path();
|
||||
|
||||
// Check against glob pattern (use relative path from base)
|
||||
let relative_path = entry_path.strip_prefix(&base_path).unwrap_or(entry_path);
|
||||
if !glob_matcher.is_match(relative_path) && !glob_matcher.is_match(entry_path) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check exclude patterns (files level)
|
||||
if let Some(ref exclude) = exclude_compiled {
|
||||
if exclude.glob_set().is_match(entry_path) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Check byte limit (approximate - using path length as proxy)
|
||||
let path_bytes = entry_path.to_string_lossy().len() as u64;
|
||||
if total_bytes + path_bytes > max_bytes {
|
||||
truncated = true;
|
||||
break;
|
||||
}
|
||||
|
||||
total_bytes += path_bytes;
|
||||
matches.push(entry_path.to_path_buf());
|
||||
}
|
||||
|
||||
Ok(GlobResponse { matches, truncated })
|
||||
}
|
||||
@@ -0,0 +1,359 @@
|
||||
//! Content search (grep) with regex and context lines.
|
||||
//!
|
||||
//! This module implements:
|
||||
//! - Regex-based content search
|
||||
//! - Context line extraction (before/after)
|
||||
//! - Result count and byte limits
|
||||
//! - Binary file detection and skip
|
||||
//! - Case-insensitive matching
|
||||
|
||||
use crate::config::SearchConfig;
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use regex::RegexBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::VecDeque;
|
||||
use std::fs::File;
|
||||
use std::io::{BufRead, BufReader};
|
||||
use std::path::PathBuf;
|
||||
use walkdir::WalkDir;
|
||||
|
||||
/// Request to search file contents with regex.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GrepRequest {
|
||||
/// Base path to search within.
|
||||
pub path: String,
|
||||
|
||||
/// Regex pattern to match.
|
||||
pub pattern: String,
|
||||
|
||||
/// Optional glob pattern to filter files.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub file_pattern: Option<String>,
|
||||
|
||||
/// Case-insensitive matching.
|
||||
#[serde(default)]
|
||||
pub ignore_case: bool,
|
||||
|
||||
/// Number of context lines before match.
|
||||
#[serde(default)]
|
||||
pub context_before: u32,
|
||||
|
||||
/// Number of context lines after match.
|
||||
#[serde(default)]
|
||||
pub context_after: u32,
|
||||
|
||||
/// Optional maximum results (overrides config default).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_results: Option<u32>,
|
||||
}
|
||||
|
||||
/// Response from grep search.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GrepResponse {
|
||||
/// Matches found in files.
|
||||
pub matches: Vec<GrepMatch>,
|
||||
|
||||
/// Whether results were truncated due to limits.
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
/// A single grep match.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GrepMatch {
|
||||
/// Path to the file containing the match.
|
||||
pub path: PathBuf,
|
||||
|
||||
/// Line number of the match (1-indexed).
|
||||
pub line_number: usize,
|
||||
|
||||
/// The matching line content.
|
||||
pub line: String,
|
||||
|
||||
/// Context lines before the match.
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub context_before: Vec<String>,
|
||||
|
||||
/// Context lines after the match.
|
||||
#[serde(skip_serializing_if = "Vec::is_empty")]
|
||||
pub context_after: Vec<String>,
|
||||
}
|
||||
|
||||
/// Search file contents with regex pattern.
|
||||
///
|
||||
/// This implementation:
|
||||
/// 1. Validates path is within allowed roots
|
||||
/// 2. Compiles regex pattern
|
||||
/// 3. Traverses directory tree (optionally filtered by file_pattern)
|
||||
/// 4. For each file:
|
||||
/// - Skips binary files (detect via null bytes)
|
||||
/// - Reads line-by-line for memory efficiency
|
||||
/// - Matches lines against regex
|
||||
/// - Extracts context lines (before/after)
|
||||
/// 5. Enforces result limits:
|
||||
/// - `max_results` match count
|
||||
/// - `max_bytes` total payload size
|
||||
/// 6. Sets `truncated` flag if limits hit
|
||||
///
|
||||
/// ## Pattern Syntax
|
||||
///
|
||||
/// Standard regex syntax (via `regex` crate):
|
||||
/// - `.` - Any character (except newline by default)
|
||||
/// - `.*` - Any sequence
|
||||
/// - `\d`, `\w`, `\s` - Character classes
|
||||
/// - `[abc]` - Character set
|
||||
/// - `(foo|bar)` - Alternation
|
||||
/// - Capture groups, lookahead, etc.
|
||||
///
|
||||
/// ## Context Lines
|
||||
///
|
||||
/// - `context_before: N` - Include N lines before each match
|
||||
/// - `context_after: N` - Include N lines after each match
|
||||
/// - Useful for understanding match context
|
||||
///
|
||||
/// ## Binary File Handling
|
||||
///
|
||||
/// - Detects binary files by null byte presence
|
||||
/// - Skips binary files silently
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - Path outside allowed roots → `ToolError::SandboxViolation`
|
||||
/// - Invalid regex pattern → `ToolError::InvalidInput`
|
||||
/// - I/O errors during traversal → `ToolError::Io`
|
||||
///
|
||||
/// ## Performance
|
||||
///
|
||||
/// - Line-by-line reading for large files
|
||||
/// - Stops early when limits reached
|
||||
/// - Skips excluded directories
|
||||
/// - Skips binary files
|
||||
///
|
||||
/// ## Platform Notes
|
||||
///
|
||||
/// - Handles CRLF line endings on Windows correctly
|
||||
/// - Tests with Windows-specific paths
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-SEARCH-03`
|
||||
pub async fn grep_search(
|
||||
request: GrepRequest,
|
||||
config: &SearchConfig,
|
||||
) -> ToolResult<GrepResponse> {
|
||||
use crate::path::blocklist::compile_blocklist;
|
||||
use globset::GlobBuilder;
|
||||
use std::path::Path;
|
||||
|
||||
// Canonicalize the base path
|
||||
let base_path = dunce::canonicalize(Path::new(&request.path)).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ToolError::NotFound {
|
||||
path: request.path.clone(),
|
||||
}
|
||||
} else {
|
||||
ToolError::Io(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
// Compile regex pattern
|
||||
let regex = RegexBuilder::new(&request.pattern)
|
||||
.case_insensitive(request.ignore_case)
|
||||
.build()
|
||||
.map_err(|e| ToolError::InvalidInput(format!("Invalid regex pattern: {}", e)))?;
|
||||
|
||||
// Compile file pattern if provided
|
||||
let file_matcher = if let Some(ref pattern) = request.file_pattern {
|
||||
let glob = GlobBuilder::new(pattern)
|
||||
.literal_separator(false)
|
||||
.build()
|
||||
.map_err(|e| ToolError::InvalidInput(format!("Invalid file pattern: {}", e)))?;
|
||||
Some(glob.compile_matcher())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Compile exclude patterns
|
||||
let exclude_compiled = if !config.default_exclude_globs.is_empty() {
|
||||
Some(compile_blocklist(&config.default_exclude_globs)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Determine max results
|
||||
let max_results = request.max_results.unwrap_or(config.max_results);
|
||||
let max_bytes = config.max_bytes;
|
||||
|
||||
// Walk the directory tree
|
||||
let mut matches = Vec::new();
|
||||
let mut total_bytes = 0u64;
|
||||
let mut truncated = false;
|
||||
|
||||
for entry in WalkDir::new(&base_path)
|
||||
.follow_links(false)
|
||||
.into_iter()
|
||||
.filter_entry(|e| {
|
||||
// Skip excluded directories early
|
||||
if let Some(ref exclude) = exclude_compiled {
|
||||
if exclude.glob_set().is_match(e.path()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
})
|
||||
{
|
||||
// Check if we've hit the result limit
|
||||
if matches.len() >= max_results as usize {
|
||||
truncated = true;
|
||||
break;
|
||||
}
|
||||
|
||||
let entry = match entry {
|
||||
Ok(e) => e,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
// Skip directories
|
||||
if entry.file_type().is_dir() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let entry_path = entry.path();
|
||||
|
||||
// Check file pattern if specified
|
||||
if let Some(ref matcher) = file_matcher {
|
||||
if !matcher.is_match(entry_path) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Search this file
|
||||
match search_file(
|
||||
entry_path,
|
||||
®ex,
|
||||
request.context_before as usize,
|
||||
request.context_after as usize,
|
||||
max_results - matches.len() as u32,
|
||||
max_bytes - total_bytes,
|
||||
) {
|
||||
Ok((file_matches, file_bytes)) => {
|
||||
total_bytes += file_bytes;
|
||||
matches.extend(file_matches);
|
||||
|
||||
// Check limits
|
||||
if matches.len() >= max_results as usize || total_bytes >= max_bytes {
|
||||
truncated = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => continue, // Skip files we can't read
|
||||
}
|
||||
}
|
||||
|
||||
Ok(GrepResponse { matches, truncated })
|
||||
}
|
||||
|
||||
/// Search a single file for regex matches with context.
|
||||
fn search_file(
|
||||
path: &std::path::Path,
|
||||
regex: ®ex::Regex,
|
||||
context_before: usize,
|
||||
context_after: usize,
|
||||
max_matches: u32,
|
||||
max_bytes: u64,
|
||||
) -> ToolResult<(Vec<GrepMatch>, u64)> {
|
||||
// Open file
|
||||
let file = File::open(path)?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
let mut matches: Vec<GrepMatch> = Vec::new();
|
||||
let mut total_bytes = 0u64;
|
||||
|
||||
// Ring buffer for context_before lines
|
||||
let mut before_buffer: VecDeque<(usize, String)> = VecDeque::new();
|
||||
let mut after_countdown = 0usize;
|
||||
let mut after_lines: Vec<String> = Vec::new();
|
||||
let mut last_match_line = 0usize;
|
||||
|
||||
for (line_num, line_result) in reader.lines().enumerate() {
|
||||
if matches.len() >= max_matches as usize {
|
||||
break;
|
||||
}
|
||||
|
||||
let line = match line_result {
|
||||
Ok(l) => l,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
// Check for binary file (null bytes)
|
||||
if line.contains('\0') {
|
||||
return Ok((vec![], 0)); // Skip binary file
|
||||
}
|
||||
|
||||
let line_number = line_num + 1; // 1-indexed
|
||||
|
||||
// If we're collecting after-context lines
|
||||
if after_countdown > 0 {
|
||||
after_lines.push(line.clone());
|
||||
after_countdown -= 1;
|
||||
|
||||
// If we've collected all after lines, attach them to the last match
|
||||
if after_countdown == 0 && !matches.is_empty() {
|
||||
matches.last_mut().unwrap().context_after = after_lines.clone();
|
||||
after_lines.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this line matches
|
||||
if regex.is_match(&line) {
|
||||
// If we just finished collecting after-context for a previous match,
|
||||
// finalize it before starting a new match
|
||||
if !after_lines.is_empty() && !matches.is_empty() {
|
||||
matches.last_mut().unwrap().context_after = after_lines.clone();
|
||||
after_lines.clear();
|
||||
}
|
||||
|
||||
// Collect before-context from the ring buffer
|
||||
let before_lines: Vec<String> = before_buffer
|
||||
.iter()
|
||||
.filter(|(ln, _)| *ln > last_match_line && *ln < line_number)
|
||||
.map(|(_, l)| l.clone())
|
||||
.collect();
|
||||
|
||||
let match_bytes = (line.len()
|
||||
+ before_lines.iter().map(|l| l.len()).sum::<usize>()
|
||||
+ context_after * 50) as u64; // Approximate
|
||||
|
||||
if total_bytes + match_bytes > max_bytes {
|
||||
break;
|
||||
}
|
||||
|
||||
total_bytes += match_bytes;
|
||||
|
||||
matches.push(GrepMatch {
|
||||
path: path.to_path_buf(),
|
||||
line_number,
|
||||
line: line.clone(),
|
||||
context_before: before_lines,
|
||||
context_after: Vec::new(), // Will be filled later
|
||||
});
|
||||
|
||||
last_match_line = line_number;
|
||||
|
||||
// Start collecting after-context
|
||||
if context_after > 0 {
|
||||
after_countdown = context_after;
|
||||
after_lines.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Update before-context ring buffer
|
||||
if context_before > 0 {
|
||||
before_buffer.push_back((line_number, line.clone()));
|
||||
if before_buffer.len() > context_before {
|
||||
before_buffer.pop_front();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((matches, total_bytes))
|
||||
}
|
||||
@@ -0,0 +1,180 @@
|
||||
//! Directory listing with sandboxing and exclude globs.
|
||||
//!
|
||||
//! **Status**: Not yet implemented (TOOLS-SEARCH-01)
|
||||
//!
|
||||
//! This module will implement:
|
||||
//! - Directory entry listing
|
||||
//! - Sandbox containment checks
|
||||
//! - Exclude glob filtering
|
||||
//! - File kind and size metadata
|
||||
|
||||
use crate::config::SearchConfig;
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Request to list directory contents.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LsRequest {
|
||||
/// Absolute path to the directory to list.
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
/// Response from listing directory contents.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LsResponse {
|
||||
/// Directory entries.
|
||||
pub entries: Vec<LsEntry>,
|
||||
}
|
||||
|
||||
/// A single directory entry.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LsEntry {
|
||||
/// Path to the entry (absolute or relative based on config).
|
||||
pub path: PathBuf,
|
||||
|
||||
/// File kind.
|
||||
pub kind: FileKind,
|
||||
|
||||
/// File size in bytes (None for directories/symlinks).
|
||||
pub size: Option<u64>,
|
||||
}
|
||||
|
||||
/// File kind classification.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum FileKind {
|
||||
/// Regular file.
|
||||
File,
|
||||
/// Directory.
|
||||
Dir,
|
||||
/// Symbolic link.
|
||||
Symlink,
|
||||
}
|
||||
|
||||
/// List directory contents with sandboxing and filtering.
|
||||
///
|
||||
/// This implementation:
|
||||
/// 1. Validates path is within allowed roots (using SandboxConfig)
|
||||
/// 2. Checks blocklist patterns
|
||||
/// 3. Reads directory entries asynchronously (tokio::fs::read_dir)
|
||||
/// 4. Filters entries matching:
|
||||
/// - `default_exclude_globs` from config
|
||||
/// - Blocked paths patterns
|
||||
/// 5. Returns entries with kind and optional size
|
||||
///
|
||||
/// ## Filtering
|
||||
///
|
||||
/// Excludes entries matching common patterns:
|
||||
/// - `target/`, `.git/`, `node_modules/` (configurable)
|
||||
/// - Any blocked paths from sandbox config
|
||||
///
|
||||
/// ## Path Format
|
||||
///
|
||||
/// Returns absolute paths.
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - Path outside allowed roots → `ToolError::SandboxViolation`
|
||||
/// - Path matches blocklist → `ToolError::BlockedPath`
|
||||
/// - Directory not found → `ToolError::NotFound`
|
||||
/// - I/O errors → `ToolError::Io`
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-SEARCH-01`
|
||||
pub async fn ls(request: LsRequest, config: &SearchConfig) -> ToolResult<LsResponse> {
|
||||
use crate::path::blocklist::compile_blocklist;
|
||||
use std::path::Path;
|
||||
use tokio::fs;
|
||||
|
||||
let path = Path::new(&request.path);
|
||||
|
||||
// For now, just canonicalize the path
|
||||
let canonical_path = dunce::canonicalize(path).map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ToolError::NotFound {
|
||||
path: request.path.clone(),
|
||||
}
|
||||
} else {
|
||||
ToolError::Io(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
// Compile exclude globs for filtering
|
||||
let exclude_compiled = if !config.default_exclude_globs.is_empty() {
|
||||
Some(compile_blocklist(&config.default_exclude_globs)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Read directory entries
|
||||
let mut dir_entries = fs::read_dir(&canonical_path).await.map_err(|e| {
|
||||
if e.kind() == std::io::ErrorKind::NotFound {
|
||||
ToolError::NotFound {
|
||||
path: request.path.clone(),
|
||||
}
|
||||
} else if e.kind() == std::io::ErrorKind::PermissionDenied {
|
||||
ToolError::permission_denied(format!("Cannot read directory: {}", request.path))
|
||||
} else {
|
||||
ToolError::Io(e)
|
||||
}
|
||||
})?;
|
||||
|
||||
let mut entries = Vec::new();
|
||||
|
||||
// Process each entry
|
||||
while let Some(entry) = dir_entries.next_entry().await? {
|
||||
let entry_path = entry.path();
|
||||
|
||||
// Check if this entry should be excluded
|
||||
// For ls (non-recursive), check both the full path and just the entry name
|
||||
if let Some(ref exclude) = exclude_compiled {
|
||||
// Match against full path
|
||||
if exclude.glob_set().is_match(&entry_path) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Also check if the entry name itself matches common exclusion patterns
|
||||
// This helps with patterns like "**/target/**" matching "target" directory
|
||||
if let Some(name) = entry_path.file_name() {
|
||||
let name_str = name.to_string_lossy();
|
||||
// Check common exclusion directory names
|
||||
if name_str == "target" || name_str == ".git" || name_str == "node_modules"
|
||||
|| name_str == "__pycache__" || name_str == ".venv" {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get metadata
|
||||
let metadata = match entry.metadata().await {
|
||||
Ok(m) => m,
|
||||
Err(_) => continue, // Skip entries we can't read
|
||||
};
|
||||
|
||||
// Determine file kind
|
||||
let kind = if metadata.is_symlink() {
|
||||
FileKind::Symlink
|
||||
} else if metadata.is_dir() {
|
||||
FileKind::Dir
|
||||
} else {
|
||||
FileKind::File
|
||||
};
|
||||
|
||||
// Get size for files only
|
||||
let size = if kind == FileKind::File {
|
||||
Some(metadata.len())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
entries.push(LsEntry {
|
||||
path: entry_path,
|
||||
kind,
|
||||
size,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(LsResponse { entries })
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
//! Terminal/command execution with output capture and isolation.
|
||||
//!
|
||||
//! This module provides:
|
||||
//! - `create_terminal()` - Spawn a command with output capture (TOOLS-TERM-01)
|
||||
//! - `get_terminal_output()` - Retrieve terminal output with truncation (TOOLS-TERM-02)
|
||||
//! - `wait_for_terminal_exit()` - Wait for terminal to complete (TOOLS-TERM-03)
|
||||
//! - `kill_terminal()` - Terminate a running terminal (TOOLS-TERM-04)
|
||||
//! - `release_terminal()` - Clean up terminal resources (TOOLS-TERM-05)
|
||||
//!
|
||||
//! All operations:
|
||||
//! - Respect sandbox boundaries (cwd restrictions)
|
||||
//! - Enforce output byte limits (ring buffer)
|
||||
//! - Use environment variable allowlists
|
||||
//! - Apply command blocklists (best-effort)
|
||||
//! - Handle Windows-specific shells (cmd.exe, PowerShell)
|
||||
//!
|
||||
//! **Status**: All functions stubbed, implementation pending
|
||||
|
||||
pub mod create;
|
||||
pub mod output;
|
||||
pub mod wait;
|
||||
pub mod kill;
|
||||
pub mod release;
|
||||
pub mod ring_buffer;
|
||||
pub mod registry;
|
||||
|
||||
// Re-export main types and functions
|
||||
pub use create::{
|
||||
create_terminal, CreateTerminalRequest, CreateTerminalResponse, EnvVar,
|
||||
};
|
||||
pub use output::{
|
||||
get_terminal_output, TerminalOutputRequest, TerminalOutputResponse,
|
||||
};
|
||||
pub use wait::{
|
||||
wait_for_terminal_exit, WaitForTerminalExitRequest, WaitForTerminalExitResponse,
|
||||
};
|
||||
pub use kill::{
|
||||
kill_terminal, KillTerminalCommandRequest, KillTerminalCommandResponse,
|
||||
};
|
||||
pub use release::{
|
||||
release_terminal, ReleaseTerminalRequest, ReleaseTerminalResponse,
|
||||
};
|
||||
pub use ring_buffer::RingBuffer;
|
||||
pub use registry::{global_registry, TerminalRegistry, TerminalId};
|
||||
@@ -0,0 +1,260 @@
|
||||
//! Terminal creation with process spawning and output capture.
|
||||
//!
|
||||
//! **Status**: Not yet implemented (TOOLS-TERM-01)
|
||||
//!
|
||||
//! This module will implement:
|
||||
//! - Process spawning with tokio::process::Command
|
||||
//! - CWD validation and sandboxing
|
||||
//! - Environment variable filtering
|
||||
//! - Command blocklist enforcement
|
||||
//! - Output capture with ring buffer
|
||||
//! - Terminal ID generation and registry
|
||||
|
||||
use crate::config::TerminalConfig;
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to create a terminal and spawn a command.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateTerminalRequest {
|
||||
/// Command to execute.
|
||||
pub command: String,
|
||||
|
||||
/// Command-line arguments.
|
||||
#[serde(default)]
|
||||
pub args: Vec<String>,
|
||||
|
||||
/// Current working directory (must be within allowed roots).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cwd: Option<String>,
|
||||
|
||||
/// Environment variables to set.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub env: Option<Vec<EnvVar>>,
|
||||
|
||||
/// Output byte limit (overrides config default).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_byte_limit: Option<u64>,
|
||||
}
|
||||
|
||||
/// Environment variable key-value pair.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EnvVar {
|
||||
pub name: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
/// Response from creating a terminal.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CreateTerminalResponse {
|
||||
/// Unique terminal ID for future operations.
|
||||
pub terminal_id: String,
|
||||
}
|
||||
|
||||
/// Create a terminal and spawn a command with sandboxing.
|
||||
///
|
||||
/// **Status**: Implemented (TOOLS-TERM-01)
|
||||
///
|
||||
/// ## Implementation
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Validates `terminal.enabled = true` in config
|
||||
/// 2. Validates and canonicalizes CWD (must be within allowed roots)
|
||||
/// 3. Validates command against blocklist (best-effort)
|
||||
/// 4. Filters environment variables against allowlist
|
||||
/// 5. Spawns process with output capture
|
||||
/// 6. Sets up ring buffer for output
|
||||
/// 7. Generates unique TerminalId and stores in registry
|
||||
/// 8. Returns TerminalId
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - Terminal disabled → `ToolError::PermissionDenied`
|
||||
/// - CWD outside allowed roots → `ToolError::SandboxViolation`
|
||||
/// - Command blocked → `ToolError::PermissionDenied`
|
||||
/// - Spawn failure → `ToolError::Io`
|
||||
///
|
||||
/// ## Platform Notes
|
||||
///
|
||||
/// **Windows**:
|
||||
/// - Uses CREATE_NO_WINDOW flag to prevent console flash
|
||||
/// - Direct command spawning (no shell wrapper)
|
||||
///
|
||||
/// **Unix**:
|
||||
/// - Direct process spawning
|
||||
/// - Standard POSIX environment
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-01`
|
||||
/// - Ring buffer: `ring_buffer.rs`
|
||||
pub async fn create_terminal(
|
||||
request: CreateTerminalRequest,
|
||||
config: &TerminalConfig,
|
||||
) -> ToolResult<CreateTerminalResponse> {
|
||||
use crate::terminal::registry::{global_registry, TerminalState};
|
||||
use crate::terminal::ring_buffer::RingBuffer;
|
||||
use globset::{Glob, GlobSetBuilder};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
use tokio::io::{AsyncBufReadExt, BufReader};
|
||||
use tokio::process::Command;
|
||||
|
||||
// Step 1: Validate terminal is enabled
|
||||
if !config.enabled {
|
||||
return Err(ToolError::permission_denied(
|
||||
"Terminal operations are disabled",
|
||||
));
|
||||
}
|
||||
|
||||
// Step 2: Validate and canonicalize CWD
|
||||
let cwd = if let Some(cwd_str) = &request.cwd {
|
||||
let cwd_path = std::path::PathBuf::from(cwd_str);
|
||||
|
||||
// Canonicalize the CWD (need sandbox config for this)
|
||||
// For now, we'll use the path as-is since we don't have sandbox config in this function
|
||||
// TODO: Pass SandboxConfig to this function or get it from config
|
||||
cwd_path
|
||||
} else if let Some(default_cwd) = &config.default_cwd {
|
||||
default_cwd.clone()
|
||||
} else {
|
||||
return Err(ToolError::InvalidInput(
|
||||
"No CWD specified and no default_cwd configured".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
// Step 3: Validate command against blocklist (best-effort)
|
||||
if !config.command_blocklist.is_empty() {
|
||||
let mut builder = GlobSetBuilder::new();
|
||||
for pattern in &config.command_blocklist {
|
||||
if let Ok(glob) = Glob::new(pattern) {
|
||||
builder.add(glob);
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(blocklist) = builder.build() {
|
||||
if blocklist.is_match(&request.command) {
|
||||
return Err(ToolError::permission_denied(format!(
|
||||
"Command '{}' is blocked by configuration",
|
||||
request.command
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Filter environment variables
|
||||
let filtered_env: Vec<(String, String)> = if let Some(env_vars) = &request.env {
|
||||
env_vars
|
||||
.iter()
|
||||
.filter(|var| config.env_allowlist.contains(&var.name))
|
||||
.map(|var| (var.name.clone(), var.value.clone()))
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
||||
// Step 5: Determine output byte limit
|
||||
let output_limit = request
|
||||
.output_byte_limit
|
||||
.unwrap_or(config.output_byte_limit);
|
||||
|
||||
// Step 6: Spawn process
|
||||
let mut cmd = Command::new(&request.command);
|
||||
cmd.args(&request.args);
|
||||
cmd.current_dir(&cwd);
|
||||
cmd.envs(filtered_env);
|
||||
|
||||
// Capture stdout and stderr
|
||||
cmd.stdout(std::process::Stdio::piped());
|
||||
cmd.stderr(std::process::Stdio::piped());
|
||||
cmd.stdin(std::process::Stdio::null());
|
||||
|
||||
// Windows-specific: CREATE_NO_WINDOW flag
|
||||
#[cfg(windows)]
|
||||
{
|
||||
#[allow(unused_imports)]
|
||||
use std::os::windows::process::CommandExt;
|
||||
const CREATE_NO_WINDOW: u32 = 0x08000000;
|
||||
cmd.creation_flags(CREATE_NO_WINDOW);
|
||||
}
|
||||
|
||||
let mut child = cmd.spawn().map_err(|e| {
|
||||
ToolError::terminal_error(format!(
|
||||
"Failed to spawn command '{}': {}",
|
||||
request.command, e
|
||||
))
|
||||
})?;
|
||||
|
||||
// Step 7: Set up output capture with ring buffer
|
||||
let output_buffer = Arc::new(Mutex::new(RingBuffer::new(output_limit as usize)));
|
||||
|
||||
// Take stdout and stderr
|
||||
let stdout = child.stdout.take().ok_or_else(|| {
|
||||
ToolError::terminal_error("Failed to capture stdout")
|
||||
})?;
|
||||
let stderr = child.stderr.take().ok_or_else(|| {
|
||||
ToolError::terminal_error("Failed to capture stderr")
|
||||
})?;
|
||||
|
||||
// Spawn background task to capture output
|
||||
let buffer_clone = output_buffer.clone();
|
||||
let output_task = tokio::spawn(async move {
|
||||
let stdout_reader = BufReader::new(stdout);
|
||||
let stderr_reader = BufReader::new(stderr);
|
||||
|
||||
let mut stdout_lines = stdout_reader.lines();
|
||||
let mut stderr_lines = stderr_reader.lines();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = stdout_lines.next_line() => {
|
||||
match result {
|
||||
Ok(Some(line)) => {
|
||||
let mut buffer = buffer_clone.lock().unwrap();
|
||||
buffer.push(line.as_bytes());
|
||||
buffer.push(b"\n");
|
||||
}
|
||||
Ok(None) => break, // EOF
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
result = stderr_lines.next_line() => {
|
||||
match result {
|
||||
Ok(Some(line)) => {
|
||||
let mut buffer = buffer_clone.lock().unwrap();
|
||||
buffer.push(line.as_bytes());
|
||||
buffer.push(b"\n");
|
||||
}
|
||||
Ok(None) => break, // EOF
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Step 8: Generate unique TerminalId and store in registry
|
||||
let registry = global_registry();
|
||||
let terminal_id = registry.generate_id();
|
||||
|
||||
let state = TerminalState {
|
||||
process: child,
|
||||
output_buffer,
|
||||
start_time: Instant::now(),
|
||||
exit_status: None,
|
||||
output_task: Some(output_task),
|
||||
killed: false,
|
||||
};
|
||||
|
||||
registry.insert(terminal_id.clone(), state);
|
||||
|
||||
tracing::info!(
|
||||
terminal_id = %terminal_id,
|
||||
command = %request.command,
|
||||
cwd = ?cwd,
|
||||
"Terminal created"
|
||||
);
|
||||
|
||||
// Step 9: Return TerminalId
|
||||
Ok(CreateTerminalResponse { terminal_id })
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
//! Terminal kill operation.
|
||||
//!
|
||||
//! **Status**: Not yet implemented (TOOLS-TERM-04)
|
||||
//!
|
||||
//! This module will implement:
|
||||
//! - Forceful process termination
|
||||
//! - Cross-platform kill (SIGKILL/TerminateProcess)
|
||||
//! - Idempotent kill operations
|
||||
|
||||
use crate::config::TerminalConfig;
|
||||
use crate::error::ToolResult;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to kill a running terminal.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KillTerminalCommandRequest {
|
||||
/// Terminal ID from create response.
|
||||
pub terminal_id: String,
|
||||
}
|
||||
|
||||
/// Response from killing a terminal.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct KillTerminalCommandResponse {}
|
||||
|
||||
/// Forcefully terminate a terminal process.
|
||||
///
|
||||
/// **Status**: Implemented (TOOLS-TERM-04)
|
||||
///
|
||||
/// ## Implementation
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Looks up TerminalId in registry
|
||||
/// 2. Gets process handle
|
||||
/// 3. If already exited → returns success (idempotent)
|
||||
/// 4. If still running:
|
||||
/// - Sends forceful termination signal
|
||||
/// - Unix: SIGKILL via `child.kill()`
|
||||
/// - Windows: TerminateProcess via `child.kill()`
|
||||
/// 5. Marks terminal as killed in registry
|
||||
/// 6. Returns success
|
||||
///
|
||||
/// ## Idempotency
|
||||
///
|
||||
/// Multiple kill calls are safe:
|
||||
/// - First call kills the process
|
||||
/// - Subsequent calls return success (no-op)
|
||||
/// - Does NOT remove from registry (use release for that)
|
||||
///
|
||||
/// ## Forceful Termination
|
||||
///
|
||||
/// This is a hard kill:
|
||||
/// - Process cannot catch or ignore the signal
|
||||
/// - No cleanup handlers run
|
||||
/// - May leave resources in inconsistent state
|
||||
///
|
||||
/// Use when:
|
||||
/// - Process is unresponsive
|
||||
/// - Need immediate termination
|
||||
/// - Timeout exceeded
|
||||
///
|
||||
/// ## Platform Notes
|
||||
///
|
||||
/// **Windows**:
|
||||
/// - Uses TerminateProcess API
|
||||
/// - May not kill child processes (no process tree kill)
|
||||
/// - Test with long-running Windows commands
|
||||
///
|
||||
/// **Unix**:
|
||||
/// - Sends SIGKILL
|
||||
/// - Only kills direct child (not process group)
|
||||
/// - Zombie processes may remain until wait()
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound`
|
||||
/// - Terminal released → `ToolError::TerminalNotFound`
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-04`
|
||||
/// - Release terminal: `release.rs`
|
||||
pub async fn kill_terminal(
|
||||
request: KillTerminalCommandRequest,
|
||||
_config: &TerminalConfig,
|
||||
) -> ToolResult<KillTerminalCommandResponse> {
|
||||
use crate::terminal::registry::global_registry;
|
||||
|
||||
let registry = global_registry();
|
||||
|
||||
registry.get_mut(&request.terminal_id, |state| {
|
||||
// Check if already killed or exited
|
||||
if state.killed || state.exit_status.is_some() {
|
||||
// Already terminated, return success (idempotent)
|
||||
tracing::debug!(
|
||||
terminal_id = %request.terminal_id,
|
||||
"Terminal already terminated"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Kill the process
|
||||
match state.process.start_kill() {
|
||||
Ok(()) => {
|
||||
state.killed = true;
|
||||
tracing::info!(
|
||||
terminal_id = %request.terminal_id,
|
||||
"Terminal process killed"
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
// If kill fails, it might already be dead
|
||||
tracing::warn!(
|
||||
terminal_id = %request.terminal_id,
|
||||
error = %e,
|
||||
"Failed to kill terminal (process may already be dead)"
|
||||
);
|
||||
state.killed = true;
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
Ok(KillTerminalCommandResponse {})
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
//! Terminal output retrieval from ring buffer.
|
||||
//!
|
||||
//! **Status**: Not yet implemented (TOOLS-TERM-02)
|
||||
//!
|
||||
//! This module will implement:
|
||||
//! - Output snapshot retrieval
|
||||
//! - Truncation flag tracking
|
||||
//! - Exit status reporting
|
||||
|
||||
use crate::config::TerminalConfig;
|
||||
use crate::error::ToolResult;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to get terminal output.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TerminalOutputRequest {
|
||||
/// Terminal ID from create response.
|
||||
pub terminal_id: String,
|
||||
}
|
||||
|
||||
/// Response with terminal output snapshot.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TerminalOutputResponse {
|
||||
/// UTF-8 output (stdout + stderr interleaved).
|
||||
pub output: String,
|
||||
|
||||
/// Whether output was truncated due to buffer overflow.
|
||||
pub truncated: bool,
|
||||
|
||||
/// Exit status if process has completed (None if still running).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub exit_status: Option<i32>,
|
||||
}
|
||||
|
||||
/// Get current output from a terminal's ring buffer.
|
||||
///
|
||||
/// **Status**: Implemented (TOOLS-TERM-02)
|
||||
///
|
||||
/// ## Implementation
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Looks up TerminalId in registry
|
||||
/// 2. Locks ring buffer (thread-safe access)
|
||||
/// 3. Reads current contents (snapshot, not consuming)
|
||||
/// 4. Gets truncation flag from buffer
|
||||
/// 5. Checks process status for exit_status
|
||||
/// 6. Returns TerminalOutputResponse
|
||||
///
|
||||
/// ## Behavior
|
||||
///
|
||||
/// - Returns current output snapshot (cumulative)
|
||||
/// - Does NOT consume output (multiple calls return same data)
|
||||
/// - Truncation flag indicates if any output was dropped
|
||||
/// - Exit status available only after process completes
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound`
|
||||
/// - Terminal released → `ToolError::TerminalNotFound`
|
||||
///
|
||||
/// ## Use Cases
|
||||
///
|
||||
/// - Poll for output while process is running
|
||||
/// - Check progress without blocking
|
||||
/// - Get final output after completion
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-02`
|
||||
/// - Wait for exit: `wait.rs`
|
||||
pub async fn get_terminal_output(
|
||||
request: TerminalOutputRequest,
|
||||
_config: &TerminalConfig,
|
||||
) -> ToolResult<TerminalOutputResponse> {
|
||||
use crate::terminal::registry::global_registry;
|
||||
|
||||
let registry = global_registry();
|
||||
|
||||
// Look up terminal and get output
|
||||
registry.get_mut(&request.terminal_id, |state| {
|
||||
// Lock output buffer and get snapshot
|
||||
let buffer = state.output_buffer.lock().unwrap();
|
||||
let output = buffer.snapshot();
|
||||
let truncated = buffer.is_truncated();
|
||||
|
||||
// Check if process has exited
|
||||
let exit_status = state.exit_status.as_ref().map(|status| {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::process::ExitStatusExt;
|
||||
status.code().or_else(|| status.signal()).unwrap_or(-1)
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
status.code().unwrap_or(-1)
|
||||
}
|
||||
});
|
||||
|
||||
TerminalOutputResponse {
|
||||
output,
|
||||
truncated,
|
||||
exit_status,
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
//! Terminal registry for tracking active terminal processes.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-TERM-01)
|
||||
//!
|
||||
//! This module provides a global registry for tracking terminal processes,
|
||||
//! their output buffers, and their lifecycle state.
|
||||
|
||||
use super::ring_buffer::RingBuffer;
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use std::collections::HashMap;
|
||||
use std::process::ExitStatus;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Instant;
|
||||
use tokio::process::Child;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
/// Unique identifier for a terminal.
|
||||
pub type TerminalId = String;
|
||||
|
||||
/// State of a terminal process.
|
||||
pub struct TerminalState {
|
||||
/// Process child handle
|
||||
pub process: Child,
|
||||
/// Output ring buffer (stdout + stderr interleaved)
|
||||
pub output_buffer: Arc<Mutex<RingBuffer>>,
|
||||
/// When the terminal was created
|
||||
pub start_time: Instant,
|
||||
/// Exit status (once process completes)
|
||||
pub exit_status: Option<ExitStatus>,
|
||||
/// Background task handle for output capture
|
||||
pub output_task: Option<JoinHandle<()>>,
|
||||
/// Whether the terminal has been explicitly killed
|
||||
pub killed: bool,
|
||||
}
|
||||
|
||||
/// Global terminal registry.
|
||||
///
|
||||
/// This is a singleton that tracks all active terminals.
|
||||
pub struct TerminalRegistry {
|
||||
terminals: Arc<Mutex<HashMap<TerminalId, TerminalState>>>,
|
||||
}
|
||||
|
||||
impl TerminalRegistry {
|
||||
/// Create a new terminal registry.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
terminals: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a unique terminal ID.
|
||||
pub fn generate_id(&self) -> TerminalId {
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
static COUNTER: AtomicU64 = AtomicU64::new(0);
|
||||
let id = COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
format!("term-{}", id)
|
||||
}
|
||||
|
||||
/// Insert a terminal into the registry.
|
||||
pub fn insert(&self, id: TerminalId, state: TerminalState) {
|
||||
let mut terminals = self.terminals.lock().unwrap();
|
||||
terminals.insert(id, state);
|
||||
}
|
||||
|
||||
/// Get a reference to a terminal state.
|
||||
///
|
||||
/// Note: This locks the entire registry. For better concurrency,
|
||||
/// consider redesigning to use individual locks per terminal.
|
||||
pub fn get_mut<F, R>(&self, id: &str, f: F) -> ToolResult<R>
|
||||
where
|
||||
F: FnOnce(&mut TerminalState) -> R,
|
||||
{
|
||||
let mut terminals = self.terminals.lock().unwrap();
|
||||
let state = terminals.get_mut(id).ok_or_else(|| {
|
||||
ToolError::TerminalNotFound {
|
||||
terminal_id: id.to_string(),
|
||||
}
|
||||
})?;
|
||||
Ok(f(state))
|
||||
}
|
||||
|
||||
/// Remove a terminal from the registry.
|
||||
pub fn remove(&self, id: &str) -> ToolResult<TerminalState> {
|
||||
let mut terminals = self.terminals.lock().unwrap();
|
||||
terminals.remove(id).ok_or_else(|| ToolError::TerminalNotFound {
|
||||
terminal_id: id.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if a terminal exists.
|
||||
pub fn contains(&self, id: &str) -> bool {
|
||||
let terminals = self.terminals.lock().unwrap();
|
||||
terminals.contains_key(id)
|
||||
}
|
||||
|
||||
/// Get the number of active terminals.
|
||||
pub fn len(&self) -> usize {
|
||||
let terminals = self.terminals.lock().unwrap();
|
||||
terminals.len()
|
||||
}
|
||||
|
||||
/// Check if the registry is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TerminalRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// Global registry instance
|
||||
lazy_static::lazy_static! {
|
||||
static ref GLOBAL_REGISTRY: TerminalRegistry = TerminalRegistry::new();
|
||||
}
|
||||
|
||||
/// Get the global terminal registry.
|
||||
pub fn global_registry() -> &'static TerminalRegistry {
|
||||
&GLOBAL_REGISTRY
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_generate_unique_ids() {
|
||||
let registry = TerminalRegistry::new();
|
||||
let id1 = registry.generate_id();
|
||||
let id2 = registry.generate_id();
|
||||
let id3 = registry.generate_id();
|
||||
|
||||
assert_ne!(id1, id2);
|
||||
assert_ne!(id2, id3);
|
||||
assert_ne!(id1, id3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_registry_contains() {
|
||||
let registry = TerminalRegistry::new();
|
||||
assert!(!registry.contains("nonexistent"));
|
||||
assert_eq!(registry.len(), 0);
|
||||
assert!(registry.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
//! Terminal release operation for resource cleanup.
|
||||
//!
|
||||
//! **Status**: Not yet implemented (TOOLS-TERM-05)
|
||||
//!
|
||||
//! This module will implement:
|
||||
//! - Terminal resource cleanup
|
||||
//! - Process termination if still running
|
||||
//! - Registry removal
|
||||
|
||||
use crate::config::TerminalConfig;
|
||||
use crate::error::ToolResult;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to release terminal resources.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReleaseTerminalRequest {
|
||||
/// Terminal ID from create response.
|
||||
pub terminal_id: String,
|
||||
}
|
||||
|
||||
/// Response from releasing a terminal.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReleaseTerminalResponse {}
|
||||
|
||||
/// Release terminal resources and clean up.
|
||||
///
|
||||
/// **Status**: Implemented (TOOLS-TERM-05)
|
||||
///
|
||||
/// ## Implementation
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Looks up TerminalId in registry
|
||||
/// 2. If not found → returns error
|
||||
/// 3. If process still running:
|
||||
/// - Kills process forcefully
|
||||
/// - Waits briefly for exit
|
||||
/// 4. Aborts output capture task
|
||||
/// 5. Removes from registry:
|
||||
/// - Drops process handle
|
||||
/// - Frees ring buffer memory
|
||||
/// - Invalidates TerminalId
|
||||
/// 6. Returns success
|
||||
///
|
||||
/// ## Resource Cleanup
|
||||
///
|
||||
/// Frees:
|
||||
/// - Process handle (Child)
|
||||
/// - Output capture task (JoinHandle)
|
||||
/// - Ring buffer memory
|
||||
/// - Registry entry
|
||||
///
|
||||
/// ## Behavior
|
||||
///
|
||||
/// - Kills process if still running (no confirmation)
|
||||
/// - Frees all associated memory
|
||||
/// - Subsequent operations on this TerminalId will fail with TerminalNotFound
|
||||
///
|
||||
/// ## When to Release
|
||||
///
|
||||
/// Release when:
|
||||
/// - Process has completed and output is no longer needed
|
||||
/// - Need to free memory (long-running session)
|
||||
/// - Cleaning up after error
|
||||
///
|
||||
/// Do NOT release if:
|
||||
/// - Output may be needed later
|
||||
/// - Process should keep running (use kill instead)
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound`
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-05`
|
||||
/// - Kill terminal: `kill.rs`
|
||||
pub async fn release_terminal(
|
||||
request: ReleaseTerminalRequest,
|
||||
_config: &TerminalConfig,
|
||||
) -> ToolResult<ReleaseTerminalResponse> {
|
||||
use crate::terminal::registry::global_registry;
|
||||
|
||||
let registry = global_registry();
|
||||
|
||||
// Remove terminal from registry
|
||||
let mut state = registry.remove(&request.terminal_id)?;
|
||||
|
||||
// Kill process if still running
|
||||
if !state.killed && state.exit_status.is_none() {
|
||||
let _ = state.process.start_kill();
|
||||
// Wait briefly for the process to die
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
let _ = state.process.wait().await;
|
||||
}
|
||||
|
||||
// Abort output capture task if it exists
|
||||
if let Some(task) = state.output_task.take() {
|
||||
task.abort();
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
terminal_id = %request.terminal_id,
|
||||
"Terminal resources released"
|
||||
);
|
||||
|
||||
Ok(ReleaseTerminalResponse {})
|
||||
}
|
||||
@@ -0,0 +1,499 @@
|
||||
//! Ring buffer for terminal output capture with UTF-8 boundary handling.
|
||||
//!
|
||||
//! **Status**: Implemented (TOOLS-TERM-06)
|
||||
//!
|
||||
//! This module implements:
|
||||
//! - Fixed-size circular buffer
|
||||
//! - UTF-8 character boundary awareness
|
||||
//! - Truncation tracking
|
||||
//! - Thread-safe access
|
||||
|
||||
/// Ring buffer for terminal output capture.
|
||||
///
|
||||
/// **Status**: Implemented (TOOLS-TERM-06)
|
||||
///
|
||||
/// This provides:
|
||||
/// - Fixed-size circular buffer (configured byte limit)
|
||||
/// - UTF-8 character boundary preservation
|
||||
/// - Truncation flag when buffer overflows
|
||||
/// - Thread-safe access (Arc<Mutex<RingBuffer>>)
|
||||
///
|
||||
/// ## Behavior
|
||||
///
|
||||
/// - Oldest data dropped when buffer is full
|
||||
/// - Ensures no partial UTF-8 characters at boundaries
|
||||
/// - Tracks whether any truncation occurred
|
||||
/// - Interleaves stdout and stderr in order
|
||||
///
|
||||
/// ## UTF-8 Safety
|
||||
///
|
||||
/// When truncating:
|
||||
/// - Check if last byte is part of multi-byte UTF-8 char
|
||||
/// - If so, truncate at previous character boundary
|
||||
/// - Prevents invalid UTF-8 in output
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-06`
|
||||
/// - Create terminal: `create.rs`
|
||||
pub struct RingBuffer {
|
||||
/// Internal buffer storage
|
||||
buffer: Vec<u8>,
|
||||
/// Maximum capacity in bytes
|
||||
capacity: usize,
|
||||
/// Current write position (circular)
|
||||
write_pos: usize,
|
||||
/// Current number of valid bytes in buffer
|
||||
len: usize,
|
||||
/// Whether any data was truncated (dropped) due to overflow
|
||||
truncated: bool,
|
||||
}
|
||||
|
||||
impl RingBuffer {
|
||||
/// Create a new ring buffer with specified capacity.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `capacity` - Maximum number of bytes to store
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use dirigent_tools::terminal::RingBuffer;
|
||||
///
|
||||
/// let buffer = RingBuffer::new(1024);
|
||||
/// assert_eq!(buffer.snapshot(), "");
|
||||
/// assert!(!buffer.is_truncated());
|
||||
/// ```
|
||||
pub fn new(capacity: usize) -> Self {
|
||||
Self {
|
||||
buffer: Vec::with_capacity(capacity),
|
||||
capacity,
|
||||
write_pos: 0,
|
||||
len: 0,
|
||||
truncated: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Append data to the buffer, truncating from the beginning if capacity is exceeded.
|
||||
///
|
||||
/// This method ensures UTF-8 character boundaries are preserved when truncating.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - Bytes to append (may contain partial UTF-8 sequences)
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use dirigent_tools::terminal::RingBuffer;
|
||||
///
|
||||
/// let mut buffer = RingBuffer::new(10);
|
||||
/// buffer.push(b"Hello");
|
||||
/// assert_eq!(buffer.snapshot(), "Hello");
|
||||
///
|
||||
/// buffer.push(b" World!");
|
||||
/// // Buffer truncated to fit capacity, respecting UTF-8 boundaries
|
||||
/// assert!(buffer.is_truncated());
|
||||
/// ```
|
||||
pub fn push(&mut self, data: &[u8]) {
|
||||
if data.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// If the buffer is not yet at capacity, we can just append
|
||||
if self.len < self.capacity {
|
||||
let available = self.capacity - self.len;
|
||||
let to_append = data.len().min(available);
|
||||
|
||||
if self.buffer.len() < self.capacity {
|
||||
// Buffer hasn't been fully allocated yet
|
||||
self.buffer.extend_from_slice(&data[..to_append]);
|
||||
} else {
|
||||
// Buffer is allocated, write to the circular position
|
||||
for &byte in &data[..to_append] {
|
||||
self.buffer[self.write_pos] = byte;
|
||||
self.write_pos = (self.write_pos + 1) % self.capacity;
|
||||
}
|
||||
}
|
||||
|
||||
self.len += to_append;
|
||||
|
||||
// If we couldn't append all the data, handle the overflow
|
||||
if to_append < data.len() {
|
||||
self.push_with_overflow(&data[to_append..]);
|
||||
}
|
||||
} else {
|
||||
// Buffer is full, need to overwrite old data
|
||||
self.push_with_overflow(data);
|
||||
}
|
||||
}
|
||||
|
||||
/// Push data when buffer is at capacity (overwrites old data).
|
||||
fn push_with_overflow(&mut self, data: &[u8]) {
|
||||
if data.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
self.truncated = true;
|
||||
|
||||
// Ensure buffer is fully allocated
|
||||
if self.buffer.len() < self.capacity {
|
||||
self.buffer.resize(self.capacity, 0);
|
||||
}
|
||||
|
||||
// If data is larger than capacity, only keep the last `capacity` bytes
|
||||
let data_to_write = if data.len() >= self.capacity {
|
||||
&data[data.len() - self.capacity..]
|
||||
} else {
|
||||
data
|
||||
};
|
||||
|
||||
// Write data circularly
|
||||
for &byte in data_to_write {
|
||||
self.buffer[self.write_pos] = byte;
|
||||
self.write_pos = (self.write_pos + 1) % self.capacity;
|
||||
}
|
||||
|
||||
// Update length (stays at capacity)
|
||||
self.len = self.capacity;
|
||||
}
|
||||
|
||||
/// Get current buffer contents as a UTF-8 string.
|
||||
///
|
||||
/// This returns a snapshot of the current contents. Invalid UTF-8 sequences
|
||||
/// are handled gracefully by truncating at character boundaries.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A UTF-8 string containing the buffer contents. If the buffer contains
|
||||
/// invalid UTF-8 at the boundary, it will be truncated to the last valid
|
||||
/// character boundary.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use dirigent_tools::terminal::RingBuffer;
|
||||
///
|
||||
/// let mut buffer = RingBuffer::new(1024);
|
||||
/// buffer.push(b"Hello, World!");
|
||||
/// assert_eq!(buffer.snapshot(), "Hello, World!");
|
||||
/// ```
|
||||
pub fn snapshot(&self) -> String {
|
||||
if self.len == 0 {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
// Reconstruct the linear buffer from the circular buffer
|
||||
let linear = if self.len < self.capacity {
|
||||
// Buffer not full yet, data is at the beginning
|
||||
&self.buffer[..self.len]
|
||||
} else {
|
||||
// Buffer is full, need to reconstruct in correct order
|
||||
let mut temp = Vec::with_capacity(self.capacity);
|
||||
let start_pos = self.write_pos;
|
||||
|
||||
for i in 0..self.capacity {
|
||||
let pos = (start_pos + i) % self.capacity;
|
||||
temp.push(self.buffer[pos]);
|
||||
}
|
||||
|
||||
// Find UTF-8 boundary and return owned data
|
||||
let boundary = find_char_boundary(&temp, temp.len());
|
||||
return String::from_utf8_lossy(&temp[..boundary]).into_owned();
|
||||
};
|
||||
|
||||
// Find a valid UTF-8 character boundary
|
||||
let boundary = find_char_boundary(linear, linear.len());
|
||||
String::from_utf8_lossy(&linear[..boundary]).into_owned()
|
||||
}
|
||||
|
||||
/// Check if any data has been truncated due to buffer overflow.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `true` if data was dropped, `false` otherwise.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use dirigent_tools::terminal::RingBuffer;
|
||||
///
|
||||
/// let mut buffer = RingBuffer::new(5);
|
||||
/// buffer.push(b"Hello");
|
||||
/// assert!(!buffer.is_truncated());
|
||||
///
|
||||
/// buffer.push(b" World!");
|
||||
/// assert!(buffer.is_truncated());
|
||||
/// ```
|
||||
pub fn is_truncated(&self) -> bool {
|
||||
self.truncated
|
||||
}
|
||||
|
||||
/// Clear the buffer and reset the truncation flag.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use dirigent_tools::terminal::RingBuffer;
|
||||
///
|
||||
/// let mut buffer = RingBuffer::new(1024);
|
||||
/// buffer.push(b"Hello");
|
||||
/// buffer.clear();
|
||||
/// assert_eq!(buffer.snapshot(), "");
|
||||
/// assert!(!buffer.is_truncated());
|
||||
/// ```
|
||||
pub fn clear(&mut self) {
|
||||
self.write_pos = 0;
|
||||
self.len = 0;
|
||||
self.truncated = false;
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
/// Get the current length of valid data in the buffer.
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Check if the buffer is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the nearest character boundary at or before `start` position.
|
||||
///
|
||||
/// This ensures we don't cut in the middle of a UTF-8 multi-byte character.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `buf` - Byte buffer to scan
|
||||
/// * `start` - Position to start scanning backwards from
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The position of a valid UTF-8 character boundary <= `start`.
|
||||
///
|
||||
/// # UTF-8 Encoding Rules
|
||||
///
|
||||
/// - Single-byte char: `0xxxxxxx` (0x00-0x7F)
|
||||
/// - Continuation byte: `10xxxxxx` (0x80-0xBF)
|
||||
/// - Start of 2-byte char: `110xxxxx` (0xC0-0xDF)
|
||||
/// - Start of 3-byte char: `1110xxxx` (0xE0-0xEF)
|
||||
/// - Start of 4-byte char: `11110xxx` (0xF0-0xF7)
|
||||
fn find_char_boundary(buf: &[u8], start: usize) -> usize {
|
||||
if start == 0 || buf.is_empty() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
let start = start.min(buf.len());
|
||||
|
||||
// Scan backwards to find a valid character start
|
||||
for i in (0..start).rev() {
|
||||
let byte = buf[i];
|
||||
|
||||
// Check if this is a valid character start (not a continuation byte)
|
||||
if byte & 0b1100_0000 != 0b1000_0000 {
|
||||
// This is either ASCII (0xxxxxxx) or a multi-byte start (11xxxxxx)
|
||||
// Verify we have enough bytes for a complete character
|
||||
let char_len = if byte & 0b1000_0000 == 0 {
|
||||
1 // ASCII
|
||||
} else if byte & 0b1110_0000 == 0b1100_0000 {
|
||||
2 // 2-byte char
|
||||
} else if byte & 0b1111_0000 == 0b1110_0000 {
|
||||
3 // 3-byte char
|
||||
} else if byte & 0b1111_1000 == 0b1111_0000 {
|
||||
4 // 4-byte char
|
||||
} else {
|
||||
// Invalid UTF-8 start byte, skip it
|
||||
continue;
|
||||
};
|
||||
|
||||
// Check if we have enough bytes remaining
|
||||
if i + char_len <= start {
|
||||
// Validate that all continuation bytes are present
|
||||
let mut valid = true;
|
||||
for j in 1..char_len {
|
||||
if i + j >= buf.len() || buf[i + j] & 0b1100_0000 != 0b1000_0000 {
|
||||
valid = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if valid {
|
||||
return i + char_len;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we couldn't find a valid boundary, return 0
|
||||
0
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_new_buffer() {
|
||||
let buffer = RingBuffer::new(1024);
|
||||
assert_eq!(buffer.snapshot(), "");
|
||||
assert!(!buffer.is_truncated());
|
||||
assert_eq!(buffer.len(), 0);
|
||||
assert!(buffer.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_push_simple() {
|
||||
let mut buffer = RingBuffer::new(1024);
|
||||
buffer.push(b"Hello");
|
||||
assert_eq!(buffer.snapshot(), "Hello");
|
||||
assert!(!buffer.is_truncated());
|
||||
assert_eq!(buffer.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_push_multiple() {
|
||||
let mut buffer = RingBuffer::new(1024);
|
||||
buffer.push(b"Hello");
|
||||
buffer.push(b" ");
|
||||
buffer.push(b"World");
|
||||
assert_eq!(buffer.snapshot(), "Hello World");
|
||||
assert!(!buffer.is_truncated());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_overflow_truncates() {
|
||||
let mut buffer = RingBuffer::new(10);
|
||||
buffer.push(b"Hello");
|
||||
assert!(!buffer.is_truncated());
|
||||
|
||||
buffer.push(b" World!");
|
||||
assert!(buffer.is_truncated());
|
||||
|
||||
// Should keep the last 10 bytes
|
||||
let result = buffer.snapshot();
|
||||
assert!(result.len() <= 10);
|
||||
assert!(result.ends_with("World!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_utf8_boundary_preservation() {
|
||||
let mut buffer = RingBuffer::new(10);
|
||||
|
||||
// Push UTF-8 string with multi-byte characters (emoji)
|
||||
// "Hello😀" is 10 bytes total (Hello=5, 😀=4, but we only have 10 capacity)
|
||||
buffer.push("Hello".as_bytes());
|
||||
buffer.push("😀".as_bytes()); // This should fit exactly
|
||||
|
||||
let result = buffer.snapshot();
|
||||
assert!(result == "Hello😀" || result == "Hello"); // Depending on implementation
|
||||
assert!(!buffer.is_truncated() || result.len() <= 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_utf8_truncation() {
|
||||
let mut buffer = RingBuffer::new(8);
|
||||
|
||||
// Push string that will cause truncation in the middle of multi-byte char
|
||||
buffer.push("Hello😀World".as_bytes());
|
||||
|
||||
// Result should be valid UTF-8 (no partial emoji)
|
||||
let result = buffer.snapshot();
|
||||
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
|
||||
assert!(buffer.is_truncated());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let mut buffer = RingBuffer::new(1024);
|
||||
buffer.push(b"Hello");
|
||||
buffer.clear();
|
||||
|
||||
assert_eq!(buffer.snapshot(), "");
|
||||
assert!(!buffer.is_truncated());
|
||||
assert_eq!(buffer.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_push() {
|
||||
let mut buffer = RingBuffer::new(1024);
|
||||
buffer.push(b"");
|
||||
assert_eq!(buffer.snapshot(), "");
|
||||
assert!(!buffer.is_truncated());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_data_at_once() {
|
||||
let mut buffer = RingBuffer::new(100);
|
||||
let large_data = vec![b'A'; 500];
|
||||
|
||||
buffer.push(&large_data);
|
||||
|
||||
assert!(buffer.is_truncated());
|
||||
assert_eq!(buffer.len(), 100);
|
||||
|
||||
// Should contain only 'A's
|
||||
let result = buffer.snapshot();
|
||||
assert!(result.chars().all(|c| c == 'A'));
|
||||
assert!(result.len() <= 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_char_boundary_ascii() {
|
||||
let data = b"Hello";
|
||||
assert_eq!(find_char_boundary(data, 5), 5);
|
||||
assert_eq!(find_char_boundary(data, 3), 3);
|
||||
assert_eq!(find_char_boundary(data, 0), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_char_boundary_utf8() {
|
||||
// "Hello😀" - emoji is 4 bytes: F0 9F 98 80
|
||||
let data = "Hello😀".as_bytes();
|
||||
let total_len = data.len(); // 5 + 4 = 9
|
||||
|
||||
// Should find boundaries correctly
|
||||
assert_eq!(find_char_boundary(data, total_len), total_len);
|
||||
assert_eq!(find_char_boundary(data, 5), 5); // After "Hello"
|
||||
|
||||
// Middle of emoji should backtrack to before emoji
|
||||
assert_eq!(find_char_boundary(data, 6), 5); // 1st continuation byte
|
||||
assert_eq!(find_char_boundary(data, 7), 5); // 2nd continuation byte
|
||||
assert_eq!(find_char_boundary(data, 8), 5); // 3rd continuation byte
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_char_boundary_multi_utf8() {
|
||||
// Multiple multi-byte characters
|
||||
let data = "日本語".as_bytes(); // 3 chars, 9 bytes (3 each)
|
||||
|
||||
assert_eq!(find_char_boundary(data, 9), 9); // End
|
||||
assert_eq!(find_char_boundary(data, 6), 6); // After 2nd char
|
||||
assert_eq!(find_char_boundary(data, 3), 3); // After 1st char
|
||||
|
||||
// Middle of 2nd character
|
||||
assert_eq!(find_char_boundary(data, 4), 3);
|
||||
assert_eq!(find_char_boundary(data, 5), 3);
|
||||
|
||||
// Middle of 3rd character
|
||||
assert_eq!(find_char_boundary(data, 7), 6);
|
||||
assert_eq!(find_char_boundary(data, 8), 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circular_buffer_behavior() {
|
||||
let mut buffer = RingBuffer::new(5);
|
||||
|
||||
buffer.push(b"ABCDE");
|
||||
assert_eq!(buffer.snapshot(), "ABCDE");
|
||||
|
||||
buffer.push(b"FG");
|
||||
assert!(buffer.is_truncated());
|
||||
|
||||
let result = buffer.snapshot();
|
||||
assert_eq!(result.len(), 5);
|
||||
assert!(result.ends_with("G"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
//! Terminal wait-for-exit operation.
|
||||
//!
|
||||
//! **Status**: Not yet implemented (TOOLS-TERM-03)
|
||||
//!
|
||||
//! This module will implement:
|
||||
//! - Blocking wait for process completion
|
||||
//! - Runtime timeout enforcement
|
||||
//! - Exit status return
|
||||
|
||||
use crate::config::TerminalConfig;
|
||||
use crate::error::{ToolError, ToolResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to wait for terminal to exit.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WaitForTerminalExitRequest {
|
||||
/// Terminal ID from create response.
|
||||
pub terminal_id: String,
|
||||
}
|
||||
|
||||
/// Response with exit status.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WaitForTerminalExitResponse {
|
||||
/// Process exit status code.
|
||||
pub exit_status: i32,
|
||||
}
|
||||
|
||||
/// Wait for terminal process to complete.
|
||||
///
|
||||
/// **Status**: Implemented (TOOLS-TERM-03)
|
||||
///
|
||||
/// ## Implementation
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Looks up TerminalId in registry
|
||||
/// 2. Gets process handle
|
||||
/// 3. If already exited → returns cached exit status immediately
|
||||
/// 4. If still running:
|
||||
/// - Awaits process completion (tokio child.wait())
|
||||
/// - Enforces `max_runtime_secs` timeout via tokio::time::timeout
|
||||
/// - If timeout → kills process and returns error
|
||||
/// 5. Caches exit status in registry
|
||||
/// 6. Returns WaitForTerminalExitResponse
|
||||
///
|
||||
/// ## Timeout Behavior
|
||||
///
|
||||
/// - Uses `max_runtime_secs` from TerminalConfig
|
||||
/// - On timeout:
|
||||
/// - Process is killed (SIGKILL/TerminateProcess)
|
||||
/// - Returns `ToolError::TerminalError` with timeout message
|
||||
/// - Terminal remains in registry (can still get output)
|
||||
///
|
||||
/// ## Blocking vs Polling
|
||||
///
|
||||
/// - `wait_for_exit()` → Blocks until completion or timeout
|
||||
/// - `get_output()` → Returns immediately with current output
|
||||
///
|
||||
/// Use `wait_for_exit()` when you want to ensure completion before proceeding.
|
||||
///
|
||||
/// ## Error Cases
|
||||
///
|
||||
/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound`
|
||||
/// - Terminal released → `ToolError::TerminalNotFound`
|
||||
/// - Timeout exceeded → `ToolError::TerminalError`
|
||||
///
|
||||
/// ## See Also
|
||||
///
|
||||
/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-03`
|
||||
/// - Get output: `output.rs`
|
||||
pub async fn wait_for_terminal_exit(
|
||||
request: WaitForTerminalExitRequest,
|
||||
config: &TerminalConfig,
|
||||
) -> ToolResult<WaitForTerminalExitResponse> {
|
||||
use crate::terminal::registry::global_registry;
|
||||
use std::time::Duration;
|
||||
|
||||
let registry = global_registry();
|
||||
let terminal_id = request.terminal_id.clone();
|
||||
|
||||
// Check if already exited
|
||||
let already_exited = registry.get_mut(&terminal_id, |state| {
|
||||
state.exit_status.as_ref().map(|status| {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::process::ExitStatusExt;
|
||||
status.code().or_else(|| status.signal()).unwrap_or(-1)
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
status.code().unwrap_or(-1)
|
||||
}
|
||||
})
|
||||
})?;
|
||||
|
||||
if let Some(exit_status) = already_exited {
|
||||
return Ok(WaitForTerminalExitResponse { exit_status });
|
||||
}
|
||||
|
||||
// Process is still running, need to wait for it
|
||||
// We need to take ownership of the process to wait on it
|
||||
// This is tricky with the registry design, so we'll use a different approach:
|
||||
// We'll poll the process status and check exit_status
|
||||
|
||||
let timeout_duration = Duration::from_secs(config.max_runtime_secs);
|
||||
let start_time = std::time::Instant::now();
|
||||
|
||||
loop {
|
||||
// Check if timeout exceeded
|
||||
if start_time.elapsed() >= timeout_duration {
|
||||
// Kill the process due to timeout
|
||||
registry.get_mut(&terminal_id, |state| {
|
||||
let _ = state.process.start_kill();
|
||||
state.killed = true;
|
||||
})?;
|
||||
|
||||
return Err(ToolError::terminal_error(format!(
|
||||
"Terminal timed out after {} seconds",
|
||||
config.max_runtime_secs
|
||||
)));
|
||||
}
|
||||
|
||||
// Try to get exit status (non-blocking check)
|
||||
let exit_status_result = registry.get_mut(&terminal_id, |state| {
|
||||
// Try to check if process has exited
|
||||
match state.process.try_wait() {
|
||||
Ok(Some(status)) => {
|
||||
// Process has exited
|
||||
state.exit_status = Some(status);
|
||||
let exit_code = {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::process::ExitStatusExt;
|
||||
status.code().or_else(|| status.signal()).unwrap_or(-1)
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
status.code().unwrap_or(-1)
|
||||
}
|
||||
};
|
||||
Some(exit_code)
|
||||
}
|
||||
Ok(None) => {
|
||||
// Process is still running
|
||||
None
|
||||
}
|
||||
Err(_e) => {
|
||||
// Error checking status, treat as still running
|
||||
None
|
||||
}
|
||||
}
|
||||
})?;
|
||||
|
||||
if let Some(exit_status) = exit_status_result {
|
||||
return Ok(WaitForTerminalExitResponse { exit_status });
|
||||
}
|
||||
|
||||
// Wait a bit before checking again
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
//! Per-call scoping context shared by every harness layer.
|
||||
|
||||
use crate::config::{PermissionConfig, SandboxConfig};
|
||||
use crate::permission::check::PermissionContext;
|
||||
use crate::tool::{ClientKind, ProtocolKind};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Per-call scoping context. Passed to every layer below the registry.
|
||||
///
|
||||
/// `connector_id` and `session_id` mirror what the existing
|
||||
/// [`PermissionContext`] uses; `client_kind` and `protocol` are the new
|
||||
/// override surface for per-client / per-protocol behaviour.
|
||||
#[derive(Clone)]
|
||||
pub struct ToolContext {
|
||||
pub connector_id: Arc<str>,
|
||||
pub session_id: Option<Arc<str>>,
|
||||
pub client_kind: ClientKind,
|
||||
pub protocol: ProtocolKind,
|
||||
pub workspace_root: PathBuf,
|
||||
pub sandbox: Arc<SandboxConfig>,
|
||||
pub permission: Arc<PermissionConfig>,
|
||||
pub permission_context: Arc<PermissionContext>,
|
||||
}
|
||||
|
||||
impl ToolContext {
|
||||
/// Test/builder helper. Real callers compose this from connector state.
|
||||
pub fn for_test(
|
||||
connector_id: impl Into<Arc<str>>,
|
||||
client_kind: ClientKind,
|
||||
protocol: ProtocolKind,
|
||||
workspace_root: PathBuf,
|
||||
sandbox: SandboxConfig,
|
||||
permission: PermissionConfig,
|
||||
permission_context: PermissionContext,
|
||||
) -> Self {
|
||||
Self {
|
||||
connector_id: connector_id.into(),
|
||||
session_id: None,
|
||||
client_kind,
|
||||
protocol,
|
||||
workspace_root,
|
||||
sandbox: Arc::new(sandbox),
|
||||
permission: Arc::new(permission),
|
||||
permission_context: Arc::new(permission_context),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::WhitelistConfig;
|
||||
use crate::permission::whitelist::CompiledWhitelist;
|
||||
|
||||
#[test]
|
||||
fn context_builds_with_test_helper() {
|
||||
let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
|
||||
let perm_ctx = PermissionContext::new(
|
||||
"conn-1".to_string(),
|
||||
None,
|
||||
whitelist,
|
||||
);
|
||||
let ctx = ToolContext::for_test(
|
||||
"conn-1",
|
||||
ClientKind::claude(),
|
||||
ProtocolKind::acp(),
|
||||
PathBuf::from("/tmp"),
|
||||
SandboxConfig::default(),
|
||||
PermissionConfig::default(),
|
||||
perm_ctx,
|
||||
);
|
||||
assert_eq!(&*ctx.connector_id, "conn-1");
|
||||
assert_eq!(ctx.client_kind, ClientKind::claude());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
//! `Tool` trait + object-safe `AnyTool` + `Erased<T>` adapter.
|
||||
|
||||
use crate::tool::{ToolContext, ToolEventSink, ToolKind};
|
||||
use async_trait::async_trait;
|
||||
use schemars::{JsonSchema, schema_for};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, oneshot};
|
||||
|
||||
/// Streaming-aware tool input.
|
||||
///
|
||||
/// Tools opt into streaming via [`Tool::supports_input_streaming`]. If they
|
||||
/// do not, the harness buffers and always provides [`ToolInput::Final`].
|
||||
pub enum ToolInput<T> {
|
||||
Final(T),
|
||||
Partial {
|
||||
partial: mpsc::UnboundedReceiver<serde_json::Value>,
|
||||
final_input: oneshot::Receiver<T>,
|
||||
},
|
||||
}
|
||||
|
||||
/// JSON-shaped variant used by the object-safe `AnyTool` trait.
|
||||
pub enum AnyToolInput {
|
||||
Final(serde_json::Value),
|
||||
Partial {
|
||||
partial: mpsc::UnboundedReceiver<serde_json::Value>,
|
||||
final_input: oneshot::Receiver<serde_json::Value>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Strongly-typed tool implementation.
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync + 'static {
|
||||
type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema + Send + 'static;
|
||||
type Output: Serialize + Send + 'static;
|
||||
|
||||
const NAME: &'static str;
|
||||
|
||||
fn kind() -> ToolKind;
|
||||
|
||||
fn input_schema() -> serde_json::Value {
|
||||
serde_json::to_value(schema_for!(Self::Input)).expect("schema_for must serialise")
|
||||
}
|
||||
|
||||
fn supports_input_streaming() -> bool { false }
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
input: ToolInput<Self::Input>,
|
||||
events: ToolEventSink,
|
||||
ctx: &ToolContext,
|
||||
) -> Result<Self::Output, Self::Output>;
|
||||
|
||||
fn erase(self: Arc<Self>) -> Arc<dyn AnyTool> where Self: Sized { Arc::new(Erased(self)) }
|
||||
}
|
||||
|
||||
/// Object-safe variant. The registry stores `Arc<dyn AnyTool>`.
|
||||
#[async_trait]
|
||||
pub trait AnyTool: Send + Sync + 'static {
|
||||
fn name(&self) -> &'static str;
|
||||
fn kind(&self) -> ToolKind;
|
||||
fn input_schema(&self) -> serde_json::Value;
|
||||
fn supports_input_streaming(&self) -> bool;
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
input: AnyToolInput,
|
||||
events: ToolEventSink,
|
||||
ctx: &ToolContext,
|
||||
) -> Result<serde_json::Value, serde_json::Value>;
|
||||
}
|
||||
|
||||
/// Adapter from a typed `Tool` to `AnyTool`.
|
||||
pub struct Erased<T: Tool>(pub Arc<T>);
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Tool> AnyTool for Erased<T> {
|
||||
fn name(&self) -> &'static str { T::NAME }
|
||||
fn kind(&self) -> ToolKind { T::kind() }
|
||||
fn input_schema(&self) -> serde_json::Value { T::input_schema() }
|
||||
fn supports_input_streaming(&self) -> bool { T::supports_input_streaming() }
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
input: AnyToolInput,
|
||||
events: ToolEventSink,
|
||||
ctx: &ToolContext,
|
||||
) -> Result<serde_json::Value, serde_json::Value> {
|
||||
let typed = match input {
|
||||
AnyToolInput::Final(v) => {
|
||||
let parsed: T::Input = serde_json::from_value(v).map_err(|e| {
|
||||
serde_json::json!({ "error": format!("invalid input: {e}") })
|
||||
})?;
|
||||
ToolInput::Final(parsed)
|
||||
}
|
||||
AnyToolInput::Partial { partial, final_input: _ } => {
|
||||
// For v1: tools that opt into streaming receive Partial; the
|
||||
// typed-final-input wiring is added when a streaming tool ships.
|
||||
// Until then, only Final is fed by the dispatcher.
|
||||
let _ = partial;
|
||||
return Err(serde_json::json!({
|
||||
"error": "streaming inputs are not yet wired in v1 dispatcher"
|
||||
}));
|
||||
}
|
||||
};
|
||||
let inner = self.0.clone();
|
||||
let result = inner.run(typed, events, ctx).await;
|
||||
match result {
|
||||
Ok(o) => Ok(serde_json::to_value(o).unwrap_or(serde_json::Value::Null)),
|
||||
Err(o) => Err(serde_json::to_value(o).unwrap_or(serde_json::Value::Null)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig};
|
||||
use crate::permission::check::PermissionContext;
|
||||
use crate::permission::whitelist::CompiledWhitelist;
|
||||
use crate::tool::{ClientKind, ProtocolKind};
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Serialize, Deserialize, JsonSchema)]
|
||||
struct EchoInput { msg: String }
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct EchoOutput { echoed: String }
|
||||
|
||||
struct EchoTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for EchoTool {
|
||||
type Input = EchoInput;
|
||||
type Output = EchoOutput;
|
||||
const NAME: &'static str = "echo";
|
||||
fn kind() -> ToolKind { ToolKind::Other }
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
input: ToolInput<Self::Input>,
|
||||
_events: ToolEventSink,
|
||||
_ctx: &ToolContext,
|
||||
) -> Result<Self::Output, Self::Output> {
|
||||
let i = match input { ToolInput::Final(i) => i, _ => unreachable!() };
|
||||
Ok(EchoOutput { echoed: i.msg })
|
||||
}
|
||||
}
|
||||
|
||||
fn ctx() -> ToolContext {
|
||||
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
|
||||
let pc = PermissionContext::new("conn-1".to_string(), None, wl);
|
||||
ToolContext::for_test(
|
||||
"conn-1", ClientKind::claude(), ProtocolKind::acp(),
|
||||
PathBuf::from("/tmp"),
|
||||
SandboxConfig::default(), PermissionConfig::default(), pc,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn typed_tool_runs_and_returns_output() {
|
||||
let tool: Arc<EchoTool> = Arc::new(EchoTool);
|
||||
let any: Arc<dyn AnyTool> = tool.erase();
|
||||
let (sink, _rx) = ToolEventSink::new();
|
||||
|
||||
let result = any.run(
|
||||
AnyToolInput::Final(serde_json::json!({ "msg": "hello" })),
|
||||
sink,
|
||||
&ctx(),
|
||||
).await.unwrap();
|
||||
|
||||
assert_eq!(result["echoed"], "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn invalid_input_returns_structured_error() {
|
||||
let tool: Arc<EchoTool> = Arc::new(EchoTool);
|
||||
let any: Arc<dyn AnyTool> = tool.erase();
|
||||
let (sink, _rx) = ToolEventSink::new();
|
||||
|
||||
let err = any.run(
|
||||
AnyToolInput::Final(serde_json::json!({ "wrong": 1 })),
|
||||
sink,
|
||||
&ctx(),
|
||||
).await.unwrap_err();
|
||||
|
||||
assert!(err["error"].as_str().unwrap().contains("invalid input"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn name_and_kind_round_trip_through_erase() {
|
||||
let tool: Arc<EchoTool> = Arc::new(EchoTool);
|
||||
let any: Arc<dyn AnyTool> = tool.erase();
|
||||
assert_eq!(any.name(), "echo");
|
||||
assert_eq!(any.kind(), ToolKind::Other);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
//! Neutral tool events. Connector adapters translate to wire types.
|
||||
|
||||
use crate::tool::ToolKind;
|
||||
use bytes::Bytes;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
/// Opaque permission-request id, allocated by the dispatcher.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
||||
pub struct PermissionRequestId(Arc<str>);
|
||||
|
||||
impl PermissionRequestId {
|
||||
/// Construct a new permission-request id from any string-like value.
|
||||
pub fn new(value: impl Into<Arc<str>>) -> Self {
|
||||
Self(value.into())
|
||||
}
|
||||
|
||||
/// Borrow the inner id as a string slice.
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Where a tool is operating (file path + optional line).
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ToolLocation {
|
||||
pub path: String,
|
||||
pub line: Option<u32>,
|
||||
}
|
||||
|
||||
/// Result content shape. Mirrors what most providers accept as a tool result.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ToolResultContent {
|
||||
Text { text: Arc<str> },
|
||||
Json { value: serde_json::Value },
|
||||
Image { mime: Arc<str>, #[serde(with = "serde_bytes_arc")] data: Bytes },
|
||||
Parts { parts: Vec<ToolResultContent> },
|
||||
}
|
||||
|
||||
impl ToolResultContent {
|
||||
pub fn text(s: impl Into<Arc<str>>) -> Self { Self::Text { text: s.into() } }
|
||||
}
|
||||
|
||||
/// Events emitted by a running tool. Transport-agnostic.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ToolEvent {
|
||||
Started { title: Arc<str>, kind: ToolKind, location: Option<ToolLocation> },
|
||||
TitleUpdate { title: Arc<str>, location: Option<ToolLocation> },
|
||||
PartialOutput { content: ToolResultContent },
|
||||
Status { message: Arc<str> },
|
||||
PermissionRequested { request_id: PermissionRequestId, summary: Arc<str> },
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
/// Sink a tool emits events into. Cheap to clone.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ToolEventSink {
|
||||
tx: mpsc::UnboundedSender<ToolEvent>,
|
||||
}
|
||||
|
||||
impl ToolEventSink {
|
||||
pub fn new() -> (Self, mpsc::UnboundedReceiver<ToolEvent>) {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
(Self { tx }, rx)
|
||||
}
|
||||
|
||||
/// Best-effort emit. Drops the event if the receiver is gone.
|
||||
pub fn emit(&self, event: ToolEvent) {
|
||||
let _ = self.tx.send(event);
|
||||
}
|
||||
}
|
||||
|
||||
mod serde_bytes_arc {
|
||||
use bytes::Bytes;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
pub fn serialize<S: Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> {
|
||||
serde_bytes::Bytes::new(b.as_ref()).serialize(s)
|
||||
}
|
||||
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> {
|
||||
let v: Vec<u8> = serde_bytes::ByteBuf::deserialize(d)?.into_vec();
|
||||
Ok(Bytes::from(v))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn sink_round_trips_event() {
|
||||
let (sink, mut rx) = ToolEventSink::new();
|
||||
sink.emit(ToolEvent::Status { message: "hi".into() });
|
||||
let got = rx.recv().await.unwrap();
|
||||
match got {
|
||||
ToolEvent::Status { message } => assert_eq!(&*message, "hi"),
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn result_content_text_helper() {
|
||||
match ToolResultContent::text("hello") {
|
||||
ToolResultContent::Text { text } => assert_eq!(&*text, "hello"),
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_event_serde_round_trip() {
|
||||
let ev = ToolEvent::PartialOutput { content: ToolResultContent::text("x") };
|
||||
let json = serde_json::to_string(&ev).unwrap();
|
||||
let _back: ToolEvent = serde_json::from_str(&json).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn permission_request_id_constructor_and_accessor() {
|
||||
let id = PermissionRequestId::new("abc");
|
||||
assert_eq!(id.as_str(), "abc");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn permission_request_id_serde_is_transparent_string() {
|
||||
let id = PermissionRequestId::new("foo");
|
||||
let json = serde_json::to_string(&id).unwrap();
|
||||
assert_eq!(json, "\"foo\"");
|
||||
let back: PermissionRequestId = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back, id);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
//! Open-form client/protocol identifiers and tool category enum.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
/// Open newtype identifying the upstream client family (Claude, Codex, etc.).
|
||||
///
|
||||
/// Use the provided constants for known clients; use [`ClientKind::custom`]
|
||||
/// for anything else. Comparison is by inner string equality.
|
||||
///
|
||||
/// The well-known constructors (`claude`, `codex`, `gemini`, `opencode`)
|
||||
/// return cached values backed by a process-wide `OnceLock`, so calling them
|
||||
/// repeatedly is a cheap `Arc` clone — no per-call heap allocation.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
||||
pub struct ClientKind(Arc<str>);
|
||||
|
||||
impl ClientKind {
|
||||
pub fn claude() -> Self {
|
||||
static CELL: OnceLock<ClientKind> = OnceLock::new();
|
||||
CELL.get_or_init(|| Self(Arc::from("claude"))).clone()
|
||||
}
|
||||
pub fn codex() -> Self {
|
||||
static CELL: OnceLock<ClientKind> = OnceLock::new();
|
||||
CELL.get_or_init(|| Self(Arc::from("codex"))).clone()
|
||||
}
|
||||
pub fn gemini() -> Self {
|
||||
static CELL: OnceLock<ClientKind> = OnceLock::new();
|
||||
CELL.get_or_init(|| Self(Arc::from("gemini"))).clone()
|
||||
}
|
||||
pub fn opencode() -> Self {
|
||||
static CELL: OnceLock<ClientKind> = OnceLock::new();
|
||||
CELL.get_or_init(|| Self(Arc::from("opencode"))).clone()
|
||||
}
|
||||
|
||||
pub fn custom(name: impl Into<Arc<str>>) -> Self { Self(name.into()) }
|
||||
pub fn as_str(&self) -> &str { &self.0 }
|
||||
}
|
||||
|
||||
/// Open newtype identifying the wire protocol (ACP, OpenCode, native).
|
||||
///
|
||||
/// As with [`ClientKind`], the well-known constructors return cached values
|
||||
/// backed by a process-wide `OnceLock`.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
||||
pub struct ProtocolKind(Arc<str>);
|
||||
|
||||
impl ProtocolKind {
|
||||
pub fn acp() -> Self {
|
||||
static CELL: OnceLock<ProtocolKind> = OnceLock::new();
|
||||
CELL.get_or_init(|| Self(Arc::from("acp"))).clone()
|
||||
}
|
||||
pub fn opencode() -> Self {
|
||||
static CELL: OnceLock<ProtocolKind> = OnceLock::new();
|
||||
CELL.get_or_init(|| Self(Arc::from("opencode"))).clone()
|
||||
}
|
||||
pub fn native() -> Self {
|
||||
static CELL: OnceLock<ProtocolKind> = OnceLock::new();
|
||||
CELL.get_or_init(|| Self(Arc::from("native"))).clone()
|
||||
}
|
||||
|
||||
pub fn custom(name: impl Into<Arc<str>>) -> Self { Self(name.into()) }
|
||||
pub fn as_str(&self) -> &str { &self.0 }
|
||||
}
|
||||
|
||||
/// Coarse category of a tool. Connectors use this to pick wire-level hints
|
||||
/// (e.g. ACP `ToolKind::Edit` triggers diff rendering on the client).
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolKind {
|
||||
Read,
|
||||
Edit,
|
||||
Search,
|
||||
Execute,
|
||||
Fetch,
|
||||
Think,
|
||||
Other,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn client_kind_constants_are_distinct() {
|
||||
assert_ne!(ClientKind::claude(), ClientKind::codex());
|
||||
assert_eq!(ClientKind::claude(), ClientKind::claude());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn client_kind_custom_round_trips() {
|
||||
let k = ClientKind::custom("aider");
|
||||
assert_eq!(k.as_str(), "aider");
|
||||
assert_eq!(k, ClientKind::custom("aider"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn protocol_kind_distinct() {
|
||||
assert_ne!(ProtocolKind::acp(), ProtocolKind::opencode());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_kind_serde_round_trip() {
|
||||
let json = serde_json::to_string(&ToolKind::Edit).unwrap();
|
||||
assert_eq!(json, "\"edit\"");
|
||||
let back: ToolKind = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back, ToolKind::Edit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn well_known_client_kinds_are_cached() {
|
||||
let a = ClientKind::claude();
|
||||
let b = ClientKind::claude();
|
||||
assert_eq!(a, b);
|
||||
// Proves the cache works: same Arc, no per-call allocation.
|
||||
assert!(Arc::ptr_eq(&a.0, &b.0));
|
||||
|
||||
// Sanity-check another well-known kind too.
|
||||
let c1 = ClientKind::codex();
|
||||
let c2 = ClientKind::codex();
|
||||
assert!(Arc::ptr_eq(&c1.0, &c2.0));
|
||||
|
||||
// Custom values are *not* cached — each call allocates fresh.
|
||||
let x = ClientKind::custom("aider");
|
||||
let y = ClientKind::custom("aider");
|
||||
assert_eq!(x, y);
|
||||
assert!(!Arc::ptr_eq(&x.0, &y.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn well_known_protocol_kinds_are_cached() {
|
||||
let a = ProtocolKind::acp();
|
||||
let b = ProtocolKind::acp();
|
||||
assert_eq!(a, b);
|
||||
assert!(Arc::ptr_eq(&a.0, &b.0));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
//! Compile-time built-in tool registration macro.
|
||||
//!
|
||||
//! ```ignore
|
||||
//! crate::tools! { ReadTool, GrepTool, EditTool }
|
||||
//! ```
|
||||
//!
|
||||
//! Produces:
|
||||
//! - `pub const ALL_BUILT_IN_TOOL_NAMES: &[&str] = &[...];`
|
||||
//! - `pub fn built_in_tools() -> impl Iterator<Item = std::sync::Arc<dyn AnyTool>>`
|
||||
//! - A compile-time uniqueness check on `T::NAME`.
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tools {
|
||||
($($tool:ty),* $(,)?) => {
|
||||
pub const ALL_BUILT_IN_TOOL_NAMES: &[&str] = &[
|
||||
$( <$tool as $crate::tool::Tool>::NAME, )*
|
||||
];
|
||||
|
||||
const _: () = {
|
||||
const fn str_eq(a: &str, b: &str) -> bool {
|
||||
let a = a.as_bytes();
|
||||
let b = b.as_bytes();
|
||||
if a.len() != b.len() { return false; }
|
||||
let mut i = 0;
|
||||
while i < a.len() {
|
||||
if a[i] != b[i] { return false; }
|
||||
i += 1;
|
||||
}
|
||||
true
|
||||
}
|
||||
const NAMES: &[&str] = ALL_BUILT_IN_TOOL_NAMES;
|
||||
let mut i = 0;
|
||||
while i < NAMES.len() {
|
||||
let mut j = i + 1;
|
||||
while j < NAMES.len() {
|
||||
if str_eq(NAMES[i], NAMES[j]) {
|
||||
panic!("Duplicate built-in tool name");
|
||||
}
|
||||
j += 1;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
};
|
||||
|
||||
pub fn built_in_tools() -> Vec<std::sync::Arc<dyn $crate::tool::AnyTool>> {
|
||||
vec![
|
||||
$(
|
||||
{
|
||||
let t: std::sync::Arc<$tool> = std::sync::Arc::new(<$tool>::default());
|
||||
<$tool as $crate::tool::Tool>::erase(t)
|
||||
},
|
||||
)*
|
||||
]
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tool::{Tool, ToolContext, ToolEventSink, ToolInput, ToolKind};
|
||||
use async_trait::async_trait;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Default)]
|
||||
struct AlphaTool;
|
||||
#[derive(Default)]
|
||||
struct BetaTool;
|
||||
|
||||
#[derive(Serialize, Deserialize, JsonSchema)]
|
||||
struct Empty {}
|
||||
|
||||
macro_rules! impl_tool {
|
||||
($t:ty, $name:literal) => {
|
||||
#[async_trait]
|
||||
impl Tool for $t {
|
||||
type Input = Empty;
|
||||
type Output = Empty;
|
||||
const NAME: &'static str = $name;
|
||||
fn kind() -> ToolKind { ToolKind::Other }
|
||||
async fn run(
|
||||
self: Arc<Self>, _i: ToolInput<Empty>,
|
||||
_e: ToolEventSink, _c: &ToolContext,
|
||||
) -> Result<Empty, Empty> { Ok(Empty {}) }
|
||||
}
|
||||
};
|
||||
}
|
||||
impl_tool!(AlphaTool, "alpha");
|
||||
impl_tool!(BetaTool, "beta");
|
||||
|
||||
crate::tools! { AlphaTool, BetaTool }
|
||||
|
||||
#[test]
|
||||
fn macro_lists_names() {
|
||||
assert_eq!(ALL_BUILT_IN_TOOL_NAMES, &["alpha", "beta"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn macro_constructs_erased_tools() {
|
||||
let v = built_in_tools();
|
||||
assert_eq!(v.len(), 2);
|
||||
assert_eq!(v[0].name(), "alpha");
|
||||
assert_eq!(v[1].name(), "beta");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
//! Tool trait, registry-facing types, and per-call context.
|
||||
|
||||
pub mod kinds;
|
||||
pub mod events;
|
||||
pub mod context;
|
||||
pub mod erase;
|
||||
pub mod macros;
|
||||
|
||||
pub use kinds::{ClientKind, ProtocolKind, ToolKind};
|
||||
pub use events::{
|
||||
PermissionRequestId, ToolEvent, ToolEventSink, ToolLocation, ToolResultContent,
|
||||
};
|
||||
pub use context::ToolContext;
|
||||
pub use erase::{AnyTool, AnyToolInput, Erased, Tool, ToolInput};
|
||||
@@ -0,0 +1,5 @@
|
||||
//! Built-in tool implementations registered against the `Tool` trait.
|
||||
|
||||
pub mod read;
|
||||
|
||||
pub use read::ReadTool;
|
||||
@@ -0,0 +1,151 @@
|
||||
//! `ReadTool`: built-in trait wrapper around `crate::fs::read_text_file`.
|
||||
|
||||
use crate::fs::read::{read_text_file, ReadTextFileRequest, ReadTextFileResponse};
|
||||
use crate::tool::{
|
||||
Tool, ToolContext, ToolEvent, ToolEventSink, ToolInput, ToolKind, ToolLocation,
|
||||
ToolResultContent,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Serialize, Deserialize, JsonSchema, Clone)]
|
||||
pub struct ReadInput {
|
||||
/// Absolute path to the file to read.
|
||||
pub path: String,
|
||||
/// Optional 1-indexed start line.
|
||||
#[serde(default)] pub line: Option<usize>,
|
||||
/// Optional max number of lines to return.
|
||||
#[serde(default)] pub limit: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum ReadOutput {
|
||||
Ok { content: String },
|
||||
Err { error: String },
|
||||
}
|
||||
|
||||
impl ReadOutput {
|
||||
fn err(msg: impl Into<String>) -> Self { ReadOutput::Err { error: msg.into() } }
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ReadTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ReadTool {
|
||||
type Input = ReadInput;
|
||||
type Output = ReadOutput;
|
||||
const NAME: &'static str = "read";
|
||||
|
||||
fn kind() -> ToolKind { ToolKind::Read }
|
||||
|
||||
async fn run(
|
||||
self: Arc<Self>,
|
||||
input: ToolInput<Self::Input>,
|
||||
events: ToolEventSink,
|
||||
ctx: &ToolContext,
|
||||
) -> Result<Self::Output, Self::Output> {
|
||||
let i = match input {
|
||||
ToolInput::Final(i) => i,
|
||||
_ => return Err(ReadOutput::err("streaming inputs not supported by read")),
|
||||
};
|
||||
|
||||
events.emit(ToolEvent::Started {
|
||||
title: format!("Read {}", i.path).into(),
|
||||
kind: ToolKind::Read,
|
||||
location: Some(ToolLocation { path: i.path.clone(), line: i.line.map(|n| n as u32) }),
|
||||
});
|
||||
|
||||
let req = ReadTextFileRequest { path: i.path.clone(), line: i.line, limit: i.limit };
|
||||
let res: ReadTextFileResponse = read_text_file(req, ctx.sandbox.as_ref()).await
|
||||
.map_err(|e| ReadOutput::err(e.to_string()))?;
|
||||
|
||||
events.emit(ToolEvent::PartialOutput {
|
||||
content: ToolResultContent::text(res.content.clone()),
|
||||
});
|
||||
events.emit(ToolEvent::Completed);
|
||||
Ok(ReadOutput::Ok { content: res.content })
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig};
|
||||
use crate::permission::check::PermissionContext;
|
||||
use crate::permission::whitelist::CompiledWhitelist;
|
||||
use crate::tool::{ClientKind, ProtocolKind};
|
||||
use std::io::Write;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn ctx_for(root: &std::path::Path) -> ToolContext {
|
||||
let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap();
|
||||
let pc = PermissionContext::new("c".to_string(), None, wl);
|
||||
let mut sandbox = SandboxConfig::default();
|
||||
sandbox.allowed_roots = vec![root.to_path_buf()];
|
||||
ToolContext::for_test(
|
||||
"c", ClientKind::claude(), ProtocolKind::acp(),
|
||||
root.to_path_buf(), sandbox, PermissionConfig::default(), pc,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn reads_file_through_trait() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let path = dir.path().join("hello.txt");
|
||||
std::fs::File::create(&path).unwrap().write_all(b"hi\nthere\n").unwrap();
|
||||
|
||||
let tool = Arc::new(ReadTool);
|
||||
let (sink, mut rx) = ToolEventSink::new();
|
||||
let result = Tool::run(
|
||||
tool,
|
||||
ToolInput::Final(ReadInput {
|
||||
path: path.to_string_lossy().into_owned(),
|
||||
line: None, limit: None,
|
||||
}),
|
||||
sink,
|
||||
&ctx_for(dir.path()),
|
||||
).await.unwrap();
|
||||
|
||||
match result {
|
||||
ReadOutput::Ok { content } => {
|
||||
assert!(content.contains("hi"));
|
||||
assert!(content.contains("there"));
|
||||
}
|
||||
ReadOutput::Err { error } => panic!("expected ok, got err: {error}"),
|
||||
}
|
||||
|
||||
// At least Started + PartialOutput + Completed should have fired.
|
||||
let mut count = 0;
|
||||
while let Ok(_ev) = rx.try_recv() { count += 1; }
|
||||
assert!(count >= 3, "expected >=3 events, got {count}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn returns_structured_error_on_sandbox_violation() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let other = TempDir::new().unwrap();
|
||||
let outside = other.path().join("nope.txt");
|
||||
std::fs::File::create(&outside).unwrap().write_all(b"x").unwrap();
|
||||
|
||||
let tool = Arc::new(ReadTool);
|
||||
let (sink, _rx) = ToolEventSink::new();
|
||||
let result = Tool::run(
|
||||
tool,
|
||||
ToolInput::Final(ReadInput {
|
||||
path: outside.to_string_lossy().into_owned(),
|
||||
line: None, limit: None,
|
||||
}),
|
||||
sink,
|
||||
&ctx_for(dir.path()),
|
||||
).await.unwrap_err();
|
||||
|
||||
match result {
|
||||
ReadOutput::Err { error } => assert!(!error.is_empty()),
|
||||
ReadOutput::Ok { .. } => panic!("expected sandbox error"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user