commit b03dc15371a9d2b0d11df445c75e3c69cff865ba Author: Gabor Körber Date: Fri May 8 01:59:04 2026 +0200 sync from monorepo @ 2452e92e diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..dae6540 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,53 @@ +[workspace] +resolver = "2" +members = [ + "crates/dirigent_protocol", + "crates/dirigent_core", + "crates/dirigent_tools", + "crates/dirigent_fermata", + "crates/dirigent_auth", + "crates/dirigent_config", + "crates/dirigent_acp_api", + "crates/dirigent_archivist", + "crates/dirigent_process", + "crates/dirigent_taskrunner", + "crates/dirigent_anth", + "crates/dirigent_inspector", + "crates/dirigent_projects", + "crates/dirigent_matrix", + "crates/dirigent_zed", + "crates/dirigent_langfuse", + "crates/dirigent_chatgpt", + "crates/dirigent_codex", + "crates/dirigent_testing", + "crates/opencode_client", +] + +[workspace.lints.rust] +dead_code = "allow" +unused_imports = "allow" +unused_variables = "allow" +unused_mut = "allow" +unused_assignments = "allow" + +[workspace.dependencies] +dirigent_protocol = { path = "crates/dirigent_protocol" } +dirigent_core = { path = "crates/dirigent_core" } +dirigent_tools = { path = "crates/dirigent_tools" } +dirigent_fermata = { path = "crates/dirigent_fermata" } +dirigent_auth = { path = "crates/dirigent_auth" } +dirigent_config = { path = "crates/dirigent_config" } +dirigent_acp_api = { path = "crates/dirigent_acp_api" } +dirigent_archivist = { path = "crates/dirigent_archivist" } +dirigent_process = { path = "crates/dirigent_process" } +dirigent_taskrunner = { path = "crates/dirigent_taskrunner" } +dirigent_anth = { path = "crates/dirigent_anth" } +dirigent_inspector = { path = "crates/dirigent_inspector" } +dirigent_projects = { path = "crates/dirigent_projects" } +dirigent_matrix = { path = "crates/dirigent_matrix", default-features = true } +dirigent_zed = { path = "crates/dirigent_zed" } +dirigent_langfuse = { path = "crates/dirigent_langfuse" } +dirigent_chatgpt = { path = "crates/dirigent_chatgpt" } +dirigent_codex = { path = "crates/dirigent_codex" } +dirigent_testing = { path = "crates/dirigent_testing" } +opencode_client = { path = "crates/opencode_client" } diff --git a/README.md b/README.md new file mode 100644 index 0000000..d266bdc --- /dev/null +++ b/README.md @@ -0,0 +1,81 @@ +# Dirigent + +

+ Dirigent +

+ +

Core libraries for the Dirigent agent orchestration platform.

+ +--- + +Dirigent is a multi-agent orchestration platform built around the Agent-Client Protocol (ACP). This repository contains the foundational library crates — the building blocks used by downstream tools such as [dirigate](https://git.g4b.org/dirigence/dirigate) and [fermata](https://git.g4b.org/dirigence/fermata). + +> **Downstream mirror.** Active development happens in an upstream monorepo. This repository is an export of the core library crates and is updated on each release. Issues and contributions should be directed to the upstream project. + +--- + +## Crates + +| Crate | Description | +|-------|-------------| +| `dirigent_protocol` | ACP protocol types — messages, events, and RPC definitions | +| `dirigent_core` | Multi-connector orchestration runtime | +| `dirigent_tools` | Tool sandbox and execution abstractions | +| `dirigent_fermata` | Policy gate for AI coding agents (`.botignore` / `botignore.toml`) | +| `dirigent_auth` | User authorization model | +| `dirigent_config` | Configuration management | +| `dirigent_acp_api` | ACP server for incoming agent connections | +| `dirigent_archivist` | Event-driven session archival | +| `dirigent_process` | Child process management | +| `dirigent_taskrunner` | Background task runner | +| `dirigent_anth` | Claude Code JSONL session parser | +| `dirigent_inspector` | Session inspection tools | +| `dirigent_projects` | Project management primitives | +| `dirigent_matrix` | Matrix integration for session sharing | +| `dirigent_zed` | Zed editor integration | +| `dirigent_langfuse` | Langfuse observability integration | +| `dirigent_chatgpt` | ChatGPT `conversations.json` parser | +| `dirigent_codex` | OpenAI Codex session parser | +| `dirigent_testing` | Test utilities | +| `opencode_client` | OpenCode.ai HTTP client | + +--- + +## Usage + +### Library crates (via git dependency) + +Add a crate to your `Cargo.toml`: + +```toml +[dependencies] +dirigent_protocol = { git = "https://git.g4b.org/dirigence/dirigent", path = "crates/dirigent_protocol" } +dirigent_core = { git = "https://git.g4b.org/dirigence/dirigent", path = "crates/dirigent_core" } +``` + +Replace `dirigent_protocol` / `dirigent_core` with the crate you need. All crates follow the same pattern. + +### Binary crates (cargo install) + +**fermata** — policy gate CLI and Claude hook adapter: + +```bash +cargo install --git https://git.g4b.org/dirigence/dirigent --features cli +``` + +**anth** — Claude Code session inspector: + +```bash +cargo install --git https://git.g4b.org/dirigence/dirigent --bin anth_bear --features dirigent-paths +``` + +--- + +## License + +Licensed under either of + +- Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or ) +- MIT License ([LICENSE-MIT](LICENSE-MIT) or ) + +at your option. diff --git a/crates/dirigent_acp_api/CLAUDE.md b/crates/dirigent_acp_api/CLAUDE.md new file mode 100644 index 0000000..a57b19b --- /dev/null +++ b/crates/dirigent_acp_api/CLAUDE.md @@ -0,0 +1,124 @@ +# Package: dirigent_acp_api + +ACP Server implementation for accepting incoming ACP connections from external agents. + +## Quick Facts +- **Type**: Library +- **Main Entry**: src/lib.rs +- **Dependencies**: axum, tokio, serde, tracing, uuid, async-trait, dirigent_protocol +- **Status**: Core structure complete, integration with CoreRuntime pending + +## Overview + +The `dirigent_acp_api` package implements an ACP (Agent-Client Protocol) server that allows Dirigent to accept incoming connections from external ACP clients like Claude Code or custom agents. This enables session sharing, remote orchestration, and multi-client collaboration. + +## Architecture + +### Core Components + +- **config.rs** - Server configuration types (`AcpServerConfig`) +- **error.rs** - Error types (`AcpServerError`, `JsonRpcErrorObject`) +- **jsonrpc.rs** - JSON-RPC 2.0 types and parsing +- **rpc.rs** - RPC handler and method dispatch +- **session_manager.rs** - Session/client tracking (TODO) +- **sse.rs** - SSE notification system (TODO) +- **event_bridge.rs** - Event translation (TODO) + +### Key Types + +```rust +pub struct AcpServerConfig { + pub enabled: bool, // Enable/disable server + pub port: u16, // Listen port (default: 3001) + pub allowed_origins: Option>, // CORS origins + pub max_connections: usize, // Connection limit (default: 100) +} +``` + +### ConnectorOperations Trait + +The RPC handler uses a trait abstraction to avoid circular dependencies with dirigent_core: + +```rust +#[async_trait] +pub trait ConnectorOperations: Send + Sync { + async fn create_session(&self, connector_id: &str) -> Result; + async fn load_session(&self, connector_id: &str, session_id: &str) -> Result; + async fn send_prompt(&self, connector_id: &str, session_id: &str, prompt: &str) -> Result; + // ... more methods +} +``` + +## API Endpoints + +### POST `/rpc` + +JSON-RPC 2.0 endpoint supporting: +- `initialize` - Client handshake +- `session/new` - Create session +- `session/load` - Load existing session +- `session/prompt` - Send prompt +- `session/cancel` - Cancel generation +- `session/close` - Close session + +### GET `/events` + +Server-Sent Events for streaming notifications: +- `acp/messageChunk` - Streaming content +- `acp/messageComplete` - Generation complete +- `acp/sessionIdle` - Ready for input + +### GET `/health` + +Health check endpoint. + +## Configuration UI + +The ACP Server is configured via the web UI at **Configuration > ACP Server**: + +- Enable/disable toggle +- Port configuration +- Max connections limit +- Allowed origins (CORS) +- Default connector selection +- Connected clients management + +Server functions in `crates/api/src/acp_server.rs` bridge the UI and this package. + +## Implementation Status + +**Completed:** +- Configuration types (`AcpServerConfig`) +- Error types (`AcpServerError`) +- JSON-RPC types and parsing +- RPC handler structure with ConnectorOperations trait + +**Pending:** +- Session Manager implementation +- SSE Notifier implementation +- Event Bridge implementation +- Axum router integration +- Web server integration + +## Key Files + +| File | Description | +|------|-------------| +| `src/lib.rs` | Module exports and router creation | +| `src/config.rs` | AcpServerConfig with validation | +| `src/error.rs` | Error types and codes | +| `src/jsonrpc.rs` | JSON-RPC 2.0 implementation | +| `src/rpc.rs` | RPC handler and method dispatch | + +## Related Packages + +- **dirigent_core** - Provides CoreHandle implementation of ConnectorOperations +- **dirigent_protocol** - Shared event and message types +- **api** - Server functions for UI configuration +- **web** - Configuration UI components + +## Documentation + +- **Architecture**: `docs/architecture/acp_server.md` +- **Configuration**: `docs/configuration/acp-connectors.md` +- **Tasks**: `docs/building/07_acp_serve/02_acp_server_tasks.md` diff --git a/crates/dirigent_acp_api/Cargo.toml b/crates/dirigent_acp_api/Cargo.toml new file mode 100644 index 0000000..9216db8 --- /dev/null +++ b/crates/dirigent_acp_api/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "dirigent_acp_api" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[dependencies] +anyhow = "1.0" +async-trait = "0.1" +# ACP (Agent-Client Protocol) dependencies +axum = "0.8" +# ACP Server dependencies (Phase 2) +chrono = { version = "0.4", features = ["serde"] } +# Workspace dependencies +# Note: dirigent_protocol is used for Event types +dirigent_protocol = { path = "../dirigent_protocol" } +# Note: dirigent_core is NOT a direct dependency to avoid circular dependency. +# CoreHandle is passed into the ACP server at runtime via generics or trait objects. +# The mode/model mapping logic is duplicated here for legacy mode support. +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1", features = ["full"] } +tokio-stream = { version = "0.1", features = ["sync"] } +tower = "0.5" +tower-http = { version = "0.6", features = ["cors"] } +# ACP Server dependencies (Phase 1) +tracing = "0.1" +uuid = { version = "1.0", features = ["serde", "v4", "v7"] } diff --git a/crates/dirigent_acp_api/README.md b/crates/dirigent_acp_api/README.md new file mode 100644 index 0000000..81b87b1 --- /dev/null +++ b/crates/dirigent_acp_api/README.md @@ -0,0 +1,5 @@ +This crate should expose an ACP API for dirigent to be used by another ACP Client. + +Funny Integration Test would be to use Dirigent using Dirigent (but that would need probably a dummy acp agent to be used by that) + +This will however also require dirigent to "pass through" functionality, it believes to be responsible for. diff --git a/crates/dirigent_acp_api/src/agent_requests.rs b/crates/dirigent_acp_api/src/agent_requests.rs new file mode 100644 index 0000000..691c1ee --- /dev/null +++ b/crates/dirigent_acp_api/src/agent_requests.rs @@ -0,0 +1,438 @@ +//! Agent Request Tracker +//! +//! This module provides infrastructure for tracking agent-initiated requests +//! that require responses from HTTP clients. It's used to implement bidirectional +//! JSON-RPC communication in the ACP Server. +//! +//! ## Use Case +//! +//! When an agent (like Claude) sends a request to the client (like a permission +//! prompt), the ACP Server needs to: +//! 1. Forward the request to the client via SSE +//! 2. Wait for the client's response via HTTP POST +//! 3. Deliver the response back to the agent +//! +//! The `AgentRequestTracker` manages the pending requests and provides a way +//! to correlate responses with their corresponding requests. +//! +//! ## Example Flow +//! +//! ```rust,ignore +//! use dirigent_acp_api::agent_requests::AgentRequestTracker; +//! use serde_json::json; +//! +//! // Create tracker +//! let tracker = AgentRequestTracker::new(); +//! +//! // Agent sends request - register it and get receiver +//! let request_id = json!(0); +//! let client_id = "client-123"; +//! let receiver = tracker.register(client_id, request_id.clone()); +//! +//! // Forward request to client via SSE... +//! +//! // Client responds via HTTP POST - complete the request +//! let response = json!({"selectedOptionId": "allow"}); +//! tracker.complete(client_id, request_id, response)?; +//! +//! // The receiver now gets the response +//! let response_value = receiver.await?; +//! ``` + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use anyhow::{anyhow, Result}; +use serde_json::Value; +use tokio::sync::oneshot; +use tracing::{debug, warn}; + +/// Tracks pending agent requests awaiting client responses +/// +/// This struct provides thread-safe storage for correlating agent requests +/// with their eventual client responses. It uses oneshot channels to deliver +/// responses to waiting tasks. +/// +/// ## Thread Safety +/// +/// The tracker is designed to be cloned and shared across async tasks. +/// Internal state is protected by `Arc>` for thread-safe access. +#[derive(Debug, Clone)] +pub struct AgentRequestTracker { + /// Maps (client_id, request_id_string) to oneshot sender for delivering response + /// + /// The key is a tuple of client ID and request ID (as string) to uniquely + /// identify each pending request. The value is a oneshot sender that will + /// be used to deliver the response when it arrives. + pending: Arc>>>, +} + +impl Default for AgentRequestTracker { + fn default() -> Self { + Self::new() + } +} + +impl AgentRequestTracker { + /// Create a new agent request tracker + /// + /// Returns an empty tracker ready to register pending requests. + pub fn new() -> Self { + Self { + pending: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Register a pending agent request and return a receiver for the response + /// + /// This method creates a oneshot channel, stores the sender in the pending + /// requests map, and returns the receiver. The caller can await on the + /// receiver to get the client's response. + /// + /// # Parameters + /// + /// - `client_id`: The ID of the client that should respond to this request + /// - `request_id`: The request ID from the agent (JSON-RPC id field) + /// + /// # Returns + /// + /// A oneshot receiver that will receive the client's response when + /// `complete()` is called with matching client_id and request_id. + /// + /// # Example + /// + /// ```rust,ignore + /// let receiver = tracker.register("client-123", json!(0)); + /// + /// // Later, when client responds... + /// tracker.complete("client-123", json!(0), response)?; + /// + /// // The receiver gets the response + /// let response = receiver.await?; + /// ``` + pub fn register(&self, client_id: &str, request_id: Value) -> oneshot::Receiver { + let (tx, rx) = oneshot::channel(); + let key = (client_id.to_string(), request_id.to_string()); + + let mut pending = self.pending.lock().expect("Lock poisoned"); + pending.insert(key.clone(), tx); + + debug!( + client_id = %client_id, + request_id = %request_id, + "Registered pending agent request" + ); + + rx + } + + /// Complete a pending agent request with the client's response + /// + /// This method looks up the pending request, sends the response through + /// the oneshot channel, and removes it from the pending map. + /// + /// # Parameters + /// + /// - `client_id`: The ID of the client sending the response + /// - `request_id`: The request ID from the original agent request + /// - `response`: The client's response (JSON-RPC response object) + /// + /// # Returns + /// + /// - `Ok(())` if the request was found and the response was delivered + /// - `Err` if the request_id was not found in pending requests + /// + /// # Errors + /// + /// Returns an error if: + /// - The request_id is not found in pending requests (may have timed out) + /// - The receiver has been dropped (unlikely but possible) + /// + /// # Example + /// + /// ```rust,ignore + /// // Client POSTs response to /acp/agent_response + /// let response = json!({ + /// "jsonrpc": "2.0", + /// "id": 0, + /// "result": {"selectedOptionId": "allow"} + /// }); + /// + /// tracker.complete("client-123", json!(0), response)?; + /// ``` + pub fn complete(&self, client_id: &str, request_id: Value, response: Value) -> Result<()> { + let key = (client_id.to_string(), request_id.to_string()); + + let mut pending = self.pending.lock().expect("Lock poisoned"); + + if let Some(sender) = pending.remove(&key) { + debug!( + client_id = %client_id, + request_id = %request_id, + "Completing pending agent request" + ); + + // Send the response through the oneshot channel + sender.send(response).map_err(|_| { + anyhow!( + "Failed to send response for request {}: receiver dropped", + request_id + ) + })?; + + Ok(()) + } else { + warn!( + client_id = %client_id, + request_id = %request_id, + "Attempted to complete non-existent agent request (may have timed out)" + ); + + Err(anyhow!( + "Request ID {} not found for client {}", + request_id, + client_id + )) + } + } + + /// Timeout a pending agent request + /// + /// This method removes a pending request from the map and logs a timeout + /// warning. It should be called when a request has been pending for too + /// long (e.g., 30 seconds) without a client response. + /// + /// # Parameters + /// + /// - `client_id`: The ID of the client that was supposed to respond + /// - `request_id`: The request ID that timed out + /// + /// # Note + /// + /// This method does not send any error response through the channel. + /// The receiver will get a `RecvError` when it tries to receive, which + /// the caller should interpret as a timeout. + /// + /// # Example + /// + /// ```rust,ignore + /// use tokio::time::{timeout, Duration}; + /// + /// let receiver = tracker.register("client-123", json!(0)); + /// + /// // Wait up to 30 seconds for response + /// match timeout(Duration::from_secs(30), receiver).await { + /// Ok(Ok(response)) => { + /// // Got response + /// } + /// Ok(Err(_)) => { + /// // Channel closed (timeout was called) + /// tracker.timeout("client-123", json!(0)); + /// } + /// Err(_) => { + /// // Timeout elapsed + /// tracker.timeout("client-123", json!(0)); + /// } + /// } + /// ``` + pub fn timeout(&self, client_id: &str, request_id: Value) { + let key = (client_id.to_string(), request_id.to_string()); + + let mut pending = self.pending.lock().expect("Lock poisoned"); + + if pending.remove(&key).is_some() { + warn!( + client_id = %client_id, + request_id = %request_id, + "Agent request timed out (30s elapsed without client response)" + ); + } + } + + /// Get the number of pending requests + /// + /// This method returns the total number of requests currently awaiting + /// client responses across all clients. + pub fn pending_count(&self) -> usize { + let pending = self.pending.lock().expect("Lock poisoned"); + pending.len() + } + + /// Get the number of pending requests for a specific client + /// + /// # Parameters + /// + /// - `client_id`: The ID of the client to query + pub fn client_pending_count(&self, client_id: &str) -> usize { + let pending = self.pending.lock().expect("Lock poisoned"); + pending + .keys() + .filter(|(cid, _)| cid == client_id) + .count() + } + + /// Clear all pending requests (used when shutting down or on client disconnect) + /// + /// This method removes all pending requests from the tracker. The oneshot + /// senders are dropped, which will cause their receivers to get `RecvError`. + /// + /// # Parameters + /// + /// - `client_id`: Optional client ID to clear only that client's pending requests. + /// If None, clears all pending requests. + pub fn clear(&self, client_id: Option<&str>) { + let mut pending = self.pending.lock().expect("Lock poisoned"); + + match client_id { + Some(id) => { + let keys_to_remove: Vec<_> = pending + .keys() + .filter(|(cid, _)| cid == id) + .cloned() + .collect(); + + for key in keys_to_remove { + pending.remove(&key); + } + + debug!( + client_id = %id, + "Cleared all pending agent requests for client" + ); + } + None => { + let count = pending.len(); + pending.clear(); + + debug!( + count = count, + "Cleared all pending agent requests" + ); + } + } + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[tokio::test] + async fn test_register_and_complete() { + let tracker = AgentRequestTracker::new(); + + let client_id = "client-123"; + let request_id = json!(0); + let response = json!({"result": "success"}); + + // Register request + let receiver = tracker.register(client_id, request_id.clone()); + assert_eq!(tracker.pending_count(), 1); + + // Complete request + let result = tracker.complete(client_id, request_id, response.clone()); + assert!(result.is_ok()); + assert_eq!(tracker.pending_count(), 0); + + // Receiver should get the response + let received = receiver.await.unwrap(); + assert_eq!(received, response); + } + + #[tokio::test] + async fn test_complete_non_existent_request() { + let tracker = AgentRequestTracker::new(); + + let client_id = "client-123"; + let request_id = json!(999); + let response = json!({"result": "success"}); + + // Try to complete non-existent request + let result = tracker.complete(client_id, request_id, response); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_timeout() { + let tracker = AgentRequestTracker::new(); + + let client_id = "client-123"; + let request_id = json!(0); + + // Register request + let receiver = tracker.register(client_id, request_id.clone()); + assert_eq!(tracker.pending_count(), 1); + + // Timeout the request + tracker.timeout(client_id, request_id); + assert_eq!(tracker.pending_count(), 0); + + // Receiver should get error (channel closed) + let result = receiver.await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_multiple_pending_requests() { + let tracker = AgentRequestTracker::new(); + + let client1 = "client-1"; + let client2 = "client-2"; + + // Register multiple requests + let _rx1 = tracker.register(client1, json!(0)); + let _rx2 = tracker.register(client1, json!(1)); + let _rx3 = tracker.register(client2, json!(0)); + + assert_eq!(tracker.pending_count(), 3); + assert_eq!(tracker.client_pending_count(client1), 2); + assert_eq!(tracker.client_pending_count(client2), 1); + } + + #[tokio::test] + async fn test_clear_all() { + let tracker = AgentRequestTracker::new(); + + let _rx1 = tracker.register("client-1", json!(0)); + let _rx2 = tracker.register("client-2", json!(0)); + + assert_eq!(tracker.pending_count(), 2); + + tracker.clear(None); + assert_eq!(tracker.pending_count(), 0); + } + + #[tokio::test] + async fn test_clear_client() { + let tracker = AgentRequestTracker::new(); + + let client1 = "client-1"; + let client2 = "client-2"; + + let _rx1 = tracker.register(client1, json!(0)); + let _rx2 = tracker.register(client1, json!(1)); + let _rx3 = tracker.register(client2, json!(0)); + + assert_eq!(tracker.pending_count(), 3); + + // Clear only client1's requests + tracker.clear(Some(client1)); + assert_eq!(tracker.pending_count(), 1); + assert_eq!(tracker.client_pending_count(client1), 0); + assert_eq!(tracker.client_pending_count(client2), 1); + } + + #[test] + fn test_tracker_clone() { + let tracker = AgentRequestTracker::new(); + let tracker_clone = tracker.clone(); + + // Both should point to same underlying state + let _rx = tracker.register("client-1", json!(0)); + assert_eq!(tracker_clone.pending_count(), 1); + } +} diff --git a/crates/dirigent_acp_api/src/config.rs b/crates/dirigent_acp_api/src/config.rs new file mode 100644 index 0000000..c270fc0 --- /dev/null +++ b/crates/dirigent_acp_api/src/config.rs @@ -0,0 +1,275 @@ +//! Configuration types for the ACP Server +//! +//! This module defines configuration types for the ACP Server, including +//! server settings, connection limits, and CORS configuration. + +use serde::{Deserialize, Serialize}; + +/// Default port for the ACP Server +pub const DEFAULT_PORT: u16 = 3001; + +/// Default maximum number of concurrent connections +pub const DEFAULT_MAX_CONNECTIONS: usize = 100; + +/// Configuration for the ACP Server +/// +/// This struct contains all configurable options for the ACP Server, +/// including network settings, security options, and resource limits. +/// +/// **Note**: This config is used when starting an actual TCP server +/// (separate port mode only). For integrated mode (mounting at /acp), +/// the port field is still required but represents which port was chosen +/// for the separate server path. Higher-level configs (dirigent_core, api) +/// use `Option` to represent the integrated vs separate distinction. +/// +/// TODO: Consider moving the AcpPortConfig enum from web package to core +/// and using it here to make the integrated/separate distinction explicit. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct AcpServerConfig { + /// Whether the ACP Server is enabled + /// + /// When disabled, the server will not accept incoming connections. + /// Default: false + #[serde(default)] + pub enabled: bool, + + /// The port to listen on for incoming connections + /// + /// This is always a concrete port number because this config is used + /// to start an actual TCP server. Use Option in higher-level configs + /// to represent integrated mode (None) vs separate mode (Some(port)). + /// + /// Default: 3001 + #[serde(default = "default_port")] + pub port: u16, + + /// List of allowed origins for CORS + /// + /// When Some, only requests from these origins are allowed. + /// When None, all origins are allowed (use with caution). + /// + /// Example: ["http://localhost:3000", "https://app.example.com"] + #[serde(default)] + pub allowed_origins: Option>, + + /// Maximum number of concurrent client connections + /// + /// New connections will be rejected when this limit is reached. + /// Default: 100 + #[serde(default = "default_max_connections")] + pub max_connections: usize, +} + +/// Returns the default port value +fn default_port() -> u16 { + DEFAULT_PORT +} + +/// Returns the default max connections value +fn default_max_connections() -> usize { + DEFAULT_MAX_CONNECTIONS +} + +impl Default for AcpServerConfig { + fn default() -> Self { + Self { + enabled: false, + port: DEFAULT_PORT, + allowed_origins: None, + max_connections: DEFAULT_MAX_CONNECTIONS, + } + } +} + +impl AcpServerConfig { + /// Create a new configuration with default values + pub fn new() -> Self { + Self::default() + } + + /// Create a configuration with the server enabled + pub fn enabled() -> Self { + Self { + enabled: true, + ..Default::default() + } + } + + /// Create a configuration with a specific port + pub fn with_port(port: u16) -> Self { + Self { + port, + ..Default::default() + } + } + + /// Set whether the server is enabled + pub fn set_enabled(mut self, enabled: bool) -> Self { + self.enabled = enabled; + self + } + + /// Set the port + pub fn set_port(mut self, port: u16) -> Self { + self.port = port; + self + } + + /// Set the allowed origins + pub fn set_allowed_origins(mut self, origins: Option>) -> Self { + self.allowed_origins = origins; + self + } + + /// Set the maximum number of connections + pub fn set_max_connections(mut self, max: usize) -> Self { + self.max_connections = max; + self + } + + /// Check if the configuration is valid + pub fn validate(&self) -> Result<(), String> { + if self.port == 0 { + return Err("Port cannot be 0".to_string()); + } + + if self.max_connections == 0 { + return Err("max_connections must be at least 1".to_string()); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = AcpServerConfig::default(); + assert!(!config.enabled); + assert_eq!(config.port, DEFAULT_PORT); + assert!(config.allowed_origins.is_none()); + assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS); + } + + #[test] + fn test_enabled_config() { + let config = AcpServerConfig::enabled(); + assert!(config.enabled); + assert_eq!(config.port, DEFAULT_PORT); + } + + #[test] + fn test_with_port() { + let config = AcpServerConfig::with_port(8080); + assert!(!config.enabled); + assert_eq!(config.port, 8080); + } + + #[test] + fn test_builder_pattern() { + let config = AcpServerConfig::new() + .set_enabled(true) + .set_port(4000) + .set_allowed_origins(Some(vec!["http://localhost:3000".to_string()])) + .set_max_connections(50); + + assert!(config.enabled); + assert_eq!(config.port, 4000); + assert_eq!( + config.allowed_origins, + Some(vec!["http://localhost:3000".to_string()]) + ); + assert_eq!(config.max_connections, 50); + } + + #[test] + fn test_validation_valid() { + let config = AcpServerConfig::default(); + assert!(config.validate().is_ok()); + } + + #[test] + fn test_validation_invalid_port() { + let config = AcpServerConfig { + port: 0, + ..Default::default() + }; + assert!(config.validate().is_err()); + } + + #[test] + fn test_validation_invalid_max_connections() { + let config = AcpServerConfig { + max_connections: 0, + ..Default::default() + }; + assert!(config.validate().is_err()); + } + + #[test] + fn test_serialization() { + let config = AcpServerConfig { + enabled: true, + port: 3001, + allowed_origins: Some(vec!["http://localhost:3000".to_string()]), + max_connections: 100, + }; + + let json = serde_json::to_string(&config).unwrap(); + assert!(json.contains("\"enabled\":true")); + assert!(json.contains("\"port\":3001")); + assert!(json.contains("\"allowed_origins\"")); + assert!(json.contains("\"max_connections\":100")); + + // Deserialize back + let parsed: AcpServerConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, config); + } + + #[test] + fn test_deserialization_with_defaults() { + // Minimal JSON with defaults + let json = r#"{"enabled":true}"#; + let config: AcpServerConfig = serde_json::from_str(json).unwrap(); + + assert!(config.enabled); + assert_eq!(config.port, DEFAULT_PORT); + assert!(config.allowed_origins.is_none()); + assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS); + } + + #[test] + fn test_deserialization_empty() { + // Empty JSON should use all defaults + let json = "{}"; + let config: AcpServerConfig = serde_json::from_str(json).unwrap(); + + assert!(!config.enabled); + assert_eq!(config.port, DEFAULT_PORT); + assert!(config.allowed_origins.is_none()); + assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS); + } + + #[test] + fn test_equality() { + let config1 = AcpServerConfig::default(); + let config2 = AcpServerConfig::default(); + assert_eq!(config1, config2); + + let config3 = AcpServerConfig::enabled(); + assert_ne!(config1, config3); + } + + #[test] + fn test_clone() { + let config = AcpServerConfig::enabled() + .set_port(8080) + .set_allowed_origins(Some(vec!["origin".to_string()])); + + let cloned = config.clone(); + assert_eq!(config, cloned); + } +} diff --git a/crates/dirigent_acp_api/src/error.rs b/crates/dirigent_acp_api/src/error.rs new file mode 100644 index 0000000..981ae84 --- /dev/null +++ b/crates/dirigent_acp_api/src/error.rs @@ -0,0 +1,362 @@ +//! Error types for the ACP Server +//! +//! This module defines error types used throughout the ACP Server implementation, +//! including conversions to JSON-RPC error format for client responses. + +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// Standard JSON-RPC error codes as defined in the specification +pub mod error_codes { + /// Parse error - Invalid JSON was received by the server + pub const PARSE_ERROR: i32 = -32700; + + /// Invalid Request - The JSON sent is not a valid Request object + pub const INVALID_REQUEST: i32 = -32600; + + /// Method not found - The method does not exist / is not available + pub const METHOD_NOT_FOUND: i32 = -32601; + + /// Invalid params - Invalid method parameter(s) + pub const INVALID_PARAMS: i32 = -32602; + + /// Internal error - Internal JSON-RPC error + pub const INTERNAL_ERROR: i32 = -32603; + + // Server errors reserved for implementation-defined errors (-32000 to -32099) + + /// Session not found error + pub const SESSION_NOT_FOUND: i32 = -32001; + + /// Connector not found error + pub const CONNECTOR_NOT_FOUND: i32 = -32002; + + /// Invalid session error + pub const INVALID_SESSION: i32 = -32003; + + /// Transport error + pub const TRANSPORT_ERROR: i32 = -32004; + + /// Client not found error + pub const CLIENT_NOT_FOUND: i32 = -32005; +} + +/// ACP Server error enum representing all possible error conditions +#[derive(Debug, Clone, PartialEq)] +pub enum AcpServerError { + /// The session ID provided is invalid or malformed + InvalidSession, + + /// An RPC-related error occurred + /// + /// Contains a description of the RPC error, such as method not found, + /// invalid params, or parse errors. + RpcError(String), + + /// A transport-level error occurred + /// + /// Contains a description of the transport error, such as connection + /// failures, timeout, or network issues. + TransportError(String), + + /// The requested session was not found + /// + /// The session ID does not correspond to any known session in the + /// session manager. + SessionNotFound, + + /// The requested connector was not found + /// + /// The connector ID does not correspond to any registered connector + /// in the runtime. Contains the connector ID that was not found. + ConnectorNotFound(String), + + /// The connector is not in a ready state + /// + /// The connector exists but is not available to handle requests + /// (e.g., still connecting, error state, or stopped). + /// Contains the connector ID that is not ready. + ConnectorNotReady(String), + + /// Operation timed out + /// + /// Contains a description of what timed out. + Timeout(String), + + /// An internal server error occurred + /// + /// Contains a description of the internal error. Used for unexpected + /// conditions that don't fit other categories. + Internal(String), + + /// The requested client was not found + /// + /// The client ID does not correspond to any connected client in the + /// session manager. Contains the client ID that was not found. + ClientNotFound(String), +} + +impl fmt::Display for AcpServerError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AcpServerError::InvalidSession => { + write!(f, "Invalid session ID provided") + } + AcpServerError::RpcError(msg) => { + write!(f, "RPC error: {}", msg) + } + AcpServerError::TransportError(msg) => { + write!(f, "Transport error: {}", msg) + } + AcpServerError::SessionNotFound => { + write!(f, "Session not found") + } + AcpServerError::ConnectorNotFound(id) => { + write!(f, "Connector not found: {}", id) + } + AcpServerError::ConnectorNotReady(id) => { + write!(f, "Connector not ready: {}", id) + } + AcpServerError::Timeout(msg) => { + write!(f, "Timeout: {}", msg) + } + AcpServerError::Internal(msg) => { + write!(f, "Internal error: {}", msg) + } + AcpServerError::ClientNotFound(id) => { + write!(f, "Client not found: {}", id) + } + } + } +} + +impl std::error::Error for AcpServerError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + // None of our error variants wrap other errors currently + None + } +} + +/// JSON-RPC error representation for wire format +/// +/// This struct is used in JSON-RPC responses when an error occurs. +/// It follows the JSON-RPC 2.0 specification for error objects. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct JsonRpcErrorObject { + /// A Number that indicates the error type that occurred + pub code: i32, + + /// A String providing a short description of the error + pub message: String, + + /// Optional additional information about the error + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +impl JsonRpcErrorObject { + /// Create a new JSON-RPC error object + pub fn new(code: i32, message: impl Into) -> Self { + Self { + code, + message: message.into(), + data: None, + } + } + + /// Create a new JSON-RPC error object with additional data + pub fn with_data(code: i32, message: impl Into, data: serde_json::Value) -> Self { + Self { + code, + message: message.into(), + data: Some(data), + } + } + + /// Create a parse error + pub fn parse_error(message: impl Into) -> Self { + Self::new(error_codes::PARSE_ERROR, message) + } + + /// Create an invalid request error + pub fn invalid_request(message: impl Into) -> Self { + Self::new(error_codes::INVALID_REQUEST, message) + } + + /// Create a method not found error + pub fn method_not_found(method: impl Into) -> Self { + Self::new( + error_codes::METHOD_NOT_FOUND, + format!("Method not found: {}", method.into()), + ) + } + + /// Create an invalid params error + pub fn invalid_params(message: impl Into) -> Self { + Self::new(error_codes::INVALID_PARAMS, message) + } + + /// Create an internal error + pub fn internal_error(message: impl Into) -> Self { + Self::new(error_codes::INTERNAL_ERROR, message) + } +} + +impl From for JsonRpcErrorObject { + fn from(error: AcpServerError) -> Self { + match error { + AcpServerError::InvalidSession => { + JsonRpcErrorObject::new(error_codes::INVALID_SESSION, "Invalid session ID provided") + } + AcpServerError::RpcError(msg) => { + JsonRpcErrorObject::new(error_codes::INTERNAL_ERROR, msg) + } + AcpServerError::TransportError(msg) => { + JsonRpcErrorObject::new(error_codes::TRANSPORT_ERROR, msg) + } + AcpServerError::SessionNotFound => { + JsonRpcErrorObject::new(error_codes::SESSION_NOT_FOUND, "Session not found") + } + AcpServerError::ConnectorNotFound(id) => { + JsonRpcErrorObject::new(error_codes::CONNECTOR_NOT_FOUND, format!("Connector not found: {}", id)) + } + AcpServerError::ConnectorNotReady(id) => { + JsonRpcErrorObject::new(error_codes::CONNECTOR_NOT_FOUND, format!("Connector not ready: {}", id)) + } + AcpServerError::Timeout(msg) => { + JsonRpcErrorObject::new(error_codes::TRANSPORT_ERROR, format!("Timeout: {}", msg)) + } + AcpServerError::Internal(msg) => JsonRpcErrorObject::internal_error(msg), + AcpServerError::ClientNotFound(id) => { + JsonRpcErrorObject::new(error_codes::CLIENT_NOT_FOUND, format!("Client not found: {}", id)) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_display() { + assert_eq!( + AcpServerError::InvalidSession.to_string(), + "Invalid session ID provided" + ); + assert_eq!( + AcpServerError::RpcError("test".to_string()).to_string(), + "RPC error: test" + ); + assert_eq!( + AcpServerError::TransportError("timeout".to_string()).to_string(), + "Transport error: timeout" + ); + assert_eq!( + AcpServerError::SessionNotFound.to_string(), + "Session not found" + ); + assert_eq!( + AcpServerError::ConnectorNotFound("test-conn".to_string()).to_string(), + "Connector not found: test-conn" + ); + assert_eq!( + AcpServerError::ConnectorNotReady("test-conn".to_string()).to_string(), + "Connector not ready: test-conn" + ); + assert_eq!( + AcpServerError::Timeout("request timed out".to_string()).to_string(), + "Timeout: request timed out" + ); + assert_eq!( + AcpServerError::Internal("something went wrong".to_string()).to_string(), + "Internal error: something went wrong" + ); + assert_eq!( + AcpServerError::ClientNotFound("client-123".to_string()).to_string(), + "Client not found: client-123" + ); + } + + #[test] + fn test_error_to_jsonrpc() { + let error: JsonRpcErrorObject = AcpServerError::InvalidSession.into(); + assert_eq!(error.code, error_codes::INVALID_SESSION); + assert_eq!(error.message, "Invalid session ID provided"); + + let error: JsonRpcErrorObject = AcpServerError::SessionNotFound.into(); + assert_eq!(error.code, error_codes::SESSION_NOT_FOUND); + + let error: JsonRpcErrorObject = AcpServerError::ConnectorNotFound("conn1".to_string()).into(); + assert_eq!(error.code, error_codes::CONNECTOR_NOT_FOUND); + assert!(error.message.contains("conn1")); + + let error: JsonRpcErrorObject = AcpServerError::ConnectorNotReady("conn2".to_string()).into(); + assert_eq!(error.code, error_codes::CONNECTOR_NOT_FOUND); + assert!(error.message.contains("conn2")); + + let error: JsonRpcErrorObject = AcpServerError::Timeout("session creation".to_string()).into(); + assert_eq!(error.code, error_codes::TRANSPORT_ERROR); + + let error: JsonRpcErrorObject = AcpServerError::TransportError("net error".to_string()).into(); + assert_eq!(error.code, error_codes::TRANSPORT_ERROR); + assert_eq!(error.message, "net error"); + + let error: JsonRpcErrorObject = AcpServerError::Internal("internal".to_string()).into(); + assert_eq!(error.code, error_codes::INTERNAL_ERROR); + + let error: JsonRpcErrorObject = AcpServerError::ClientNotFound("client-456".to_string()).into(); + assert_eq!(error.code, error_codes::CLIENT_NOT_FOUND); + assert!(error.message.contains("client-456")); + } + + #[test] + fn test_jsonrpc_error_serialization() { + let error = JsonRpcErrorObject::new(error_codes::PARSE_ERROR, "Invalid JSON"); + let json = serde_json::to_string(&error).unwrap(); + assert!(json.contains("-32700")); + assert!(json.contains("Invalid JSON")); + + // With data + let error = JsonRpcErrorObject::with_data( + error_codes::INVALID_PARAMS, + "Missing field", + serde_json::json!({"field": "session_id"}), + ); + let json = serde_json::to_string(&error).unwrap(); + assert!(json.contains("session_id")); + } + + #[test] + fn test_jsonrpc_error_factories() { + let error = JsonRpcErrorObject::parse_error("bad json"); + assert_eq!(error.code, error_codes::PARSE_ERROR); + + let error = JsonRpcErrorObject::invalid_request("missing jsonrpc field"); + assert_eq!(error.code, error_codes::INVALID_REQUEST); + + let error = JsonRpcErrorObject::method_not_found("session.unknown"); + assert_eq!(error.code, error_codes::METHOD_NOT_FOUND); + assert!(error.message.contains("session.unknown")); + + let error = JsonRpcErrorObject::invalid_params("session_id required"); + assert_eq!(error.code, error_codes::INVALID_PARAMS); + + let error = JsonRpcErrorObject::internal_error("panic"); + assert_eq!(error.code, error_codes::INTERNAL_ERROR); + } + + #[test] + fn test_error_is_error_trait() { + fn assert_is_error() {} + assert_is_error::(); + } + + #[test] + fn test_error_clone() { + let error = AcpServerError::Internal("test".to_string()); + let cloned = error.clone(); + assert_eq!(error, cloned); + } +} diff --git a/crates/dirigent_acp_api/src/event_bridge.rs b/crates/dirigent_acp_api/src/event_bridge.rs new file mode 100644 index 0000000..57f6aa8 --- /dev/null +++ b/crates/dirigent_acp_api/src/event_bridge.rs @@ -0,0 +1,1613 @@ +//! Event Bridge for ACP Server +//! +//! This module provides the `EventBridge` which subscribes to a source event stream +//! (e.g., from `dirigent_core`) and forwards translated events to connected ACP clients +//! via the SSE notifier. +//! +//! ## Architecture +//! +//! The `EventBridge` runs as a background task that: +//! 1. Subscribes to a source event stream +//! 2. For each event, looks up the session mapping to find the client +//! 3. Translates the event to an ACP notification +//! 4. Broadcasts the notification to the appropriate client via SSE +//! +//! ## Example +//! +//! ```rust,ignore +//! use dirigent_acp_api::event_bridge::EventBridge; +//! use dirigent_acp_api::session_manager::SessionManager; +//! use dirigent_acp_api::sse::SseNotifier; +//! use tokio_stream::wrappers::BroadcastStream; +//! +//! // Create components +//! let session_manager = SessionManager::new(); +//! let sse_notifier = SseNotifier::new(); +//! +//! // Create event bridge +//! let agent_request_tracker = Arc::new(AgentRequestTracker::new()); +//! let bridge = EventBridge::new(session_manager, sse_notifier, agent_request_tracker); +//! +//! // Start with an event stream (e.g., from CoreHandle::subscribe()) +//! bridge.start(event_stream).await; +//! +//! // Later, stop the bridge +//! bridge.stop(); +//! ``` + +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use tokio::sync::mpsc; +use tokio::task::JoinHandle; +use tracing::{debug, info, trace, warn}; + +use dirigent_protocol::Event; + +use crate::rpc::ConnectorOperations; +use crate::session_manager::SessionManager; +use crate::sse::{translate_event, SessionUpdateParams, SseNotifier}; + +// ============================================================================ +// EventBridge (T032-T033) +// ============================================================================ + +/// Configuration for the event bridge +#[derive(Debug, Clone)] +pub struct EventBridgeConfig { + /// Whether to broadcast system-wide errors to all clients + pub broadcast_system_errors: bool, + + /// Whether to log events that don't have a session mapping + pub log_unmapped_events: bool, +} + +impl Default for EventBridgeConfig { + fn default() -> Self { + Self { + broadcast_system_errors: true, + log_unmapped_events: true, + } + } +} + +/// Internal state for the event bridge +struct EventBridgeInner { + /// Session manager for looking up session mappings + session_manager: SessionManager, + + /// SSE notifier for broadcasting to clients + sse_notifier: SseNotifier, + + /// Agent request tracker for bidirectional request/response + agent_request_tracker: Arc, + + /// Configuration + config: EventBridgeConfig, + + /// Flag to signal shutdown + shutdown: AtomicBool, + + /// Optional connector operations handle for querying connectors + /// Used for finding Gateway connectors during fallback scenarios and sending connector commands + core_handle: Option>, +} + +/// Event Bridge for forwarding events from a source stream to ACP clients +/// +/// The `EventBridge` subscribes to an event stream (typically from `CoreHandle`) +/// and forwards relevant events to connected ACP clients via SSE. +/// +/// ## Thread Safety +/// +/// The `EventBridge` is designed to be cloned and shared. The internal state +/// is wrapped in `Arc` for thread-safe access. +/// +/// ## Lifecycle +/// +/// 1. Create with `new()` or `with_config()` +/// 2. Start the background task with `start()` +/// 3. The bridge runs until `stop()` is called or the source stream ends +/// 4. Use `is_running()` to check status +#[derive(Clone)] +pub struct EventBridge { + inner: Arc, + + /// Sender for the shutdown command + shutdown_tx: Option>, +} + +impl EventBridge { + /// Create a new event bridge with default configuration + /// + /// # Parameters + /// + /// - `session_manager`: The session manager for looking up mappings + /// - `sse_notifier`: The SSE notifier for broadcasting + /// - `agent_request_tracker`: The agent request tracker for bidirectional requests + pub fn new( + session_manager: SessionManager, + sse_notifier: SseNotifier, + agent_request_tracker: Arc, + ) -> Self { + Self::with_config( + session_manager, + sse_notifier, + agent_request_tracker, + EventBridgeConfig::default(), + ) + } + + /// Create a new event bridge with custom configuration + /// + /// # Parameters + /// + /// - `session_manager`: The session manager for looking up mappings + /// - `sse_notifier`: The SSE notifier for broadcasting + /// - `agent_request_tracker`: The agent request tracker for bidirectional requests + /// - `config`: The bridge configuration + pub fn with_config( + session_manager: SessionManager, + sse_notifier: SseNotifier, + agent_request_tracker: Arc, + config: EventBridgeConfig, + ) -> Self { + Self { + inner: Arc::new(EventBridgeInner { + session_manager, + sse_notifier, + agent_request_tracker, + config, + shutdown: AtomicBool::new(false), + core_handle: None, + }), + shutdown_tx: None, + } + } + + /// Create a new event bridge with core handle for connector operations + /// + /// # Parameters + /// + /// - `session_manager`: The session manager for looking up mappings + /// - `sse_notifier`: The SSE notifier for broadcasting + /// - `agent_request_tracker`: The agent request tracker for bidirectional requests + /// - `config`: The bridge configuration + /// - `core_handle`: Optional connector operations handle for querying connectors and sending commands + pub fn with_core_handle( + session_manager: SessionManager, + sse_notifier: SseNotifier, + agent_request_tracker: Arc, + config: EventBridgeConfig, + core_handle: Option>, + ) -> Self { + Self { + inner: Arc::new(EventBridgeInner { + session_manager, + sse_notifier, + agent_request_tracker, + config, + shutdown: AtomicBool::new(false), + core_handle, + }), + shutdown_tx: None, + } + } + + /// Start the event bridge background task + /// + /// This method spawns a background task that reads from the provided event + /// stream and forwards translated events to the appropriate ACP clients. + /// + /// # Type Parameters + /// + /// - `S`: The event stream type (must implement `Stream`) + /// + /// # Parameters + /// + /// - `event_stream`: The source event stream to subscribe to + /// + /// # Returns + /// + /// A `JoinHandle` for the background task. The task will run until: + /// - `stop()` is called + /// - The source stream ends + /// - An unrecoverable error occurs + /// + /// # Example + /// + /// ```rust,ignore + /// let agent_request_tracker = Arc::new(AgentRequestTracker::new()); + /// let bridge = EventBridge::new(session_manager, sse_notifier, agent_request_tracker); + /// let handle = bridge.start(event_stream); + /// + /// // Later... + /// bridge.stop(); + /// handle.await; // Wait for clean shutdown + /// ``` + pub fn start(&mut self, event_stream: S) -> JoinHandle<()> + where + S: tokio_stream::Stream + Send + Unpin + 'static, + { + use tokio_stream::StreamExt; + + // Reset shutdown flag + self.inner.shutdown.store(false, Ordering::SeqCst); + + // Create shutdown channel + let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); + self.shutdown_tx = Some(shutdown_tx); + + let inner = self.inner.clone(); + + tokio::spawn(async move { + info!("Event bridge started"); + + let mut stream = Box::pin(event_stream); + let mut event_count = 0u64; + + loop { + tokio::select! { + // Check for shutdown signal + _ = shutdown_rx.recv() => { + info!("Event bridge received shutdown signal"); + break; + } + + // Check for shutdown flag + _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)), if inner.shutdown.load(Ordering::SeqCst) => { + info!("Event bridge shutdown flag set"); + break; + } + + // Process next event from stream + event = stream.next() => { + match event { + Some(evt) => { + event_count += 1; + Self::process_event(&inner, evt).await; + } + None => { + info!("Event bridge source stream ended"); + break; + } + } + } + } + } + + info!( + "Event bridge stopped after processing {} events", + event_count + ); + }) + } + + /// Process a single event + async fn process_event(inner: &EventBridgeInner, event: Event) { + debug!("EventBridge processing event: {:?}", event); + + // Handle special events that don't need translation + match &event { + Event::SessionTransferred { + from_connector, + from_session, + to_connector, + to_session, + is_new_session, + models, + modes, + } => { + Self::handle_session_transferred_internal( + inner, + from_connector, + from_session, + to_connector, + to_session, + *is_new_session, + models.clone(), + modes.clone(), + ) + .await; + return; + } + Event::ForwardingPanic { + connector_id, + session_id, + reason, + fallback_gateway_session, + } => { + Self::handle_forwarding_panic_internal( + inner, + connector_id, + session_id, + reason, + fallback_gateway_session.as_deref(), + ) + .await; + return; + } + Event::AgentRequest { + connector_id, + session_id, + request_id, + method, + params, + is_forwarded: _, // Ignored: handled by routing logic before this point + } => { + Self::handle_agent_request_internal( + inner, + connector_id, + session_id, + request_id.clone(), + method, + params.clone(), + ) + .await; + return; + } + _ => {} + } + + // Try to translate the event to session update params + let update_params_list = translate_event(&event); + + if update_params_list.is_empty() { + trace!("Event filtered (not forwarded to clients): {:?}", event); + return; + } + + // Process each update (most events emit 1, SessionMetadataReceived may emit 2) + for mut update_params in update_params_list { + debug!( + "Translated to SessionUpdateParams: session_id={}, variant={}", + update_params.session_id, + update_params.update.variant_name() + ); + + // Determine which client(s) to send to based on the event + // Also translate internal session ID to client session ID + match Self::get_client_for_event_and_translate(inner, &event, &mut update_params) { + EventTarget::Client(client_id) => { + // Send to specific client (session_id already translated) + match inner.sse_notifier.broadcast(&client_id, update_params) { + Ok(n) => { + debug!("Forwarded event to client {}: {} receivers", client_id, n); + } + Err(()) => { + if inner.config.log_unmapped_events { + warn!("Client {} is not subscribed to SSE", client_id); + } + } + } + } + EventTarget::AllClients => { + // Broadcast to all clients (e.g., system errors) + let count = inner.sse_notifier.broadcast_all(update_params); + debug!("Broadcast event to all clients: {} receivers", count); + } + EventTarget::NoTarget => { + // No target found + if inner.config.log_unmapped_events { + debug!("No target found for event, dropping"); + } + } + } + } + } + + /// Determine which client(s) should receive the event AND translate session ID + /// + /// This function: + /// 1. Looks up the session mapping (internal -> client session ID) + /// 2. Translates the session_id in update_params from internal to client ID + /// 3. Returns which client should receive the event + fn get_client_for_event_and_translate( + inner: &EventBridgeInner, + event: &Event, + update_params: &mut SessionUpdateParams, + ) -> EventTarget { + // Extract session_id from the update params (this is the internal session ID) + let session_id = &update_params.session_id; + + // Try to find the session mapping by client_session_id first + if let Some(mapping) = inner.session_manager.get_mapping(session_id) { + // Already using client session ID + return EventTarget::Client(mapping.client_id); + } + + // Try to find by internal_session_id (events from connectors use internal IDs) + // We need to extract the internal session ID from the original event + if let Some(internal_id) = Self::extract_internal_session_id(event) { + debug!( + "Looking up mapping for internal session ID: {}", + internal_id + ); + if let Some(mapping) = inner + .session_manager + .get_mapping_by_internal_id(&internal_id) + { + // TRANSLATE: Replace internal session ID with client session ID + update_params.session_id = mapping.client_session_id.clone(); + debug!( + "Translated session ID: internal {} -> client {}, client_id: {}", + internal_id, mapping.client_session_id, mapping.client_id + ); + return EventTarget::Client(mapping.client_id); + } else { + warn!( + "No session mapping found for internal session ID: {}", + internal_id + ); + } + } else { + warn!( + "Could not extract internal session ID from event: {:?}", + event + ); + } + + debug!( + "No target found for event with session_id: {}", + update_params.session_id + ); + EventTarget::NoTarget + } + + /// Determine which client(s) should receive the event (deprecated - use get_client_for_event_and_translate) + #[allow(dead_code)] + fn get_client_for_event( + inner: &EventBridgeInner, + event: &Event, + update_params: &SessionUpdateParams, + ) -> EventTarget { + // Extract session_id from the update params + let session_id = &update_params.session_id; + + // Try to find the session mapping by client_session_id first + if let Some(mapping) = inner.session_manager.get_mapping(session_id) { + return EventTarget::Client(mapping.client_id); + } + + // Try to find by internal_session_id (events from connectors use internal IDs) + // We need to extract the internal session ID from the original event + if let Some(internal_id) = Self::extract_internal_session_id(event) { + if let Some(mapping) = inner + .session_manager + .get_mapping_by_internal_id(&internal_id) + { + return EventTarget::Client(mapping.client_id); + } + } + + EventTarget::NoTarget + } + + /// Extract the internal session ID from an event + fn extract_internal_session_id(event: &Event) -> Option { + match event { + Event::SessionUpdate { session_id, .. } => Some(session_id.clone()), + Event::MessageCompleted { message, .. } => Some(message.session_id.clone()), + Event::SessionIdle { session_id, .. } => Some(session_id.clone()), + Event::SessionMetadataReceived { session_id, .. } => Some(session_id.clone()), + Event::MessageFailed { .. } => { + // MessageFailed uses message_id, not session_id + // We can't directly map this, so return None + None + } + Event::Error { .. } => None, + _ => None, + } + } + + /// Internal handler for session transfer (static method for use in process_event) + /// + /// The `models` and `modes` parameters are now passed directly from the `SessionTransferred` + /// event, which carries the new connector's session initialization data. This avoids needing + /// to query the connector for metadata after the transfer. + async fn handle_session_transferred_internal( + inner: &EventBridgeInner, + from_connector: &str, + from_session: &str, + to_connector: &str, + to_session: &str, + is_new_session: bool, + models: Option, + modes: Option, + ) { + tracing::info!( + from_connector = %from_connector, + from_session = %from_session, + to_connector = %to_connector, + to_session = %to_session, + is_new = %is_new_session, + has_models = %models.is_some(), + has_modes = %modes.is_some(), + "Processing session transfer" + ); + + // Find the client mapping for the source session + let mapping = inner + .session_manager + .get_mapping_by_gateway_session(from_session); + + if let Some(mapping) = mapping { + // Update the mapping to point to new connector/session + // Note: connector_title is None here, will use connector_id as fallback + if let Some(old_mapping) = inner.session_manager.update_mapping_connector( + &mapping.client_session_id, + to_connector.to_string(), + None, // connector_title - could be looked up from ConnectorOperations + to_session.to_string(), + ) { + tracing::info!( + client_session_id = %mapping.client_session_id, + old_connector = %old_mapping.connector_id, + new_connector = %to_connector, + "Session mapping updated for transfer" + ); + + // Notify client via SSE that the transfer completed + use crate::sse::{SessionUpdateParams, SessionUpdateVariant}; + let update = SessionUpdateParams { + session_id: mapping.client_session_id.clone(), + update: SessionUpdateVariant::ConnectorChanged { + new_connector_id: to_connector.to_string(), + new_internal_session_id: to_session.to_string(), + is_new_session, + }, + event_type_override: None, + }; + + match inner.sse_notifier.broadcast(&mapping.client_id, update) { + Ok(receivers) => { + tracing::info!( + client_id = %mapping.client_id, + session_id = %mapping.client_session_id, + receivers = receivers, + "Notified client about session transfer via SSE" + ); + } + Err(_) => { + tracing::warn!( + client_id = %mapping.client_id, + "Failed to notify client about session transfer via SSE - client not subscribed" + ); + } + } + + // Send config_option_update with the new connector's modes/models + // These are passed directly from the SessionTransferred event, which + // carries the initialization data from the new connector's session/new response. + // See: https://agentclientprotocol.com/rfds/session-config-options + { + let mut config_options = vec![]; + + if let Some(ref modes_state) = modes { + config_options.push(crate::sse::modes_to_config_option(modes_state)); + } + + if let Some(ref models_state) = models { + config_options.push(crate::sse::models_to_config_option(models_state)); + } + + // Send config_option_update if we have any options + if !config_options.is_empty() { + let config_update = SessionUpdateParams { + session_id: mapping.client_session_id.clone(), + update: SessionUpdateVariant::ConfigOptionUpdate { config_options }, + event_type_override: None, + }; + + if let Ok(receivers) = inner + .sse_notifier + .broadcast(&mapping.client_id, config_update) + { + tracing::info!( + client_id = %mapping.client_id, + session_id = %mapping.client_session_id, + receivers = receivers, + "Sent config_option_update after transfer" + ); + } + } else { + tracing::debug!( + client_id = %mapping.client_id, + "No modes/models to send in config_option_update after transfer" + ); + } + } + + // Get available commands from new connector and send update + if let Some(ref core_handle) = inner.core_handle { + match core_handle.get_connector_commands(to_connector).await { + Ok(commands) => { + let commands_update = SessionUpdateParams { + session_id: mapping.client_session_id.clone(), + update: SessionUpdateVariant::AvailableCommandsUpdate { + available_commands: commands.clone(), + }, + event_type_override: None, + }; + + match inner + .sse_notifier + .broadcast(&mapping.client_id, commands_update) + { + Ok(receivers) => { + tracing::info!( + client_id = %mapping.client_id, + session_id = %mapping.client_session_id, + commands = commands.len(), + receivers = receivers, + "Sent AvailableCommandsUpdate after transfer" + ); + } + Err(_) => { + tracing::warn!( + client_id = %mapping.client_id, + "Failed to send AvailableCommandsUpdate after transfer" + ); + } + } + } + Err(e) => { + tracing::warn!( + connector_id = %to_connector, + error = %e, + "Failed to get connector commands after transfer" + ); + } + } + } else { + tracing::debug!("No core handle available - skipping metadata/commands query after transfer"); + } + + // Send confirmation message so the user sees the transfer completed + let mode_text = if is_new_session { + "new session" + } else { + "existing session" + }; + let confirmation_text = format!("Connected to {} ({}).", to_connector, mode_text); + let confirmation_update = SessionUpdateParams { + session_id: mapping.client_session_id.clone(), + update: SessionUpdateVariant::AgentMessageChunk { + content: dirigent_protocol::types::ContentBlock::Text { + text: confirmation_text, + }, + }, + event_type_override: None, + }; + + match inner + .sse_notifier + .broadcast(&mapping.client_id, confirmation_update) + { + Ok(receivers) => { + tracing::info!( + client_id = %mapping.client_id, + session_id = %mapping.client_session_id, + receivers = receivers, + "Sent transfer confirmation message" + ); + } + Err(_) => { + tracing::warn!( + client_id = %mapping.client_id, + "Failed to send transfer confirmation message" + ); + } + } + + // After transfer, send SessionIdle to give control back to client + let idle_update = SessionUpdateParams { + session_id: mapping.client_session_id.clone(), + update: SessionUpdateVariant::SessionIdle {}, + event_type_override: None, + }; + + match inner + .sse_notifier + .broadcast(&mapping.client_id, idle_update) + { + Ok(receivers) => { + tracing::info!( + client_id = %mapping.client_id, + session_id = %mapping.client_session_id, + receivers = receivers, + "Sent SessionIdle after transfer to give control back" + ); + } + Err(_) => { + tracing::warn!( + client_id = %mapping.client_id, + "Failed to send SessionIdle after transfer" + ); + } + } + + } + } else { + tracing::warn!( + from_session = %from_session, + "No client mapping found for transferred session - session may have been created outside ACP Server" + ); + } + } + + /// Internal handler for forwarding panic (static method for use in process_event) + async fn handle_forwarding_panic_internal( + inner: &EventBridgeInner, + connector_id: &str, + session_id: &str, + reason: &str, + fallback_gateway_session: Option<&str>, + ) { + tracing::warn!( + connector_id = %connector_id, + session_id = %session_id, + reason = %reason, + "Handling forwarding panic" + ); + + // Find mapping by the failed connector's session + let mapping = inner.session_manager.get_mapping_by_internal_id(session_id); + + if let Some(mapping) = mapping { + // Determine fallback session + let fallback_session = fallback_gateway_session + .map(String::from) + .unwrap_or_else(|| format!("fallback-{}", uuid::Uuid::new_v4())); + + // Find a Gateway connector to fall back to + let gateway_connector = Self::find_gateway_connector_internal(inner).await; + + if let Some(gateway_id) = gateway_connector { + inner.session_manager.update_mapping_connector( + &mapping.client_session_id, + gateway_id.clone(), + Some("Gateway".to_string()), // Fallback to Gateway + fallback_session.clone(), + ); + + // Log the error - client notification would happen through the Gateway + // when they send their next message and see they're back on Gateway + tracing::warn!( + client_id = %mapping.client_id, + connector_id = %connector_id, + reason = %reason, + "Connector failed, client mapping reverted to Gateway" + ); + + tracing::info!( + client_id = %mapping.client_id, + fallback_connector = %gateway_id, + "Client reverted to Gateway after forwarding panic" + ); + } else { + tracing::error!("No Gateway connector available for fallback!"); + } + } + } + + /// Internal method to find an available Gateway connector + async fn find_gateway_connector_internal(inner: &EventBridgeInner) -> Option { + // Query CoreHandle for connectors + if let Some(ref core_handle) = inner.core_handle { + match core_handle.list_connectors().await { + Ok(connectors) => connectors + .into_iter() + .find(|c| c.connector_type.to_lowercase() == "gateway" && c.available) + .map(|c| c.id), + Err(e) => { + tracing::error!("Failed to list connectors: {:?}", e); + None + } + } + } else { + None + } + } + + /// Internal handler for agent requests (T024-T030, Phase 2) + /// + /// This method implements the complete bidirectional request/response flow: + /// 1. Translate internal session_id to client session_id + /// 2. Check ownership to determine routing (UI permission modal vs external client) + /// 3. Route to appropriate handler + async fn handle_agent_request_internal( + inner: &EventBridgeInner, + connector_id: &str, + session_id: &str, + request_id: serde_json::Value, + method: &str, + params: serde_json::Value, + ) { + tracing::info!( + connector_id = %connector_id, + session_id = %session_id, + method = %method, + request_id = %request_id, + "Handling agent request" + ); + + // Get mapping with ownership information + let mapping = match inner.session_manager.get_mapping_by_internal_id(session_id) { + Some(m) => m, + None => { + tracing::warn!( + connector_id = %connector_id, + session_id = %session_id, + "No client mapping found for agent request - cannot forward to client" + ); + return; + } + }; + + // Phase 2: Check if we should forward to external client + if let Some(forward_client_id) = mapping.ownership.forward_to_client() { + // Forward to external client with capabilities + Self::forward_to_external_client( + inner, + forward_client_id, + connector_id, + &mapping.client_session_id, + request_id, + method, + params, + ) + .await; + } else { + // Show permission modal in UI (current behavior) + Self::handle_permission_request_for_ui( + inner, + &mapping.client_id, + &mapping.client_session_id, + connector_id, + request_id, + method, + params, + ) + .await; + } + } + + /// Handle permission request for UI client (original flow) + /// + /// This method implements the UI permission modal flow: + /// 1. Register request with tracker + /// 2. Broadcast SSE notification to UI + /// 3. Wait for UI response (30s timeout) + /// 4. Send response back to connector + async fn handle_permission_request_for_ui( + inner: &EventBridgeInner, + client_id: &str, + client_session_id: &str, + connector_id: &str, + request_id: serde_json::Value, + method: &str, + params: serde_json::Value, + ) { + use crate::sse::{SessionUpdateParams, SessionUpdateVariant}; + + tracing::debug!( + client_id = %client_id, + client_session_id = %client_session_id, + method = %method, + "Routing agent request to UI permission modal" + ); + + // Register request with tracker and get receiver + let response_rx = inner + .agent_request_tracker + .register(client_id, request_id.clone()); + + tracing::debug!( + client_id = %client_id, + request_id = %request_id, + "Registered agent request with tracker" + ); + + // Create SessionUpdateParams with AgentRequest variant + // Translate sessionId in params from internal to client-facing + let mut translated_params = params.clone(); + if let Some(params_obj) = translated_params.as_object_mut() { + params_obj.insert( + "sessionId".to_string(), + serde_json::json!(client_session_id), + ); + } + + let update = SessionUpdateParams { + session_id: client_session_id.to_string(), + update: SessionUpdateVariant::AgentRequest { + request_id: request_id.clone(), + method: method.to_string(), + params: translated_params, + }, + event_type_override: None, + }; + + // Broadcast SSE event to UI client + match inner.sse_notifier.broadcast(client_id, update) { + Ok(receivers) => { + tracing::info!( + client_id = %client_id, + request_id = %request_id, + receivers = receivers, + "Broadcasted agent request via SSE to UI" + ); + } + Err(()) => { + tracing::warn!( + client_id = %client_id, + request_id = %request_id, + "Failed to broadcast agent request - UI not subscribed to SSE" + ); + // Clean up the tracker entry since we can't deliver the request + inner + .agent_request_tracker + .timeout(client_id, request_id.clone()); + return; + } + } + + // Wait for UI response with 30s timeout + let response = + match tokio::time::timeout(tokio::time::Duration::from_secs(30), response_rx).await { + Ok(Ok(response_value)) => { + tracing::info!( + client_id = %client_id, + request_id = %request_id, + "Received UI response for agent request" + ); + response_value + } + Ok(Err(_)) => { + tracing::warn!( + client_id = %client_id, + request_id = %request_id, + "Response channel dropped for agent request" + ); + return; + } + Err(_) => { + tracing::warn!( + client_id = %client_id, + request_id = %request_id, + "Agent request timed out after 30s - UI did not respond" + ); + // Clean up tracker + inner + .agent_request_tracker + .timeout(client_id, request_id.clone()); + return; + } + }; + + // Send response to connector via send_agent_response + if let Some(ref core_handle) = inner.core_handle { + match core_handle + .send_agent_response(connector_id, request_id.clone(), response) + .await + { + Ok(()) => { + tracing::info!( + connector_id = %connector_id, + request_id = %request_id, + "Sent agent response back to connector" + ); + } + Err(e) => { + tracing::error!( + connector_id = %connector_id, + request_id = %request_id, + error = %e, + "Failed to send agent response to connector" + ); + } + } + } else { + tracing::error!( + connector_id = %connector_id, + request_id = %request_id, + "No core handle available - cannot send agent response to connector" + ); + } + } + + /// Forward an agent request to an external client via SSE (Phase 2) + /// + /// This method implements external client forwarding: + /// 1. Register request with tracker + /// 2. Translate sessionId in params + /// 3. Broadcast SSE notification to external client + /// 4. Wait for external client response (60s timeout - longer than UI) + /// 5. Send response back to connector + async fn forward_to_external_client( + inner: &EventBridgeInner, + client_id: &str, + connector_id: &str, + client_session_id: &str, + request_id: serde_json::Value, + method: &str, + params: serde_json::Value, + ) { + use crate::sse::{SessionUpdateParams, SessionUpdateVariant}; + + tracing::info!( + client_id = %client_id, + method = %method, + "Forwarding agent request to external client" + ); + + // Register with tracker (reuse existing mechanism) + let response_rx = inner + .agent_request_tracker + .register(client_id, request_id.clone()); + + // Translate sessionId in params from internal to client-facing + let mut translated_params = params.clone(); + if let Some(params_obj) = translated_params.as_object_mut() { + params_obj.insert( + "sessionId".to_string(), + serde_json::json!(client_session_id), + ); + } + + // Create SSE notification (same format as permission requests) + let notification = SessionUpdateParams { + session_id: client_session_id.to_string(), + update: SessionUpdateVariant::AgentRequest { + request_id: request_id.clone(), + method: method.to_string(), + params: translated_params, + }, + event_type_override: None, + }; + + // Broadcast to external client + if inner + .sse_notifier + .broadcast(client_id, notification) + .is_err() + { + tracing::warn!(client_id = %client_id, "External client not subscribed to SSE"); + inner.agent_request_tracker.timeout(client_id, request_id); + return; + } + + // Wait for response with longer timeout for external clients (60s) + let response = match tokio::time::timeout(tokio::time::Duration::from_secs(60), response_rx) + .await + { + Ok(Ok(response)) => { + tracing::info!(client_id = %client_id, "Received response from external client"); + response + } + Ok(Err(_)) => { + tracing::warn!(client_id = %client_id, "External client response channel dropped"); + return; + } + Err(_) => { + tracing::warn!(client_id = %client_id, "External client request timed out after 60s"); + inner + .agent_request_tracker + .timeout(client_id, request_id.clone()); + return; + } + }; + + // Send back to connector + if let Some(ref core_handle) = inner.core_handle { + if let Err(e) = core_handle + .send_agent_response(connector_id, request_id.clone(), response) + .await + { + tracing::error!(connector_id = %connector_id, error = %e, "Failed to send response to connector"); + } + } + } + + /// Stop the event bridge + /// + /// Signals the background task to shut down gracefully. The task will + /// complete processing any current event and then exit. + pub fn stop(&self) { + info!("Stopping event bridge"); + self.inner.shutdown.store(true, Ordering::SeqCst); + + // Also send via channel for faster shutdown + if let Some(tx) = &self.shutdown_tx { + let _ = tx.try_send(()); + } + } + + /// Check if the event bridge is running + /// + /// Returns `true` if the bridge has been started and has not been stopped. + /// Note that this doesn't guarantee the background task is still alive - + /// use the `JoinHandle` returned by `start()` for that. + pub fn is_running(&self) -> bool { + !self.inner.shutdown.load(Ordering::SeqCst) + } + + /// Get a reference to the session manager + pub fn session_manager(&self) -> &SessionManager { + &self.inner.session_manager + } + + /// Get a reference to the SSE notifier + pub fn sse_notifier(&self) -> &SseNotifier { + &self.inner.sse_notifier + } +} + +/// Target for an event notification +enum EventTarget { + /// Send to a specific client + Client(String), + + /// Broadcast to all clients (future use for system-wide notifications) + #[allow(dead_code)] + AllClients, + + /// No target found, drop the event + NoTarget, +} + +impl std::fmt::Debug for EventBridge { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("EventBridge") + .field("running", &self.is_running()) + .field("config", &self.inner.config) + .finish() + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use dirigent_protocol::{ContentBlock, Message, MessageRole, MessageStatus, SessionUpdate}; + use tokio_stream::StreamExt; + + fn create_test_bridge() -> EventBridge { + let session_manager = SessionManager::new(); + let sse_notifier = SseNotifier::new(); + let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new()); + EventBridge::new(session_manager, sse_notifier, agent_request_tracker) + } + + #[test] + fn test_event_bridge_new() { + let bridge = create_test_bridge(); + assert!(bridge.is_running()); // Not started, but not stopped either + } + + #[test] + fn test_event_bridge_with_config() { + let session_manager = SessionManager::new(); + let sse_notifier = SseNotifier::new(); + let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new()); + let config = EventBridgeConfig { + broadcast_system_errors: false, + log_unmapped_events: false, + }; + + let bridge = + EventBridge::with_config(session_manager, sse_notifier, agent_request_tracker, config); + assert!(!bridge.inner.config.broadcast_system_errors); + assert!(!bridge.inner.config.log_unmapped_events); + } + + #[test] + fn test_event_bridge_stop() { + let bridge = create_test_bridge(); + assert!(bridge.is_running()); + + bridge.stop(); + assert!(!bridge.is_running()); + } + + #[tokio::test] + async fn test_event_bridge_start_and_stop() { + let session_manager = SessionManager::new(); + let sse_notifier = SseNotifier::new(); + let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new()); + let mut bridge = EventBridge::new(session_manager, sse_notifier, agent_request_tracker); + + // Create an empty stream that will end immediately + let stream = tokio_stream::empty::(); + + let handle = bridge.start(stream); + + // Wait for the task to complete (stream ends immediately) + tokio::time::timeout(tokio::time::Duration::from_secs(1), handle) + .await + .expect("Task should complete quickly") + .expect("Task should not panic"); + } + + #[tokio::test] + async fn test_event_bridge_with_events() { + let session_manager = SessionManager::new(); + let sse_notifier = SseNotifier::new(); + + // Register a client and create a session + let client_id = session_manager.register_client(None); + let mapping = session_manager.create_mapping( + &client_id, + Some("test-session".to_string()), + "connector-1".to_string(), + ); + + // Subscribe to SSE + let mut sse_stream = sse_notifier.subscribe(&client_id); + + let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new()); + let mut bridge = EventBridge::new(session_manager, sse_notifier, agent_request_tracker); + + // Create a stream with one event + let event = Event::SessionIdle { + connector_id: "test-connector".to_string(), + session_id: mapping.internal_session_id.clone(), + }; + let stream = tokio_stream::iter(vec![event]); + + let handle = bridge.start(stream); + + // Wait for the event to be processed + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Check that the client received the notification + let notification = + tokio::time::timeout(tokio::time::Duration::from_millis(500), sse_stream.next()).await; + + // The notification should have been received + assert!(notification.is_ok()); + let notification = notification.unwrap(); + assert!(notification.is_some()); + + let notification = notification.unwrap(); + assert!(notification.is_ok()); + + let update_params = notification.unwrap(); + // The session_id in the update params should be the CLIENT session ID + // because get_client_for_event_and_translate() translates internal IDs to client IDs + assert_eq!(update_params.session_id, mapping.client_session_id); + + // Verify it's a SessionIdle update + match update_params.update { + crate::sse::SessionUpdateVariant::SessionIdle {} => { + // Expected + } + other => panic!("Expected SessionIdle variant, got {:?}", other), + } + + // Stop and wait for completion + bridge.stop(); + let _ = tokio::time::timeout(tokio::time::Duration::from_secs(1), handle).await; + } + + #[tokio::test] + async fn test_event_bridge_filtered_events() { + let session_manager = SessionManager::new(); + let sse_notifier = SseNotifier::new(); + + // Register a client + let client_id = session_manager.register_client(None); + + // Subscribe to SSE + let mut sse_stream = sse_notifier.subscribe(&client_id); + + let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new()); + let mut bridge = EventBridge::new(session_manager, sse_notifier, agent_request_tracker); + + // Create events that should be filtered (not forwarded) + let events = vec![ + Event::Error { + message: "System-wide error".to_string(), + }, + Event::MessageFailed { + message_id: "msg-1".to_string(), + error: "Failed".to_string(), + }, + Event::Connected, + ]; + let stream = tokio_stream::iter(events); + + let _handle = bridge.start(stream); + + // Wait for processing + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // No notifications should be received (filtered events) + let result = + tokio::time::timeout(tokio::time::Duration::from_millis(200), sse_stream.next()).await; + + // Timeout expected since no events should be forwarded + assert!( + result.is_err(), + "Expected timeout since events should be filtered" + ); + + bridge.stop(); + } + + #[test] + fn test_event_bridge_debug() { + let bridge = create_test_bridge(); + let debug_str = format!("{:?}", bridge); + + assert!(debug_str.contains("EventBridge")); + assert!(debug_str.contains("running")); + assert!(debug_str.contains("config")); + } + + #[test] + fn test_event_bridge_clone() { + let bridge1 = create_test_bridge(); + let bridge2 = bridge1.clone(); + + // Both should share the same inner state + bridge1.stop(); + assert!(!bridge2.is_running()); + } + + #[test] + fn test_extract_internal_session_id() { + // SessionUpdate + let event = Event::SessionUpdate { + connector_id: "conn".to_string(), + session_id: "sess-123".to_string(), + update: SessionUpdate::AgentMessageChunk { + message_id: "msg-1".to_string(), + content: ContentBlock::Text { + text: "Hello".to_string(), + }, + _meta: None, + }, + }; + assert_eq!( + EventBridge::extract_internal_session_id(&event), + Some("sess-123".to_string()) + ); + + // SessionIdle + let event = Event::SessionIdle { + connector_id: "test-connector".to_string(), + session_id: "sess-456".to_string(), + }; + assert_eq!( + EventBridge::extract_internal_session_id(&event), + Some("sess-456".to_string()) + ); + + // MessageCompleted + let event = Event::MessageCompleted { + connector_id: "conn".to_string(), + message: Message { + id: "msg-1".to_string(), + session_id: "sess-789".to_string(), + role: MessageRole::Assistant, + created_at: chrono::Utc::now(), + content: vec![], + status: MessageStatus::Completed, + metadata: None, + }, + }; + assert_eq!( + EventBridge::extract_internal_session_id(&event), + Some("sess-789".to_string()) + ); + + // Error (no session) + let event = Event::Error { + message: "error".to_string(), + }; + assert_eq!(EventBridge::extract_internal_session_id(&event), None); + } + + #[test] + fn test_event_bridge_config_default() { + let config = EventBridgeConfig::default(); + assert!(config.broadcast_system_errors); + assert!(config.log_unmapped_events); + } + + #[tokio::test] + async fn test_handle_session_transferred() { + let session_manager = SessionManager::new(); + let sse_notifier = SseNotifier::new(); + let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new()); + let bridge = EventBridge::new(session_manager.clone(), sse_notifier, agent_request_tracker); + + // Setup initial mapping + let client_id = session_manager.register_client(None); + session_manager.create_mapping( + &client_id, + Some("client-session-1".to_string()), + "gateway-1".to_string(), + ); + + // Get the internal session ID (which represents the gateway session) + let mapping = session_manager.get_mapping("client-session-1").unwrap(); + let gateway_session_id = mapping.internal_session_id.clone(); + + // Simulate SessionTransferred event via internal handler + EventBridge::handle_session_transferred_internal( + &bridge.inner, + "gateway-1", + &gateway_session_id, + "opencode-1", + "opencode-session-new", + true, + None, + None, + ) + .await; + + // Verify mapping updated + let updated_mapping = session_manager.get_mapping("client-session-1").unwrap(); + assert_eq!(updated_mapping.connector_id, "opencode-1"); + assert_eq!(updated_mapping.internal_session_id, "opencode-session-new"); + assert_eq!(updated_mapping.client_session_id, "client-session-1"); // Should remain unchanged + } + + #[tokio::test] + async fn test_handle_session_transferred_no_mapping() { + let session_manager = SessionManager::new(); + let sse_notifier = SseNotifier::new(); + let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new()); + let bridge = EventBridge::new(session_manager, sse_notifier, agent_request_tracker); + + // Try to transfer a session that doesn't have a client mapping + // This should log a warning but not panic + EventBridge::handle_session_transferred_internal( + &bridge.inner, + "gateway-1", + "nonexistent-session", + "opencode-1", + "opencode-session-new", + true, + None, + None, + ) + .await; + + // Test passes if no panic occurred + } + + #[tokio::test] + async fn test_process_event_session_transferred() { + let session_manager = SessionManager::new(); + let sse_notifier = SseNotifier::new(); + let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new()); + let mut bridge = + EventBridge::new(session_manager.clone(), sse_notifier, agent_request_tracker); + + // Setup initial mapping + let client_id = session_manager.register_client(None); + session_manager.create_mapping( + &client_id, + Some("client-session-1".to_string()), + "gateway-1".to_string(), + ); + + let mapping = session_manager.get_mapping("client-session-1").unwrap(); + let gateway_session_id = mapping.internal_session_id.clone(); + + // Create a SessionTransferred event + let event = Event::SessionTransferred { + from_connector: "gateway-1".to_string(), + from_session: gateway_session_id, + to_connector: "opencode-1".to_string(), + to_session: "opencode-session-new".to_string(), + is_new_session: true, + models: None, + modes: None, + }; + + let stream = tokio_stream::iter(vec![event]); + let handle = bridge.start(stream); + + // Wait for processing + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Verify mapping was updated + let updated_mapping = session_manager.get_mapping("client-session-1").unwrap(); + assert_eq!(updated_mapping.connector_id, "opencode-1"); + assert_eq!(updated_mapping.internal_session_id, "opencode-session-new"); + + bridge.stop(); + let _ = tokio::time::timeout(tokio::time::Duration::from_secs(1), handle).await; + } + + #[tokio::test] + async fn test_handle_agent_request_forwards_to_external_client() { + use dirigent_protocol::SessionOwnership; + + // Setup with external ownership + let session_manager = SessionManager::new(); + let sse_notifier = SseNotifier::new(); + let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new()); + + // Register external client with capabilities + let client_capabilities = + serde_json::json!({"fs": {"readTextFile": true}, "terminal": true}); + let client_id = session_manager.register_client(Some(client_capabilities.clone())); + + // Create mapping + let mapping = session_manager.create_mapping( + &client_id, + Some("client-session-1".to_string()), + "connector-1".to_string(), + ); + + // Update ownership to external with forwarding + let ownership = + SessionOwnership::external_forwarded(client_id.clone(), Some(client_capabilities)); + session_manager.update_mapping_ownership("client-session-1", ownership); + + // Subscribe external client to SSE + let mut sse_stream = sse_notifier.subscribe(&client_id); + + // Create event bridge + let bridge = EventBridge::new( + session_manager.clone(), + sse_notifier.clone(), + agent_request_tracker.clone(), + ); + + // Simulate an agent request + tokio::spawn({ + let bridge_inner = bridge.inner.clone(); + let internal_session_id = mapping.internal_session_id.clone(); + async move { + EventBridge::handle_agent_request_internal( + &bridge_inner, + "connector-1", + &internal_session_id, + serde_json::json!("req-123"), + "fs/readTextFile", + serde_json::json!({"path": "/test.txt"}), + ) + .await; + } + }); + + // External client should receive the request + let notification = + tokio::time::timeout(tokio::time::Duration::from_secs(1), sse_stream.next()) + .await + .expect("Should receive notification") + .expect("Should have notification") + .expect("Should parse notification"); + + // Verify the request was forwarded + assert_eq!(notification.session_id, "client-session-1"); + match notification.update { + crate::sse::SessionUpdateVariant::AgentRequest { + request_id, + method, + params, + } => { + assert_eq!(request_id, serde_json::json!("req-123")); + assert_eq!(method, "fs/readTextFile"); + assert_eq!(params["sessionId"], serde_json::json!("client-session-1")); + assert_eq!(params["path"], serde_json::json!("/test.txt")); + } + _ => panic!("Expected AgentRequest variant"), + } + + // Note: Full roundtrip testing (client response -> connector) would require + // a mock CoreHandle implementation, which is out of scope for this unit test. + // Integration tests should verify the complete flow. + } + +} diff --git a/crates/dirigent_acp_api/src/jsonrpc.rs b/crates/dirigent_acp_api/src/jsonrpc.rs new file mode 100644 index 0000000..31e2824 --- /dev/null +++ b/crates/dirigent_acp_api/src/jsonrpc.rs @@ -0,0 +1,460 @@ +//! JSON-RPC 2.0 types for the ACP Server +//! +//! This module implements JSON-RPC 2.0 request/response types according to +//! the specification at https://www.jsonrpc.org/specification. +//! +//! Key features: +//! - Support for both numeric and string IDs +//! - Batch request/response handling +//! - Proper serialization of null vs missing fields + +use serde::{Deserialize, Serialize}; + +use crate::error::JsonRpcErrorObject; + +/// JSON-RPC protocol version constant +pub const JSONRPC_VERSION: &str = "2.0"; + +/// JSON-RPC request/response identifier +/// +/// According to the JSON-RPC 2.0 spec, an id can be a String, Number, +/// or Null. This type uses an untagged enum to handle both string and +/// number identifiers. +/// +/// Note: The spec recommends not using Null as an id for requests. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(untagged)] +pub enum JsonRpcId { + /// Numeric identifier (integer) + Number(i64), + + /// String identifier + String(String), + + /// Null identifier (typically used in error responses for invalid requests) + Null, +} + +impl From for JsonRpcId { + fn from(n: i64) -> Self { + JsonRpcId::Number(n) + } +} + +impl From for JsonRpcId { + fn from(s: String) -> Self { + JsonRpcId::String(s) + } +} + +impl From<&str> for JsonRpcId { + fn from(s: &str) -> Self { + JsonRpcId::String(s.to_string()) + } +} + +/// A JSON-RPC 2.0 request object +/// +/// Represents a remote procedure call with optional parameters. +/// The `id` field determines whether this is a request (with id) or +/// notification (without id). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcRequest { + /// JSON-RPC protocol version (must be "2.0") + pub jsonrpc: String, + + /// A String containing the name of the method to be invoked + pub method: String, + + /// Optional structured value that holds the parameter values + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, + + /// Optional identifier established by the client + /// + /// If absent, the request is a notification (no response expected) + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, +} + +impl JsonRpcRequest { + /// Create a new JSON-RPC request + pub fn new(method: impl Into, params: Option, id: JsonRpcId) -> Self { + Self { + jsonrpc: JSONRPC_VERSION.to_string(), + method: method.into(), + params, + id: Some(id), + } + } + + /// Create a new JSON-RPC notification (request without id) + pub fn notification(method: impl Into, params: Option) -> Self { + Self { + jsonrpc: JSONRPC_VERSION.to_string(), + method: method.into(), + params, + id: None, + } + } + + /// Check if this request is a notification (no id) + pub fn is_notification(&self) -> bool { + self.id.is_none() + } + + /// Validate the request format + pub fn validate(&self) -> Result<(), String> { + if self.jsonrpc != JSONRPC_VERSION { + return Err(format!( + "Invalid JSON-RPC version: expected '{}', got '{}'", + JSONRPC_VERSION, self.jsonrpc + )); + } + + if self.method.is_empty() { + return Err("Method name cannot be empty".to_string()); + } + + // Methods starting with "rpc." are reserved for internal use + if self.method.starts_with("rpc.") { + return Err(format!( + "Method name '{}' is reserved (starts with 'rpc.')", + self.method + )); + } + + Ok(()) + } +} + +/// A JSON-RPC 2.0 response object +/// +/// Contains either a result (success) or an error (failure), never both. +/// The id must match the corresponding request id. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcResponse { + /// JSON-RPC protocol version (must be "2.0") + pub jsonrpc: String, + + /// The result of the call (on success) + /// + /// This member is REQUIRED on success and MUST NOT exist on error. + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + + /// The error object (on failure) + /// + /// This member is REQUIRED on error and MUST NOT exist on success. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + + /// The identifier matching the request + /// + /// If there was an error detecting the id in the Request object + /// (e.g. Parse error/Invalid Request), it MUST be Null. + pub id: JsonRpcId, +} + +impl JsonRpcResponse { + /// Create a successful response + pub fn success(result: serde_json::Value, id: JsonRpcId) -> Self { + Self { + jsonrpc: JSONRPC_VERSION.to_string(), + result: Some(result), + error: None, + id, + } + } + + /// Create an error response + pub fn error(error: JsonRpcErrorObject, id: JsonRpcId) -> Self { + Self { + jsonrpc: JSONRPC_VERSION.to_string(), + result: None, + error: Some(error), + id, + } + } + + /// Create an error response with a null id (for parse errors) + pub fn error_with_null_id(error: JsonRpcErrorObject) -> Self { + Self::error(error, JsonRpcId::Null) + } + + /// Check if this response represents success + pub fn is_success(&self) -> bool { + self.result.is_some() && self.error.is_none() + } + + /// Check if this response represents an error + pub fn is_error(&self) -> bool { + self.error.is_some() + } +} + +/// Represents either a single request or a batch of requests +/// +/// The JSON-RPC 2.0 spec allows sending multiple requests in a single +/// JSON array for batch processing. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum JsonRpcRequestBatch { + /// A single request + Single(JsonRpcRequest), + + /// A batch of requests + Batch(Vec), +} + +impl JsonRpcRequestBatch { + /// Check if this is an empty batch + pub fn is_empty(&self) -> bool { + match self { + JsonRpcRequestBatch::Single(_) => false, + JsonRpcRequestBatch::Batch(batch) => batch.is_empty(), + } + } + + /// Get the number of requests + pub fn len(&self) -> usize { + match self { + JsonRpcRequestBatch::Single(_) => 1, + JsonRpcRequestBatch::Batch(batch) => batch.len(), + } + } + + /// Convert to a vector of requests + pub fn into_vec(self) -> Vec { + match self { + JsonRpcRequestBatch::Single(req) => vec![req], + JsonRpcRequestBatch::Batch(batch) => batch, + } + } +} + +/// Represents either a single response or a batch of responses +/// +/// The response format must match the request format: single request +/// gets single response, batch request gets batch response. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum JsonRpcResponseBatch { + /// A single response + Single(JsonRpcResponse), + + /// A batch of responses + Batch(Vec), +} + +impl JsonRpcResponseBatch { + /// Create a batch response from a vector + /// + /// Returns Single if there's exactly one response, otherwise Batch. + pub fn from_vec(responses: Vec) -> Self { + if responses.len() == 1 { + JsonRpcResponseBatch::Single(responses.into_iter().next().unwrap()) + } else { + JsonRpcResponseBatch::Batch(responses) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_jsonrpc_id_number() { + let id = JsonRpcId::Number(42); + let json = serde_json::to_string(&id).unwrap(); + assert_eq!(json, "42"); + + let parsed: JsonRpcId = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, id); + } + + #[test] + fn test_jsonrpc_id_string() { + let id = JsonRpcId::String("abc-123".to_string()); + let json = serde_json::to_string(&id).unwrap(); + assert_eq!(json, "\"abc-123\""); + + let parsed: JsonRpcId = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, id); + } + + #[test] + fn test_jsonrpc_id_null() { + let id = JsonRpcId::Null; + let json = serde_json::to_string(&id).unwrap(); + assert_eq!(json, "null"); + + let parsed: JsonRpcId = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, id); + } + + #[test] + fn test_jsonrpc_id_from_conversions() { + let id: JsonRpcId = 42i64.into(); + assert_eq!(id, JsonRpcId::Number(42)); + + let id: JsonRpcId = "test".into(); + assert_eq!(id, JsonRpcId::String("test".to_string())); + + let id: JsonRpcId = String::from("owned").into(); + assert_eq!(id, JsonRpcId::String("owned".to_string())); + } + + #[test] + fn test_request_creation() { + let req = JsonRpcRequest::new("session.new", Some(json!({"title": "Test"})), 1.into()); + assert_eq!(req.jsonrpc, "2.0"); + assert_eq!(req.method, "session.new"); + assert!(req.params.is_some()); + assert_eq!(req.id, Some(JsonRpcId::Number(1))); + assert!(!req.is_notification()); + } + + #[test] + fn test_notification_creation() { + let notif = JsonRpcRequest::notification("event.ping", None); + assert!(notif.is_notification()); + assert_eq!(notif.id, None); + } + + #[test] + fn test_request_validation() { + let valid = JsonRpcRequest::new("test.method", None, 1.into()); + assert!(valid.validate().is_ok()); + + // Invalid version + let mut invalid_version = valid.clone(); + invalid_version.jsonrpc = "1.0".to_string(); + assert!(invalid_version.validate().is_err()); + + // Empty method + let mut empty_method = valid.clone(); + empty_method.method = String::new(); + assert!(empty_method.validate().is_err()); + + // Reserved method + let mut reserved = valid.clone(); + reserved.method = "rpc.internal".to_string(); + assert!(reserved.validate().is_err()); + } + + #[test] + fn test_request_serialization() { + let req = JsonRpcRequest::new( + "session.prompt", + Some(json!({"session_id": "abc", "content": "Hello"})), + "req-123".into(), + ); + + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"jsonrpc\":\"2.0\"")); + assert!(json.contains("\"method\":\"session.prompt\"")); + assert!(json.contains("\"id\":\"req-123\"")); + + // Deserialize back + let parsed: JsonRpcRequest = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.method, "session.prompt"); + } + + #[test] + fn test_response_success() { + let resp = JsonRpcResponse::success(json!({"session_id": "new-123"}), 1.into()); + assert!(resp.is_success()); + assert!(!resp.is_error()); + assert_eq!(resp.result, Some(json!({"session_id": "new-123"}))); + assert!(resp.error.is_none()); + } + + #[test] + fn test_response_error() { + let error = JsonRpcErrorObject::method_not_found("unknown.method"); + let resp = JsonRpcResponse::error(error, 1.into()); + assert!(!resp.is_success()); + assert!(resp.is_error()); + assert!(resp.result.is_none()); + assert!(resp.error.is_some()); + } + + #[test] + fn test_response_serialization() { + // Success response + let success = JsonRpcResponse::success(json!({"ok": true}), 42.into()); + let json = serde_json::to_string(&success).unwrap(); + assert!(json.contains("\"result\"")); + assert!(!json.contains("\"error\"")); + + // Error response + let error = JsonRpcResponse::error( + JsonRpcErrorObject::internal_error("Something broke"), + 42.into(), + ); + let json = serde_json::to_string(&error).unwrap(); + assert!(!json.contains("\"result\"")); + assert!(json.contains("\"error\"")); + } + + #[test] + fn test_batch_request_single() { + let req = JsonRpcRequest::new("test", None, 1.into()); + let batch = JsonRpcRequestBatch::Single(req); + assert_eq!(batch.len(), 1); + assert!(!batch.is_empty()); + + let vec = batch.into_vec(); + assert_eq!(vec.len(), 1); + } + + #[test] + fn test_batch_request_multiple() { + let req1 = JsonRpcRequest::new("test1", None, 1.into()); + let req2 = JsonRpcRequest::new("test2", None, 2.into()); + let batch = JsonRpcRequestBatch::Batch(vec![req1, req2]); + assert_eq!(batch.len(), 2); + + let vec = batch.into_vec(); + assert_eq!(vec.len(), 2); + } + + #[test] + fn test_batch_request_empty() { + let batch = JsonRpcRequestBatch::Batch(vec![]); + assert!(batch.is_empty()); + assert_eq!(batch.len(), 0); + } + + #[test] + fn test_batch_request_deserialization() { + // Single request + let single_json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#; + let single: JsonRpcRequestBatch = serde_json::from_str(single_json).unwrap(); + assert_eq!(single.len(), 1); + + // Batch request + let batch_json = r#"[{"jsonrpc":"2.0","method":"test1","id":1},{"jsonrpc":"2.0","method":"test2","id":2}]"#; + let batch: JsonRpcRequestBatch = serde_json::from_str(batch_json).unwrap(); + assert_eq!(batch.len(), 2); + } + + #[test] + fn test_batch_response_from_vec() { + // Single response becomes Single variant + let responses = vec![JsonRpcResponse::success(json!(null), 1.into())]; + let batch = JsonRpcResponseBatch::from_vec(responses); + matches!(batch, JsonRpcResponseBatch::Single(_)); + + // Multiple responses become Batch variant + let responses = vec![ + JsonRpcResponse::success(json!(null), 1.into()), + JsonRpcResponse::success(json!(null), 2.into()), + ]; + let batch = JsonRpcResponseBatch::from_vec(responses); + matches!(batch, JsonRpcResponseBatch::Batch(_)); + } +} diff --git a/crates/dirigent_acp_api/src/lib.rs b/crates/dirigent_acp_api/src/lib.rs new file mode 100644 index 0000000..127639a --- /dev/null +++ b/crates/dirigent_acp_api/src/lib.rs @@ -0,0 +1,116 @@ +//! Dirigent ACP API +//! +//! This crate exposes an ACP (Agent-Client Protocol) API for Dirigent, +//! allowing other ACP clients to interact with Dirigent. +//! +//! ## Modules +//! +//! - [`config`] - Server configuration types +//! - [`error`] - Error types and JSON-RPC error conversion +//! - [`event_bridge`] - Event forwarding from source streams to SSE clients +//! - [`jsonrpc`] - JSON-RPC 2.0 request/response types +//! - [`router`] - Axum router and HTTP handlers +//! - [`rpc`] - JSON-RPC request handler and method dispatch +//! - [`session_manager`] - Session mapping and client connection tracking +//! - [`sse`] - SSE notifications for streaming events to clients +//! +//! ## Example +//! +//! ```rust,ignore +//! use dirigent_acp_api::{ +//! AcpServerConfig, NoOpConnectorOperations, +//! router::{AcpServerState, create_acp_server_router}, +//! }; +//! +//! // Create server configuration +//! let config = AcpServerConfig::enabled() +//! .set_port(3001) +//! .set_max_connections(100); +//! +//! // Create server state +//! let state = AcpServerState::new(config); +//! +//! // Create the router with connector operations +//! let router = create_acp_server_router(state, NoOpConnectorOperations); +//! +//! // Run with axum +//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await?; +//! axum::serve(listener, router).await?; +//! ``` + +// Modules +pub mod agent_requests; +pub mod config; +pub mod error; +pub mod event_bridge; +pub mod jsonrpc; +pub mod router; +pub mod rpc; +pub mod session_manager; +pub mod sse; + +// Re-exports for convenience +pub use agent_requests::AgentRequestTracker; +pub use config::AcpServerConfig; +pub use error::{AcpServerError, JsonRpcErrorObject}; +pub use event_bridge::{EventBridge, EventBridgeConfig}; +pub use jsonrpc::{ + JsonRpcId, JsonRpcRequest, JsonRpcRequestBatch, JsonRpcResponse, JsonRpcResponseBatch, + JSONRPC_VERSION, +}; +pub use router::{create_acp_server_router, AcpServerState, RouterState}; +pub use rpc::{ + ConnectorInfo, ConnectorOperations, NoOpConnectorOperations, RpcHandler, SessionInfo, + ACP_PROTOCOL_VERSION, SERVER_NAME, +}; +pub use session_manager::{ClientConnection, ClientInfo, SessionManager, SessionMapping}; +pub use sse::{AcpNotification, SseNotifier, translate_event}; + +use axum::{response::Json, routing::get, Router}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct ApiInfo { + pub name: String, + pub version: String, + pub description: String, +} + +/// Create the ACP API router +pub fn create_api_router() -> Router { + Router::new() + .route("/", get(api_info)) + .route("/health", get(health_check)) +} + +/// Get API information +async fn api_info() -> Json { + Json(ApiInfo { + name: "dirigent_acp_api".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + description: "ACP API for Dirigent - orchestrates agentic clients".to_string(), + }) +} + +/// Health check endpoint +async fn health_check() -> Json { + Json(serde_json::json!({ + "status": "healthy", + "message": "Dirigent ACP API is running" + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_api_info() { + let info = ApiInfo { + name: "test".to_string(), + version: "0.1.0".to_string(), + description: "test description".to_string(), + }; + assert_eq!(info.name, "test"); + } +} diff --git a/crates/dirigent_acp_api/src/router.rs b/crates/dirigent_acp_api/src/router.rs new file mode 100644 index 0000000..64b1483 --- /dev/null +++ b/crates/dirigent_acp_api/src/router.rs @@ -0,0 +1,950 @@ +//! Axum Router Integration for ACP Server +//! +//! This module provides the Axum router and HTTP handlers for the ACP Server. +//! It exposes a JSON-RPC endpoint at `/rpc` and an SSE events endpoint at `/events`. +//! +//! ## Architecture +//! +//! The router is generic over `ConnectorOperations`, allowing different implementations +//! to be used (e.g., one backed by `CoreHandle` for production, or `NoOpConnectorOperations` +//! for testing). +//! +//! ## Endpoints +//! +//! - `POST /rpc` - JSON-RPC 2.0 endpoint for method calls +//! - `GET /events?client_id=...` - SSE stream for real-time notifications +//! - `GET /health` - Health check endpoint +//! +//! ## Example +//! +//! ```rust,ignore +//! use dirigent_acp_api::router::{AcpServerState, create_acp_server_router}; +//! use dirigent_acp_api::{AcpServerConfig, NoOpConnectorOperations}; +//! +//! // Create state +//! let state = AcpServerState::new(AcpServerConfig::enabled()); +//! +//! // Create router with no-op operations for testing +//! let router = create_acp_server_router(state, NoOpConnectorOperations); +//! ``` + +use std::convert::Infallible; +use std::sync::Arc; +use std::time::Instant; + +use axum::{ + extract::{Query, State}, + http::{HeaderMap, StatusCode}, + response::{ + sse::{Event, KeepAlive, Sse}, + IntoResponse, Json, + }, + routing::{get, post}, + Router, +}; +use serde::{Deserialize, Serialize}; +use tokio_stream::StreamExt; +use tower_http::cors::{Any, CorsLayer}; +use tracing::{debug, info, trace, warn}; + +use crate::config::AcpServerConfig; +use crate::rpc::{ConnectorOperations, RpcHandler}; +use crate::session_manager::SessionManager; +use crate::sse::SseNotifier; + +// ============================================================================ +// AcpServerState (T031) +// ============================================================================ + +/// Internal state shared across handlers +struct AcpServerStateInner { + /// Session manager for tracking client sessions + session_manager: SessionManager, + + /// SSE notifier for broadcasting events to clients + sse_notifier: SseNotifier, + + /// Agent request tracker for bidirectional request/response + agent_request_tracker: Arc, + + /// Server configuration + config: AcpServerConfig, +} + +/// Shared state for the ACP Server Axum handlers +/// +/// This struct contains all the state needed by the HTTP handlers: +/// - Session manager for tracking client sessions and mappings +/// - SSE notifier for broadcasting events to connected clients +/// - Server configuration +/// +/// The state is wrapped in `Arc` internally, making it cheap to clone and +/// share across async tasks and handlers. +/// +/// Note: `RpcHandler` is created per-request with `ConnectorOperations` passed in, +/// as the connector operations implementation cannot be stored in the shared state +/// (it may contain non-Clone types like `CoreHandle`). +#[derive(Clone)] +pub struct AcpServerState { + inner: Arc, +} + +impl AcpServerState { + /// Create a new ACP server state with the given configuration + /// + /// # Parameters + /// + /// - `config`: The server configuration + /// + /// # Example + /// + /// ```rust + /// use dirigent_acp_api::router::AcpServerState; + /// use dirigent_acp_api::AcpServerConfig; + /// + /// let state = AcpServerState::new(AcpServerConfig::enabled()); + /// ``` + pub fn new(config: AcpServerConfig) -> Self { + Self { + inner: Arc::new(AcpServerStateInner { + session_manager: SessionManager::new(), + sse_notifier: SseNotifier::new(), + agent_request_tracker: Arc::new(crate::agent_requests::AgentRequestTracker::new()), + config, + }), + } + } + + /// Create a new ACP server state with custom components + /// + /// This is useful for testing or when you need to share a session manager + /// or SSE notifier with other parts of the application. + /// + /// # Parameters + /// + /// - `session_manager`: The session manager instance + /// - `sse_notifier`: The SSE notifier instance + /// - `agent_request_tracker`: The agent request tracker instance + /// - `config`: The server configuration + pub fn with_components( + session_manager: SessionManager, + sse_notifier: SseNotifier, + agent_request_tracker: Arc, + config: AcpServerConfig, + ) -> Self { + Self { + inner: Arc::new(AcpServerStateInner { + session_manager, + sse_notifier, + agent_request_tracker, + config, + }), + } + } + + /// Get a reference to the session manager + pub fn session_manager(&self) -> &SessionManager { + &self.inner.session_manager + } + + /// Get a reference to the SSE notifier + pub fn sse_notifier(&self) -> &SseNotifier { + &self.inner.sse_notifier + } + + /// Get a reference to the agent request tracker + pub fn agent_request_tracker(&self) -> &Arc { + &self.inner.agent_request_tracker + } + + /// Get a reference to the configuration + pub fn config(&self) -> &AcpServerConfig { + &self.inner.config + } +} + +impl std::fmt::Debug for AcpServerState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AcpServerState") + .field("session_count", &self.inner.session_manager.mapping_count()) + .field("client_count", &self.inner.sse_notifier.client_count()) + .field("config", &self.inner.config) + .finish() + } +} + +// ============================================================================ +// Router State with ConnectorOperations (T028) +// ============================================================================ + +/// Combined state for Axum handlers that includes both server state and connector operations +/// +/// This struct combines the shared `AcpServerState` with a `ConnectorOperations` +/// implementation. It's used as the Axum state to provide handlers with access +/// to both session management and connector functionality. +#[derive(Clone)] +pub struct RouterState { + /// The ACP server state (session manager, SSE notifier, config) + pub state: AcpServerState, + + /// The connector operations implementation + pub connector_ops: C, +} + +impl RouterState { + /// Create a new router state + pub fn new(state: AcpServerState, connector_ops: C) -> Self { + Self { + state, + connector_ops, + } + } +} + +// ============================================================================ +// Router Factory (T028) +// ============================================================================ + +/// Create the ACP server router with all endpoints configured +/// +/// This function creates an Axum router with the following endpoints: +/// - `POST /rpc` - JSON-RPC 2.0 endpoint +/// - `GET /events` - SSE events stream +/// - `GET /health` - Health check +/// +/// CORS middleware is applied based on the configuration. +/// +/// # Type Parameters +/// +/// - `C`: The connector operations implementation +/// +/// # Parameters +/// +/// - `state`: The ACP server state +/// - `connector_ops`: The connector operations implementation +/// +/// # Example +/// +/// ```rust,ignore +/// use dirigent_acp_api::router::{AcpServerState, create_acp_server_router}; +/// use dirigent_acp_api::{AcpServerConfig, NoOpConnectorOperations}; +/// +/// let state = AcpServerState::new(AcpServerConfig::enabled()); +/// let router = create_acp_server_router(state, NoOpConnectorOperations); +/// +/// // Use with axum server +/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await?; +/// axum::serve(listener, router).await?; +/// ``` +pub fn create_acp_server_router(state: AcpServerState, connector_ops: C) -> Router +where + C: ConnectorOperations + Clone + Send + Sync + 'static, +{ + // Build CORS layer based on configuration + let cors = build_cors_layer(&state.config()); + + // Create combined router state + let router_state = RouterState::new(state, connector_ops); + + // Build the router (T023: added /agent_response route) + Router::new() + .route("/rpc", post(handle_rpc::)) + .route("/events", get(handle_sse::)) + .route("/health", get(handle_health::)) + .route("/agent_response", post(handle_agent_response::)) + .layer(cors) + .with_state(router_state) +} + +/// Build CORS layer from configuration +fn build_cors_layer(config: &AcpServerConfig) -> CorsLayer { + let cors = CorsLayer::new() + .allow_methods([ + axum::http::Method::GET, + axum::http::Method::POST, + axum::http::Method::OPTIONS, + ]) + .allow_headers(Any); + + match &config.allowed_origins { + Some(origins) if !origins.is_empty() => { + // Parse origins and set specific allowed origins + let parsed_origins: Vec<_> = origins + .iter() + .filter_map(|o| o.parse().ok()) + .collect(); + + if parsed_origins.is_empty() { + warn!("No valid origins in allowed_origins, allowing any origin"); + cors.allow_origin(Any) + } else { + debug!("CORS configured with {} allowed origins", parsed_origins.len()); + cors.allow_origin(parsed_origins) + } + } + _ => { + debug!("CORS configured to allow any origin"); + cors.allow_origin(Any) + } + } +} + +// ============================================================================ +// Request/Response Types +// ============================================================================ + +/// Query parameters for the SSE events endpoint +#[derive(Debug, Deserialize)] +pub struct SseQuery { + /// The client ID (required for SSE subscription) + pub client_id: Option, +} + +/// Health check response +#[derive(Debug, Serialize, Deserialize)] +pub struct HealthResponse { + /// Health status + pub status: String, + + /// Server message + pub message: String, + + /// Number of connected clients + pub clients: usize, + + /// Number of active sessions + pub sessions: usize, +} + +/// Error response for SSE endpoint +#[derive(Debug, Serialize, Deserialize)] +pub struct SseErrorResponse { + /// Error message + pub error: String, + + /// Error code + pub code: String, +} + +// ============================================================================ +// Endpoint Handlers (T029, T030) +// ============================================================================ + +/// Handle POST /rpc requests (T029) +/// +/// Extracts the JSON body, processes it through the RPC handler, and returns +/// the JSON-RPC response. +async fn handle_rpc( + State(router_state): State>, + headers: HeaderMap, + body: String, +) -> impl IntoResponse +where + C: ConnectorOperations + Clone + Send + Sync + 'static, +{ + debug!("Received RPC request: {} bytes", body.len()); + + // Extract client_id from X-Client-ID header + let client_id = headers + .get("X-Client-ID") + .and_then(|v| v.to_str().ok()); + + // Extract select_connector from X-Select-Connector header (sent during initialize) + let select_connector = headers + .get("X-Select-Connector") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + if let Some(id) = client_id { + info!(client_id = %id, "Received X-Client-ID header from RPC request"); + debug!("RPC headers: {:?}", headers); + } else { + warn!("No X-Client-ID header provided in RPC request"); + debug!("Available headers: {:?}", headers.keys().collect::>()); + } + + // Log the incoming request + debug!("Received RPC request body: {}", serde_json::to_string(&body).unwrap_or_else(|_| "failed to serialize".to_string())); + + // Create an RPC handler for this request + let handler = RpcHandler::new( + router_state.state.session_manager().clone(), + router_state.connector_ops.clone(), + router_state.state.sse_notifier().clone(), + router_state.state.agent_request_tracker().clone(), + ); + + // Process the request with the client_id from header + let response = handler.handle_request(&body, client_id).await; + + // If this was an initialize request with X-Select-Connector header, + // store the preference after the client is registered + if let Some(ref connector) = select_connector { + // Check if this is an initialize response by looking for clientId in result + if let crate::jsonrpc::JsonRpcResponseBatch::Single(ref resp) = response { + if let Some(ref result) = resp.result { + if let Some(new_client_id) = result.get("clientId").and_then(|v| v.as_str()) { + router_state.state.session_manager().update_client_preferred_connector( + new_client_id, + Some(connector.clone()), + ); + info!( + client_id = %new_client_id, + select_connector = %connector, + "Stored preferred connector from X-Select-Connector header" + ); + } + } + } + } + + // Log the response being sent + debug!("Sending RPC response: {}", serde_json::to_string(&response).unwrap_or_else(|_| "failed to serialize".to_string())); + + // Return JSON response + Json(response) +} + +/// Handle GET /events SSE requests (T030) +/// +/// Creates an SSE stream subscription for the client. The client_id must be +/// provided as a query parameter. +async fn handle_sse( + State(router_state): State>, + Query(query): Query, +) -> Result>>, (StatusCode, Json)> +where + C: ConnectorOperations + Clone + Send + Sync + 'static, +{ + // Validate client_id + let client_id = match query.client_id { + Some(id) if !id.is_empty() => id, + _ => { + warn!("SSE request missing client_id"); + return Err(( + StatusCode::BAD_REQUEST, + Json(SseErrorResponse { + error: "Missing required parameter: client_id".to_string(), + code: "MISSING_CLIENT_ID".to_string(), + }), + )); + } + }; + + info!("SSE subscription requested for client: {}", client_id); + debug!( + "SSE notifier state before subscribe: client_count={}, subscribed_clients={:?}", + router_state.state.sse_notifier().client_count(), + router_state.state.sse_notifier().subscribed_clients() + ); + + // Subscribe to notifications + let notifier = router_state.state.sse_notifier(); + let notification_stream = notifier.subscribe(&client_id); + + info!( + "SSE client subscribed: client_id={}, total_clients={}, is_subscribed={}", + client_id, + notifier.client_count(), + notifier.is_subscribed(&client_id) + ); + + // Check if this client has any sessions that need tool updates + // This handles the race condition where session/new happens before SSE subscription + let session_manager = router_state.state.session_manager(); + let client_sessions = session_manager.list_client_sessions(&client_id); + + if !client_sessions.is_empty() { + debug!( + "Client {} has {} existing sessions, sending tool updates", + client_id, + client_sessions.len() + ); + + // For each session, get the connector and send tool updates + for session_id in client_sessions { + if let Some(mapping) = session_manager.get_mapping(&session_id) { + debug!( + "Fetching tool updates for session {} on connector {}", + session_id, + mapping.connector_id + ); + + // Get available commands from the connector + match router_state.connector_ops.get_connector_commands(&mapping.connector_id).await { + Ok(commands) => { + let update_params = crate::sse::SessionUpdateParams { + session_id: session_id.clone(), + update: crate::sse::SessionUpdateVariant::AvailableCommandsUpdate { + available_commands: commands.clone(), + }, + event_type_override: None, + }; + + // Broadcast the tool update now that client is subscribed + match notifier.broadcast(&client_id, update_params) { + Ok(n) => { + info!( + "Sent deferred available_commands_update for session {}: {} commands to {} receivers", + session_id, + commands.len(), + n + ); + } + Err(_) => { + warn!( + "Failed to send deferred available_commands_update for session {}", + session_id + ); + } + } + } + Err(e) => { + warn!( + "Failed to get connector commands for session {}: {}", + session_id, + e + ); + } + } + } + } + } + + // Map the notification stream to SSE events + let client_id_for_log = client_id.clone(); + let sse_stream = notification_stream.map(move |result| { + match result { + Ok(notification) => { + // T013: Start timing for SSE event creation + let event_start = Instant::now(); + + // Convert notification to SSE event + // Use to_sse_json() which handles raw events correctly + let event_type = notification.event_type(); + let data = notification.to_sse_json(); + + // T011: Log before SSE event is written (T023: includes session_id for correlation) + trace!( + client_id = %client_id_for_log, + event_type = %event_type, + session_id = %notification.session_id, + data_len = data.len(), + "Writing SSE event to client stream" + ); + + debug!( + "SSE: Sending event to client {}: type={}, session_id={}, data_len={}", + client_id_for_log, + event_type, + notification.session_id, + data.len() + ); + trace!("SSE event data: {}", data); + + let event = Event::default() + .event(event_type) + .data(data); + + // T012: Log after Event::default() construction (T023: includes session_id for correlation) + trace!( + client_id = %client_id_for_log, + event_type = %event_type, + session_id = %notification.session_id, + "SSE event constructed, sending to client" + ); + + // T013: Log SSE event write completion with timing (T023: includes session_id for correlation) + let elapsed_ms = event_start.elapsed().as_millis(); + trace!( + client_id = %client_id_for_log, + event_type = %event_type, + session_id = %notification.session_id, + elapsed_ms = elapsed_ms, + "SSE event write completed" + ); + + // T014: Warn for slow SSE event writes (T023: includes session_id for correlation) + if elapsed_ms > 100 { + warn!( + elapsed_ms = elapsed_ms, + client_id = %client_id_for_log, + event_type = %event_type, + session_id = %notification.session_id, + "Slow SSE event write detected" + ); + } + + Ok(event) + } + Err(e) => { + // Broadcast stream error (e.g., lagged receiver) + // We send an error event but keep the stream open + warn!("SSE stream error for client {}: {:?}", client_id_for_log, e); + Ok(Event::default() + .event("error") + .data(format!(r#"{{"error":"Stream error: {:?}"}}"#, e))) + } + } + }); + + // Return SSE response with keep-alive + Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default())) +} + +/// Handle GET /health requests +/// +/// Returns basic health information about the server. +async fn handle_health( + State(router_state): State>, +) -> Json +where + C: ConnectorOperations + Clone + Send + Sync + 'static, +{ + Json(HealthResponse { + status: "healthy".to_string(), + message: "Dirigent ACP Server is running".to_string(), + clients: router_state.state.sse_notifier().client_count(), + sessions: router_state.state.session_manager().mapping_count(), + }) +} + +/// Response for agent response endpoint +#[derive(Debug, Serialize, Deserialize)] +pub struct AgentResponseResult { + /// Status of the response + pub status: String, +} + +/// Handle POST /agent_response requests (T019) +/// +/// Accepts JSON-RPC responses from clients for pending agent requests. +/// When a client approves/denies a permission request, they POST the +/// response to this endpoint, which completes the pending request. +/// (Note: When nested at /acp, the full path becomes /acp/agent_response) +/// +/// Expected body format: +/// ```json +/// { +/// "jsonrpc": "2.0", +/// "id": 0, +/// "result": { "selectedOptionId": "allow" } +/// } +/// ``` +/// +/// The response delivery to the connector happens through the oneshot channel +/// registered in the event bridge (Phase 2.4). The tracker's `complete()` method +/// sends the response value through the oneshot channel, and the event bridge task +/// (which registered the request) will receive it and send the `ConnectorCommand::AgentResponse`. +async fn handle_agent_response( + State(router_state): State>, + headers: HeaderMap, + Json(response): Json, +) -> Result, (StatusCode, String)> +where + C: ConnectorOperations + Clone + Send + Sync + 'static, +{ + // Extract client_id from X-Client-ID header (T020) + let client_id = match headers.get("X-Client-ID").and_then(|v| v.to_str().ok()) { + Some(id) if !id.is_empty() => id, + _ => { + warn!("Agent response request missing X-Client-ID header"); + return Err(( + StatusCode::BAD_REQUEST, + "Missing required header: X-Client-ID".to_string(), + )); + } + }; + + debug!( + "Received agent response from client: {}", + client_id + ); + trace!("Agent response body: {}", response); + + // Extract request_id from JSON body (T020) + let request_id = match response.get("id") { + Some(id) => id.clone(), + None => { + warn!( + client_id = %client_id, + "Agent response missing 'id' field in body" + ); + return Err(( + StatusCode::BAD_REQUEST, + "Missing required field: id".to_string(), + )); + } + }; + + info!( + client_id = %client_id, + request_id = %request_id, + "Processing agent response" + ); + + // Call AgentRequestTracker::complete() (T021) + let tracker = router_state.state.agent_request_tracker(); + match tracker.complete(client_id, request_id.clone(), response.clone()) { + Ok(()) => { + info!( + client_id = %client_id, + request_id = %request_id, + "Agent response delivered successfully" + ); + + // Note (T022): The response delivery to the connector happens through + // the oneshot channel in the event bridge. When the event bridge handles + // an `Event::AgentRequest`, it registers the request with the tracker and + // gets a receiver. After we complete the request here, the receiver in the + // event bridge gets the response value and sends the `ConnectorCommand::AgentResponse` + // to the connector. This flow is implemented in Phase 2.4 (T024-T030). + + Ok(Json(AgentResponseResult { + status: "ok".to_string(), + })) + } + Err(e) => { + warn!( + client_id = %client_id, + request_id = %request_id, + error = %e, + "Agent response failed: request not found (may have timed out)" + ); + + Err(( + StatusCode::NOT_FOUND, + format!("Request ID {} not found or already completed", request_id), + )) + } + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use crate::jsonrpc::JsonRpcResponseBatch; + use crate::rpc::NoOpConnectorOperations; + use axum::{ + body::Body, + http::{Request, StatusCode}, + }; + use tower::ServiceExt; + + fn create_test_router() -> Router { + let state = AcpServerState::new(AcpServerConfig::enabled()); + create_acp_server_router(state, NoOpConnectorOperations) + } + + #[tokio::test] + async fn test_health_endpoint() { + let router = create_test_router(); + + let request = Request::builder() + .uri("/health") + .body(Body::empty()) + .unwrap(); + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let health: HealthResponse = serde_json::from_slice(&body).unwrap(); + + assert_eq!(health.status, "healthy"); + assert_eq!(health.clients, 0); + assert_eq!(health.sessions, 0); + } + + #[tokio::test] + async fn test_rpc_initialize() { + let router = create_test_router(); + + let request_body = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#; + + let request = Request::builder() + .method("POST") + .uri("/rpc") + .header("content-type", "application/json") + .body(Body::from(request_body)) + .unwrap(); + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let resp: JsonRpcResponseBatch = serde_json::from_slice(&body).unwrap(); + + match resp { + JsonRpcResponseBatch::Single(r) => { + assert!(r.is_success()); + let result = r.result.unwrap(); + assert_eq!(result["agentInfo"]["name"], "dirigent-acp-server"); + } + _ => panic!("Expected single response"), + } + } + + #[tokio::test] + async fn test_rpc_invalid_json() { + let router = create_test_router(); + + let request = Request::builder() + .method("POST") + .uri("/rpc") + .header("content-type", "application/json") + .body(Body::from("not valid json")) + .unwrap(); + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::OK); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let resp: JsonRpcResponseBatch = serde_json::from_slice(&body).unwrap(); + + match resp { + JsonRpcResponseBatch::Single(r) => { + assert!(r.is_error()); + let error = r.error.unwrap(); + assert_eq!(error.code, crate::error::error_codes::PARSE_ERROR); + } + _ => panic!("Expected single response"), + } + } + + #[tokio::test] + async fn test_sse_missing_client_id() { + let router = create_test_router(); + + let request = Request::builder() + .uri("/events") + .body(Body::empty()) + .unwrap(); + + let response = router.oneshot(request).await.unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + + let body = axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap(); + let error: SseErrorResponse = serde_json::from_slice(&body).unwrap(); + + assert_eq!(error.code, "MISSING_CLIENT_ID"); + } + + #[tokio::test] + async fn test_sse_with_client_id() { + let router = create_test_router(); + + let request = Request::builder() + .uri("/events?client_id=test-client") + .body(Body::empty()) + .unwrap(); + + let response = router.oneshot(request).await.unwrap(); + + // Should return 200 OK with SSE content type + assert_eq!(response.status(), StatusCode::OK); + + let content_type = response + .headers() + .get("content-type") + .and_then(|v| v.to_str().ok()); + + assert!(content_type.is_some()); + assert!(content_type.unwrap().contains("text/event-stream")); + } + + #[test] + fn test_acp_server_state_new() { + let state = AcpServerState::new(AcpServerConfig::enabled()); + + assert!(state.config().enabled); + assert_eq!(state.session_manager().mapping_count(), 0); + assert_eq!(state.sse_notifier().client_count(), 0); + } + + #[test] + fn test_acp_server_state_with_components() { + let session_manager = SessionManager::new(); + let sse_notifier = SseNotifier::new(); + let config = AcpServerConfig::enabled().set_port(4000); + + // Pre-populate some state + session_manager.register_client(None); + + let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new()); + + let state = AcpServerState::with_components( + session_manager, + sse_notifier, + agent_request_tracker, + config, + ); + + assert!(state.config().enabled); + assert_eq!(state.config().port, 4000); + assert_eq!(state.session_manager().client_count(), 1); + } + + #[test] + fn test_acp_server_state_clone() { + let state1 = AcpServerState::new(AcpServerConfig::enabled()); + let state2 = state1.clone(); + + // Both should share the same underlying state + let client_id = state1.session_manager().register_client(None); + assert_eq!(state2.session_manager().client_count(), 1); + assert!(state2.session_manager().get_client(&client_id).is_some()); + } + + #[test] + fn test_acp_server_state_debug() { + let state = AcpServerState::new(AcpServerConfig::enabled()); + let debug_str = format!("{:?}", state); + + assert!(debug_str.contains("AcpServerState")); + assert!(debug_str.contains("session_count")); + assert!(debug_str.contains("client_count")); + } + + #[test] + fn test_build_cors_layer_any_origin() { + let config = AcpServerConfig::default(); + let _cors = build_cors_layer(&config); + // No assertion needed - just verify it doesn't panic + } + + #[test] + fn test_build_cors_layer_specific_origins() { + let config = AcpServerConfig::default() + .set_allowed_origins(Some(vec![ + "http://localhost:3000".to_string(), + "https://app.example.com".to_string(), + ])); + let _cors = build_cors_layer(&config); + // No assertion needed - just verify it doesn't panic + } + + #[test] + fn test_build_cors_layer_empty_origins() { + let config = AcpServerConfig::default() + .set_allowed_origins(Some(vec![])); + let _cors = build_cors_layer(&config); + // Should fall back to any origin + } +} diff --git a/crates/dirigent_acp_api/src/rpc/mod.rs b/crates/dirigent_acp_api/src/rpc/mod.rs new file mode 100644 index 0000000..19e7ee2 --- /dev/null +++ b/crates/dirigent_acp_api/src/rpc/mod.rs @@ -0,0 +1,1657 @@ +//! JSON-RPC Request Handler for ACP Server +//! +//! This module implements the RPC handler that processes incoming JSON-RPC requests +//! from ACP clients. It dispatches requests to appropriate method handlers and manages +//! the request/response lifecycle. +//! +//! ## Architecture +//! +//! The `RpcHandler` uses a `ConnectorOperations` trait to abstract connector access. +//! This design avoids circular dependencies - the web server implements this trait +//! to bridge to `CoreHandle`, allowing `dirigent_acp_api` to remain independent of +//! `dirigent_core`. +//! +//! ## Module Organization +//! +//! - `types` - Request/response parameter types and `ConnectorOperations` trait +//! - `noop` - No-op implementation for testing +//! +//! ## Supported Methods +//! +//! - `initialize` - Client handshake, returns server capabilities +//! - `session/new` - Create a new session +//! - `session/load` - Load an existing session +//! - `session/prompt` - Send a prompt to a session +//! - `session/cancel` - Cancel active generation +//! - `session/close` - Close a session +//! - `session/set_mode` - Set the session mode (legacy) +//! - `session/set_model` - Set the session model (legacy) +//! - `session/set_config_option` - Set a configuration option (unified mode/model setter) +//! - `session/resume` - Resume an existing session without history replay + +// Submodules +mod noop; +pub mod types; + +// Re-exports +pub use noop::NoOpConnectorOperations; +pub use types::{ + AgentCapabilities, AgentInfo, AuthMethod, ConnectorInfo, ConnectorOperations, ContentBlock, + InitializeParams, InitializeResult, McpCapabilities, PromptCapabilities, PromptContent, + SessionCancelParams, SessionCancelResult, SessionCapabilities, SessionCloseParams, + SessionCloseResult, SessionInfo, SessionListEntry, SessionListParams, SessionListResult, + SessionLoadParams, SessionLoadResult, SessionNewParams, SessionNewResult, SessionPromptParams, + SessionPromptResult, SessionResumeParams, SessionResumeResult, SessionSetConfigOptionParams, + SessionSetModeParams, SessionSetModelParams, +}; + +use tracing::{debug, error, info, warn}; + +use crate::error::JsonRpcErrorObject; +use crate::jsonrpc::{ + JsonRpcId, JsonRpcRequest, JsonRpcRequestBatch, JsonRpcResponse, JsonRpcResponseBatch, +}; +use crate::session_manager::SessionManager; +use crate::sse::{models_to_config_option, modes_to_config_option}; + +/// Protocol version supported by this server +pub const ACP_PROTOCOL_VERSION: u32 = 1; + +/// Server name for capability responses +pub const SERVER_NAME: &str = "dirigent-acp-server"; + +// ============================================================================ +// RpcHandler Implementation +// ============================================================================ + +/// JSON-RPC request handler for the ACP Server +/// +/// The RpcHandler processes incoming JSON-RPC requests, dispatches them to +/// the appropriate method handlers, and manages session state through the +/// SessionManager. +/// +/// ## Thread Safety +/// +/// The RpcHandler is designed to be cloned and shared across async tasks. +/// The SessionManager uses internal locking for thread-safe access. +pub struct RpcHandler { + /// Session manager for tracking client sessions + session_manager: SessionManager, + + /// Connector operations implementation + connector_ops: C, + + /// SSE notifier for broadcasting events to clients + sse_notifier: crate::sse::SseNotifier, + + /// Agent request tracker for bidirectional request/response + agent_request_tracker: std::sync::Arc, +} + +impl RpcHandler { + /// Create a new RPC handler + /// + /// # Parameters + /// - `session_manager`: The session manager for tracking sessions + /// - `connector_ops`: Implementation of connector operations + /// - `sse_notifier`: The SSE notifier for broadcasting events to clients + /// - `agent_request_tracker`: The agent request tracker for bidirectional requests + pub fn new( + session_manager: SessionManager, + connector_ops: C, + sse_notifier: crate::sse::SseNotifier, + agent_request_tracker: std::sync::Arc, + ) -> Self { + Self { + session_manager, + connector_ops, + sse_notifier, + agent_request_tracker, + } + } + + /// Get a reference to the session manager + pub fn session_manager(&self) -> &SessionManager { + &self.session_manager + } + + /// Get a reference to the agent request tracker + pub fn agent_request_tracker( + &self, + ) -> &std::sync::Arc { + &self.agent_request_tracker + } + + /// Resolve a connector ID from a preference string (ID or magic word) + /// + /// # Parameters + /// - `preference`: Either a connector ID or a magic word (claude, codex, gemini) + /// + /// # Returns + /// The resolved connector ID, or an error if not found + async fn resolve_connector_by_preference( + &self, + preference: &str, + ) -> Result { + // First, try exact ID match by listing all connectors + let connectors = self.connector_ops.list_connectors().await?; + + // Try exact ID match + if connectors.iter().any(|c| c.id == preference) { + return Ok(preference.to_string()); + } + + // Try to match by agent_type magic word + // Magic words: claude, codex (alias: openai), gemini (alias: google) + let magic_word = preference.to_lowercase(); + let is_magic_word = matches!( + magic_word.as_str(), + "claude" | "codex" | "openai" | "gemini" | "google" + ); + + if is_magic_word { + // Find first connector whose name/type contains the magic word + for connector in &connectors { + // Match if connector name/type contains the magic word (or its canonical form) + let name_lower = connector.name.to_lowercase(); + let type_lower = connector.connector_type.to_lowercase(); + + let matches = match magic_word.as_str() { + "claude" => name_lower.contains("claude") || type_lower.contains("claude"), + "codex" | "openai" => { + name_lower.contains("codex") + || name_lower.contains("openai") + || type_lower.contains("codex") + || type_lower.contains("openai") + } + "gemini" | "google" => { + name_lower.contains("gemini") + || name_lower.contains("google") + || type_lower.contains("gemini") + || type_lower.contains("google") + } + _ => false, + }; + + if matches { + debug!( + "Matched magic word '{}' to connector: {} ({})", + preference, connector.id, connector.name + ); + return Ok(connector.id.clone()); + } + } + + return Err(crate::error::AcpServerError::ConnectorNotFound(format!( + "No connector found matching agent type '{}'", + preference + ))); + } + + // No match found + Err(crate::error::AcpServerError::ConnectorNotFound(format!( + "No connector found matching preference '{}'", + preference + ))) + } + + /// Handle an incoming request (single or batch) + /// + /// This is the main entry point for processing JSON-RPC requests. + /// It handles both single requests and batch requests. + /// + /// # Parameters + /// - `body`: The raw JSON body of the request + /// - `client_id`: Optional client ID for authenticated requests + /// + /// # Returns + /// A JSON-RPC response (single or batch) + pub async fn handle_request( + &self, + body: &str, + client_id: Option<&str>, + ) -> JsonRpcResponseBatch { + // Try to parse as batch or single request + match serde_json::from_str::(body) { + Ok(batch) => self.handle_batch(batch, client_id).await, + Err(e) => { + error!("Failed to parse JSON-RPC request: {}", e); + JsonRpcResponseBatch::Single(JsonRpcResponse::error_with_null_id( + JsonRpcErrorObject::parse_error(format!("Invalid JSON: {}", e)), + )) + } + } + } + + /// Handle a batch of requests + /// + /// Processes each request in the batch and returns a batch of responses. + /// Notifications (requests without id) do not generate responses. + /// + /// # Parameters + /// - `batch`: The batch of requests + /// - `client_id`: Optional client ID + /// + /// # Returns + /// A batch of responses + pub async fn handle_batch( + &self, + batch: JsonRpcRequestBatch, + client_id: Option<&str>, + ) -> JsonRpcResponseBatch { + match batch { + JsonRpcRequestBatch::Single(request) => { + if request.is_notification() { + // Notifications don't get responses, but we still process them + self.dispatch_request(&request, client_id).await; + // Return an empty batch for notifications + JsonRpcResponseBatch::Batch(vec![]) + } else { + let response = self.dispatch_request(&request, client_id).await; + JsonRpcResponseBatch::Single(response) + } + } + JsonRpcRequestBatch::Batch(requests) => { + if requests.is_empty() { + // Empty batch is an invalid request + return JsonRpcResponseBatch::Single(JsonRpcResponse::error_with_null_id( + JsonRpcErrorObject::invalid_request("Empty batch request"), + )); + } + + let mut responses = Vec::new(); + + for request in requests { + if request.is_notification() { + // Process notification but don't add response + self.dispatch_request(&request, client_id).await; + } else { + let response = self.dispatch_request(&request, client_id).await; + responses.push(response); + } + } + + // If all requests were notifications, return empty batch + if responses.is_empty() { + JsonRpcResponseBatch::Batch(vec![]) + } else { + JsonRpcResponseBatch::from_vec(responses) + } + } + } + } + + /// Dispatch a single request to the appropriate handler + async fn dispatch_request( + &self, + request: &JsonRpcRequest, + client_id: Option<&str>, + ) -> JsonRpcResponse { + let id = request.id.clone().unwrap_or(JsonRpcId::Null); + + // Validate the request + if let Err(e) = request.validate() { + return JsonRpcResponse::error(JsonRpcErrorObject::invalid_request(e), id); + } + + debug!("Dispatching RPC method: {}", request.method); + + // Route to appropriate handler + let result = match request.method.as_str() { + "initialize" => self.handle_initialize(request, client_id).await, + "session/new" => self.handle_session_new(request, client_id).await, + "session/load" => self.handle_session_load(request, client_id).await, + "session/prompt" => self.handle_session_prompt(request, client_id).await, + "session/cancel" => self.handle_session_cancel(request, client_id).await, + "session/close" => self.handle_session_close(request, client_id).await, + "session/set_mode" => self.handle_session_set_mode(request, client_id).await, + "session/set_model" => self.handle_session_set_model(request, client_id).await, + "session/set_config_option" => { + self.handle_session_set_config_option(request, client_id) + .await + } + "session/list" => self.handle_session_list(request, client_id).await, + "session/resume" => self.handle_session_resume(request, client_id).await, + _ => { + warn!("Unknown method: {}", request.method); + Err(JsonRpcErrorObject::method_not_found(&request.method)) + } + }; + + match result { + Ok(value) => JsonRpcResponse::success(value, id), + Err(error) => JsonRpcResponse::error(error, id), + } + } + + // ======================================================================== + // Request Handlers + // ======================================================================== + + /// Handle the initialize request + /// + /// Registers the client and returns server capabilities. + async fn handle_initialize( + &self, + request: &JsonRpcRequest, + client_id: Option<&str>, + ) -> Result { + // Parse parameters (optional) + let params: InitializeParams = request + .params + .as_ref() + .map(|p| serde_json::from_value(p.clone())) + .transpose() + .map_err(|e| { + JsonRpcErrorObject::invalid_params(format!("Invalid initialize params: {}", e)) + })? + .unwrap_or(InitializeParams { + capabilities: None, + client_name: None, + client_version: None, + }); + + // Use provided client_id from X-Client-Id header if available, + // otherwise register a new client + let new_client_id = if let Some(id) = client_id { + info!( + client_id = %id, + "Using provided client_id from X-Client-Id header" + ); + // Register this client_id with the session manager + self.session_manager + .register_client_with_id(id.to_string(), params.capabilities); + id.to_string() + } else { + // Fallback: generate a new client_id + let id = self.session_manager.register_client(params.capabilities); + info!( + client_id = %id, + "Generated new client_id (no X-Client-Id header provided)" + ); + id + }; + + info!( + "Client initialized: {} (name: {:?}, version: {:?})", + new_client_id, params.client_name, params.client_version + ); + + let result = InitializeResult { + protocol_version: ACP_PROTOCOL_VERSION, + agent_capabilities: AgentCapabilities::default(), + agent_info: AgentInfo { + name: SERVER_NAME.to_string(), + title: "Dirigent ACP Server".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }, + auth_methods: vec![], + client_id: new_client_id, + }; + + serde_json::to_value(result) + .map_err(|e| JsonRpcErrorObject::internal_error(format!("Serialization error: {}", e))) + } + + /// Handle session/new request + /// + /// Creates a new session via the target connector. + async fn handle_session_new( + &self, + request: &JsonRpcRequest, + client_id: Option<&str>, + ) -> Result { + // Parse parameters + let params: SessionNewParams = request + .params + .as_ref() + .map(|p| serde_json::from_value(p.clone())) + .transpose() + .map_err(|e| { + JsonRpcErrorObject::invalid_params(format!("Invalid session/new params: {}", e)) + })? + .unwrap_or(SessionNewParams { + connector_id: None, + cwd: None, + session_id: None, + }); + + // Get or generate a client ID + let client_id = client_id + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("http-client-{}", uuid::Uuid::now_v7())); + + // Determine which connector to use + let connector_id = match params.connector_id { + Some(id) => id, + None => { + // Check if client has preferred_connector (from --select-connector) + let client_info = self.session_manager.get_client(&client_id); + let preferred_connector = client_info + .as_ref() + .and_then(|c| c.preferred_connector.clone()); + + if let Some(ref preferred) = preferred_connector { + // Try to resolve the preferred connector + match self.resolve_connector_by_preference(preferred).await { + Ok(connector_id) => { + info!( + "Resolved preferred connector '{}' to connector_id: {}", + preferred, connector_id + ); + connector_id + } + Err(e) => { + warn!( + "Failed to resolve preferred connector '{}': {}. Falling back to default.", + preferred, e + ); + self.connector_ops + .default_connector_id() + .await + .ok_or_else(|| { + JsonRpcErrorObject::invalid_params( + "No connector_id specified and no default connector configured", + ) + })? + } + } + } else { + // Use default connector if available + self.connector_ops + .default_connector_id() + .await + .ok_or_else(|| { + JsonRpcErrorObject::invalid_params( + "No connector_id specified and no default connector configured", + ) + })? + } + } + }; + + // Get client capabilities from the session manager + let client_capabilities: Option = self + .session_manager + .get_client(&client_id) + .and_then(|client| client.capabilities); + + // Create external ownership with capability forwarding + let ownership = dirigent_protocol::SessionOwnership::external_forwarded( + client_id.clone(), + client_capabilities, + ); + + // Create session via connector operations + let session_info = self + .connector_ops + .create_session(&connector_id, params.cwd, ownership.clone()) + .await + .map_err(JsonRpcErrorObject::from)?; + + // Create the session mapping + let mapping = self.session_manager.create_mapping( + &client_id, + params.session_id, + connector_id.clone(), + ); + + // Update the mapping with the actual connector session ID + self.session_manager.update_mapping_internal_session( + &mapping.client_session_id, + session_info.session_id.clone(), + ); + + // Apply ownership model to the mapping for permission request routing + self.session_manager + .update_mapping_ownership(&mapping.client_session_id, ownership); + + info!( + "Created session: {} -> {} on connector {}", + mapping.client_session_id, mapping.internal_session_id, connector_id + ); + + // Build config_options from modes and models + let config_options = { + let mut options = Vec::new(); + if let Some(ref modes) = session_info.modes { + options.push(modes_to_config_option(modes)); + } + if let Some(ref models) = session_info.models { + options.push(models_to_config_option(models)); + } + if options.is_empty() { + None + } else { + Some(options) + } + }; + + let result = SessionNewResult { + session_id: mapping.client_session_id.clone(), + title: session_info.title, + connector_id: session_info.connector_id, + created_at: session_info.created_at, + models: session_info.models.clone(), + modes: session_info.modes.clone(), + config_options, + }; + + // Send available_commands_update notification + let commands = match self + .connector_ops + .get_connector_commands(&connector_id) + .await + { + Ok(cmds) => cmds, + Err(e) => { + warn!( + "Failed to get connector commands for {}: {}. Using empty list.", + connector_id, e + ); + Vec::new() + } + }; + + let update_params = crate::sse::SessionUpdateParams { + session_id: result.session_id.clone(), + update: crate::sse::SessionUpdateVariant::AvailableCommandsUpdate { + available_commands: commands.clone(), + }, + event_type_override: None, + }; + + // Broadcast to this client (ignore errors) + if let Err(_) = self.sse_notifier.broadcast(&client_id, update_params) { + debug!( + "Client {} not subscribed to SSE yet, commands update not sent", + client_id + ); + } + + serde_json::to_value(result) + .map_err(|e| JsonRpcErrorObject::internal_error(format!("Serialization error: {}", e))) + } + + /// Handle session/load request + /// + /// Loads an existing session from a connector. + async fn handle_session_load( + &self, + request: &JsonRpcRequest, + client_id: Option<&str>, + ) -> Result { + let params_value = request + .params + .as_ref() + .ok_or_else(|| JsonRpcErrorObject::invalid_params("Missing params for session/load"))? + .clone(); + + let params: SessionLoadParams = serde_json::from_value(params_value).map_err(|e| { + JsonRpcErrorObject::invalid_params(format!("Invalid session/load params: {}", e)) + })?; + + let client_id = client_id + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("http-client-{}", uuid::Uuid::now_v7())); + + // Resolve connector ID: explicit param -> session mapping -> archivist lookup -> client sessions -> default + let connector_id = if let Some(ref cid) = params.connector_id { + cid.clone() + } else if let Some(mapping) = self.session_manager.get_mapping(¶ms.session_id) { + // Session was already mapped (e.g. from a prior session/list call) + mapping.connector_id + } else if let Some(connector) = self.connector_ops.resolve_session_connector(¶ms.session_id).await { + // Archivist knows which connector owns this session + connector + } else if let Some(connector) = self.resolve_connector_from_client_sessions(&client_id) { + connector + } else if let Some(default) = self.connector_ops.default_connector_id().await { + default + } else { + return Err(JsonRpcErrorObject::invalid_params( + "No connector_id specified and no default connector available", + )); + }; + + // Pre-route: if resolved to gateway but archivist knows the real owner, override + let connector_id = if self.is_gateway_connector(&connector_id).await { + if let Some(real_connector) = self.connector_ops.resolve_session_connector(¶ms.session_id).await { + info!( + "Pre-routing session {} from gateway {} to owning connector {}", + params.session_id, connector_id, real_connector + ); + real_connector + } else { + connector_id + } + } else { + connector_id + }; + + // Load session via connector operations + let mcp_servers_value = params.mcp_servers.as_ref().map(|servers| { + serde_json::Value::Array(servers.clone()) + }); + let session_info = self + .connector_ops + .load_session(&connector_id, ¶ms.session_id, params.cwd.clone(), mcp_servers_value) + .await + .map_err(JsonRpcErrorObject::from)?; + + // Create the session mapping + let mapping = self.session_manager.create_mapping( + &client_id, + Some(params.session_id.clone()), + connector_id.clone(), + ); + + // Update the mapping with the actual connector session ID + self.session_manager.update_mapping_internal_session( + &mapping.client_session_id, + session_info.session_id.clone(), + ); + + info!( + "Loaded session: {} -> {} from connector {}", + mapping.client_session_id, mapping.internal_session_id, connector_id + ); + + // Build config_options from modes and models + let config_options = { + let mut options = Vec::new(); + if let Some(ref modes) = session_info.modes { + options.push(modes_to_config_option(modes)); + } + if let Some(ref models) = session_info.models { + options.push(models_to_config_option(models)); + } + if options.is_empty() { + None + } else { + Some(options) + } + }; + + let result = SessionLoadResult { + session_id: mapping.client_session_id, + title: session_info.title, + connector_id: session_info.connector_id, + created_at: session_info.created_at, + models: session_info.models.clone(), + modes: session_info.modes.clone(), + config_options, + }; + + serde_json::to_value(result) + .map_err(|e| JsonRpcErrorObject::internal_error(format!("Serialization error: {}", e))) + } + + /// Handle session/list request + /// + /// Lists available sessions from the target connector. + async fn handle_session_list( + &self, + request: &JsonRpcRequest, + client_id: Option<&str>, + ) -> Result { + // params are optional for session/list + let params: SessionListParams = request + .params + .as_ref() + .map(|v| serde_json::from_value(v.clone())) + .transpose() + .map_err(|e| { + JsonRpcErrorObject::invalid_params(format!("Invalid session/list params: {}", e)) + })? + .unwrap_or(SessionListParams { + connector_id: None, + cwd: None, + cursor: None, + }); + + let client_id = client_id + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("http-client-{}", uuid::Uuid::now_v7())); + + // Resolve connector ID: explicit param → client's active session → default + let connector_id = if let Some(ref cid) = params.connector_id { + cid.clone() + } else if let Some(connector) = self.resolve_connector_from_client_sessions(&client_id) { + connector + } else if let Some(default) = self.connector_ops.default_connector_id().await { + default + } else { + return Err(JsonRpcErrorObject::invalid_params( + "No connector_id specified and no default connector available", + )); + }; + + // Gateways: list sessions from all connectors via archivist + if self.is_gateway_connector(&connector_id).await { + info!( + "Connector {} is a gateway, listing sessions across all connectors (client: {})", + connector_id, client_id + ); + + let all_sessions = self + .connector_ops + .list_all_sessions() + .await + .map_err(JsonRpcErrorObject::from)?; + + let entries: Vec = all_sessions + .into_iter() + .map(|info| SessionListEntry { + session_id: info.session_id, + cwd: info.cwd.unwrap_or_else(|| ".".to_string()), + title: info.title, + updated_at: Some(info.created_at), + meta: Some(serde_json::json!({ "connectorId": info.connector_id })), + }) + .collect(); + + info!( + "Listed {} sessions across all connectors for client {}", + entries.len(), + client_id + ); + + let result = SessionListResult { + sessions: entries, + next_cursor: None, + }; + return serde_json::to_value(result).map_err(|e| { + JsonRpcErrorObject::internal_error(format!("Serialization error: {}", e)) + }); + } + + info!( + "Listing sessions for connector: {} (client: {})", + connector_id, client_id + ); + + // List sessions via connector operations + let sessions = self + .connector_ops + .list_sessions(&connector_id) + .await + .map_err(JsonRpcErrorObject::from)?; + + // Convert to ACP-spec SessionListEntry format + // Note: We pass through the connector's raw session IDs here. + // No session mappings are created — session/list is a read-only query. + // Mappings are created later when session/load or session/resume is called. + let entries: Vec = sessions + .into_iter() + .map(|info| SessionListEntry { + session_id: info.session_id, + cwd: info.cwd.unwrap_or_else(|| ".".to_string()), + title: info.title, + updated_at: Some(info.created_at), + meta: None, + }) + .collect(); + + info!( + "Listed {} sessions from connector {} for client {}", + entries.len(), + connector_id, + client_id + ); + + // No server-side pagination for now (connector already resolved all pages) + let result = SessionListResult { + sessions: entries, + next_cursor: None, + }; + + serde_json::to_value(result) + .map_err(|e| JsonRpcErrorObject::internal_error(format!("Serialization error: {}", e))) + } + + /// Handle session/resume request + /// + /// Resumes an existing session without history replay. + /// Structurally similar to session/load but semantically different + /// (no session/update history notifications are sent). + async fn handle_session_resume( + &self, + request: &JsonRpcRequest, + client_id: Option<&str>, + ) -> Result { + let params_value = request + .params + .as_ref() + .ok_or_else(|| { + JsonRpcErrorObject::invalid_params("Missing params for session/resume") + })? + .clone(); + + let params: SessionResumeParams = serde_json::from_value(params_value).map_err(|e| { + JsonRpcErrorObject::invalid_params(format!("Invalid session/resume params: {}", e)) + })?; + + let client_id = client_id + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("http-client-{}", uuid::Uuid::now_v7())); + + // Resolve connector ID: explicit param -> session mapping -> client sessions -> default + let connector_id = if let Some(ref cid) = params.connector_id { + cid.clone() + } else if let Some(mapping) = self.session_manager.get_mapping(¶ms.session_id) { + // Session was already mapped (e.g. from a prior session/list call) + mapping.connector_id + } else if let Some(connector) = self.resolve_connector_from_client_sessions(&client_id) { + connector + } else if let Some(default) = self.connector_ops.default_connector_id().await { + default + } else { + return Err(JsonRpcErrorObject::invalid_params( + "No connector_id specified and no default connector available", + )); + }; + + // Resume uses load_session under the hood — the connector decides + // whether to send session/resume or session/load to the upstream agent + let mcp_servers_value = params.mcp_servers.as_ref().map(|servers| { + serde_json::Value::Array(servers.clone()) + }); + let session_info = self + .connector_ops + .load_session(&connector_id, ¶ms.session_id, params.cwd.clone(), mcp_servers_value) + .await + .map_err(JsonRpcErrorObject::from)?; + + // Create the session mapping + let mapping = self.session_manager.create_mapping( + &client_id, + Some(params.session_id.clone()), + connector_id.clone(), + ); + + self.session_manager.update_mapping_internal_session( + &mapping.client_session_id, + session_info.session_id.clone(), + ); + + info!( + "Resumed session: {} -> {} from connector {}", + mapping.client_session_id, mapping.internal_session_id, connector_id + ); + + // Build config_options from modes and models + let config_options = { + let mut options = Vec::new(); + if let Some(ref modes) = session_info.modes { + options.push(modes_to_config_option(modes)); + } + if let Some(ref models) = session_info.models { + options.push(models_to_config_option(models)); + } + if options.is_empty() { + None + } else { + Some(options) + } + }; + + let result = SessionResumeResult { + session_id: mapping.client_session_id, + title: session_info.title, + connector_id: session_info.connector_id, + created_at: session_info.created_at, + models: session_info.models.clone(), + modes: session_info.modes.clone(), + config_options, + }; + + serde_json::to_value(result) + .map_err(|e| JsonRpcErrorObject::internal_error(format!("Serialization error: {}", e))) + } + + /// Handle session/prompt request + /// + /// Sends a prompt to a session and starts streaming the response. + async fn handle_session_prompt( + &self, + request: &JsonRpcRequest, + _client_id: Option<&str>, + ) -> Result { + let params_value = request + .params + .as_ref() + .ok_or_else(|| JsonRpcErrorObject::invalid_params("Missing params for session/prompt"))? + .clone(); + + let params: SessionPromptParams = serde_json::from_value(params_value).map_err(|e| { + JsonRpcErrorObject::invalid_params(format!("Invalid session/prompt params: {}", e)) + })?; + + // Look up the session mapping + let mapping = self + .session_manager + .get_mapping(¶ms.session_id) + .ok_or_else(|| { + JsonRpcErrorObject::new( + crate::error::error_codes::SESSION_NOT_FOUND, + format!("Session not found: {}", params.session_id), + ) + })?; + + // Convert prompt content to text + let text = params.prompt.to_text(); + + info!( + "Received prompt for session {}: {} chars", + params.session_id, + text.len() + ); + + // Send message via connector operations and wait for completion + let stop_reason = self + .connector_ops + .send_message(&mapping.connector_id, &mapping.internal_session_id, text) + .await + .map_err(JsonRpcErrorObject::from)?; + + info!( + "Prompt completed for session {} with stop_reason: {}", + params.session_id, stop_reason + ); + + let result = SessionPromptResult { stop_reason }; + + serde_json::to_value(&result) + .map_err(|e| JsonRpcErrorObject::internal_error(format!("Serialization error: {}", e))) + } + + /// Handle session/cancel request + /// + /// Cancels any active generation on the session. + async fn handle_session_cancel( + &self, + request: &JsonRpcRequest, + _client_id: Option<&str>, + ) -> Result { + let params_value = request + .params + .as_ref() + .ok_or_else(|| JsonRpcErrorObject::invalid_params("Missing params for session/cancel"))? + .clone(); + + let params: SessionCancelParams = serde_json::from_value(params_value).map_err(|e| { + JsonRpcErrorObject::invalid_params(format!("Invalid session/cancel params: {}", e)) + })?; + + // Look up the session mapping + let mapping = self + .session_manager + .get_mapping(¶ms.session_id) + .ok_or_else(|| { + JsonRpcErrorObject::new( + crate::error::error_codes::SESSION_NOT_FOUND, + format!("Session not found: {}", params.session_id), + ) + })?; + + info!("Cancelling generation for session {}", params.session_id); + + // Cancel via connector operations + self.connector_ops + .cancel_generation(&mapping.connector_id, &mapping.internal_session_id) + .await + .map_err(JsonRpcErrorObject::from)?; + + let result = SessionCancelResult { cancelled: true }; + + serde_json::to_value(result) + .map_err(|e| JsonRpcErrorObject::internal_error(format!("Serialization error: {}", e))) + } + + /// Handle session/close request + /// + /// Closes a session and removes the session mapping. + async fn handle_session_close( + &self, + request: &JsonRpcRequest, + _client_id: Option<&str>, + ) -> Result { + let params_value = request + .params + .as_ref() + .ok_or_else(|| JsonRpcErrorObject::invalid_params("Missing params for session/close"))? + .clone(); + + let params: SessionCloseParams = serde_json::from_value(params_value).map_err(|e| { + JsonRpcErrorObject::invalid_params(format!("Invalid session/close params: {}", e)) + })?; + + // Remove the session mapping + let removed = self.session_manager.remove_mapping(¶ms.session_id); + + if let Some(mapping) = removed { + info!( + "Closed session: {} (internal: {}, connector: {})", + params.session_id, mapping.internal_session_id, mapping.connector_id + ); + } else { + warn!( + "Session not found during close: {} (treating as success)", + params.session_id + ); + } + + let result = SessionCloseResult { closed: true }; + serde_json::to_value(result) + .map_err(|e| JsonRpcErrorObject::internal_error(format!("Serialization error: {}", e))) + } + + /// Handle session/set_mode request + /// + /// Sets the mode for a session. This is the legacy mode-setting method. + async fn handle_session_set_mode( + &self, + request: &JsonRpcRequest, + _client_id: Option<&str>, + ) -> Result { + let params_value = request + .params + .as_ref() + .ok_or_else(|| { + JsonRpcErrorObject::invalid_params("Missing params for session/set_mode") + })? + .clone(); + + let params: SessionSetModeParams = serde_json::from_value(params_value).map_err(|e| { + JsonRpcErrorObject::invalid_params(format!("Invalid session/set_mode params: {}", e)) + })?; + + // Look up the session mapping + let mapping = self + .session_manager + .get_mapping(¶ms.session_id) + .ok_or_else(|| { + JsonRpcErrorObject::new( + crate::error::error_codes::SESSION_NOT_FOUND, + format!("Session not found: {}", params.session_id), + ) + })?; + + // Apply forward mapping for legacy mode (Gateway → Agent) + let mapped_mode_id = self + .apply_forward_mode_mapping(&mapping.connector_id, ¶ms.mode_id) + .await; + + info!( + "Setting mode for session {} to mode_id: {} (mapped from: {})", + params.session_id, mapped_mode_id, params.mode_id + ); + + // Forward to connector via connector operations + self.connector_ops + .set_session_mode( + &mapping.connector_id, + &mapping.internal_session_id, + &mapped_mode_id, + ) + .await + .map_err(JsonRpcErrorObject::from)?; + + // Per ACP spec, return empty object on success + Ok(serde_json::json!({})) + } + + /// Handle session/set_model request + /// + /// Sets the model for a session. This is the legacy model-setting method. + async fn handle_session_set_model( + &self, + request: &JsonRpcRequest, + _client_id: Option<&str>, + ) -> Result { + let params_value = request + .params + .as_ref() + .ok_or_else(|| { + JsonRpcErrorObject::invalid_params("Missing params for session/set_model") + })? + .clone(); + + let params: SessionSetModelParams = serde_json::from_value(params_value).map_err(|e| { + JsonRpcErrorObject::invalid_params(format!("Invalid session/set_model params: {}", e)) + })?; + + // Look up the session mapping + let mapping = self + .session_manager + .get_mapping(¶ms.session_id) + .ok_or_else(|| { + JsonRpcErrorObject::new( + crate::error::error_codes::SESSION_NOT_FOUND, + format!("Session not found: {}", params.session_id), + ) + })?; + + // Apply forward mapping for legacy model (Gateway → Agent) + let mapped_model_id = self + .apply_forward_model_mapping(&mapping.connector_id, ¶ms.model_id) + .await; + + info!( + "Setting model for session {} to model_id: {} (mapped from: {})", + params.session_id, mapped_model_id, params.model_id + ); + + // Forward to connector via connector operations + self.connector_ops + .set_session_model( + &mapping.connector_id, + &mapping.internal_session_id, + &mapped_model_id, + ) + .await + .map_err(JsonRpcErrorObject::from)?; + + // Per ACP spec, return empty object on success + Ok(serde_json::json!({})) + } + + /// Handle session/set_config_option request + /// + /// Sets a configuration option for a session. + async fn handle_session_set_config_option( + &self, + request: &JsonRpcRequest, + _client_id: Option<&str>, + ) -> Result { + let params_value = request + .params + .as_ref() + .ok_or_else(|| { + JsonRpcErrorObject::invalid_params("Missing params for session/set_config_option") + })? + .clone(); + + let params: SessionSetConfigOptionParams = + serde_json::from_value(params_value).map_err(|e| { + JsonRpcErrorObject::invalid_params(format!( + "Invalid session/set_config_option params: {}", + e + )) + })?; + + // Look up the session mapping + let mapping = self + .session_manager + .get_mapping(¶ms.session_id) + .ok_or_else(|| { + JsonRpcErrorObject::new( + crate::error::error_codes::SESSION_NOT_FOUND, + format!("Session not found: {}", params.session_id), + ) + })?; + + info!( + "Setting config option for session {}: {}={}", + params.session_id, params.config_id, params.value + ); + + // Map config_id to appropriate handler + match params.config_id.as_str() { + "mode" => { + self.connector_ops + .set_session_mode( + &mapping.connector_id, + &mapping.internal_session_id, + ¶ms.value, + ) + .await + .map_err(JsonRpcErrorObject::from)?; + } + "model" => { + self.connector_ops + .set_session_model( + &mapping.connector_id, + &mapping.internal_session_id, + ¶ms.value, + ) + .await + .map_err(JsonRpcErrorObject::from)?; + } + other => { + warn!( + "Unknown config option ID: {} - ignoring (session: {})", + other, params.session_id + ); + } + } + + // Per ACP spec, return empty object on success + Ok(serde_json::json!({})) + } + + // ======================================================================== + // Forward Mapping Helpers (Gateway → Agent) + // ======================================================================== + + /// Apply forward mapping for mode (Gateway → Agent) + async fn apply_forward_mode_mapping(&self, connector_id: &str, gateway_mode: &str) -> String { + // Get the agent type for this connector + let agent_type = self + .connector_ops + .get_connector_agent_type(connector_id) + .await + .ok() + .flatten(); + + // Apply forward mapping based on agent type + let mapped = match agent_type.as_deref() { + Some("claude") => match gateway_mode { + "ask" => "default", + "plan" => "plan", + "readonly" => "plan", + "write" => "acceptEdits", + "yolo" => "bypassPermissions", + _ => gateway_mode, + }, + Some("codex") | Some("gemini") => gateway_mode, + _ => gateway_mode, + }; + + if mapped != gateway_mode { + tracing::debug!( + connector_id = %connector_id, + gateway_mode = %gateway_mode, + mapped_mode = %mapped, + "Forward mode mapping applied" + ); + } + + mapped.to_string() + } + + /// Apply forward mapping for model (Gateway → Agent) + async fn apply_forward_model_mapping(&self, connector_id: &str, gateway_model: &str) -> String { + // Get the agent type for this connector + let agent_type = self + .connector_ops + .get_connector_agent_type(connector_id) + .await + .ok() + .flatten(); + + // Apply forward mapping based on agent type + let mapped = match agent_type.as_deref() { + Some("claude") => match gateway_model { + "simple" => "haiku", + "dailydriver" => "sonnet", + "high" => "opus", + _ => gateway_model, + }, + Some("codex") | Some("gemini") => gateway_model, + _ => gateway_model, + }; + + if mapped != gateway_model { + tracing::debug!( + connector_id = %connector_id, + gateway_model = %gateway_model, + mapped_model = %mapped, + "Forward model mapping applied" + ); + } + + mapped.to_string() + } + + /// Check if a connector is a gateway type + /// + /// Gateway connectors don't have real sessions — they act as routers + /// that transfer clients to actual agent connectors. When a session/list + /// Resolve the connector ID from the client's existing session mappings. + /// Returns the connector from the most recent session, if any. + fn resolve_connector_from_client_sessions(&self, client_id: &str) -> Option { + let session_ids = self.session_manager.list_client_sessions(client_id); + // Take the last session (most recently added) and look up its connector + session_ids.last().and_then(|sid| { + self.session_manager + .get_mapping(sid) + .map(|m| m.connector_id) + }) + } + + /// request targets a gateway, the RPC handler should delay the response + /// until a transfer occurs. + async fn is_gateway_connector(&self, connector_id: &str) -> bool { + match self.connector_ops.list_connectors().await { + Ok(connectors) => connectors + .iter() + .any(|c| c.id == connector_id && c.connector_type.to_lowercase() == "gateway"), + Err(_) => false, + } + } +} + +// ============================================================================ +// Tests +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + fn create_test_handler() -> RpcHandler { + RpcHandler::new( + SessionManager::new(), + NoOpConnectorOperations, + crate::sse::SseNotifier::new(), + std::sync::Arc::new(crate::agent_requests::AgentRequestTracker::new()), + ) + } + + #[tokio::test] + async fn test_handle_initialize() { + let handler = create_test_handler(); + + let request = json!({ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "client_name": "test-client", + "client_version": "1.0.0" + }, + "id": 1 + }); + + let response = handler.handle_request(&request.to_string(), None).await; + + match response { + JsonRpcResponseBatch::Single(resp) => { + assert!(resp.is_success()); + let result = resp.result.unwrap(); + assert_eq!(result["agentInfo"]["name"], SERVER_NAME); + assert_eq!(result["protocolVersion"], ACP_PROTOCOL_VERSION); + assert!(!result["clientId"].as_str().unwrap().is_empty()); + } + _ => panic!("Expected single response"), + } + } + + #[tokio::test] + async fn test_handle_session_new() { + let handler = create_test_handler(); + + // First initialize to get a client ID + let init_request = json!({ + "jsonrpc": "2.0", + "method": "initialize", + "id": 1 + }); + + let init_response = handler + .handle_request(&init_request.to_string(), None) + .await; + let client_id = match init_response { + JsonRpcResponseBatch::Single(resp) => resp.result.unwrap()["clientId"] + .as_str() + .unwrap() + .to_string(), + _ => panic!("Expected single response"), + }; + + // Now create a session + let request = json!({ + "jsonrpc": "2.0", + "method": "session/new", + "params": { + "connector_id": "test-connector" + }, + "id": 2 + }); + + let response = handler + .handle_request(&request.to_string(), Some(&client_id)) + .await; + + match response { + JsonRpcResponseBatch::Single(resp) => { + assert!(resp.is_success()); + let result = resp.result.unwrap(); + assert!(!result["sessionId"].as_str().unwrap().is_empty()); + } + _ => panic!("Expected single response"), + } + } + + #[tokio::test] + async fn test_handle_session_close() { + let handler = create_test_handler(); + + // Initialize + let init_request = json!({ + "jsonrpc": "2.0", + "method": "initialize", + "id": 1 + }); + let init_response = handler + .handle_request(&init_request.to_string(), None) + .await; + let client_id = match init_response { + JsonRpcResponseBatch::Single(resp) => resp.result.unwrap()["clientId"] + .as_str() + .unwrap() + .to_string(), + _ => panic!("Expected single response"), + }; + + // Create session + let new_request = json!({ + "jsonrpc": "2.0", + "method": "session/new", + "params": {}, + "id": 2 + }); + let new_response = handler + .handle_request(&new_request.to_string(), Some(&client_id)) + .await; + let session_id = match new_response { + JsonRpcResponseBatch::Single(resp) => resp.result.unwrap()["sessionId"] + .as_str() + .unwrap() + .to_string(), + _ => panic!("Expected single response"), + }; + + // Close session + let close_request = json!({ + "jsonrpc": "2.0", + "method": "session/close", + "params": { + "session_id": session_id + }, + "id": 3 + }); + + let response = handler + .handle_request(&close_request.to_string(), Some(&client_id)) + .await; + + match response { + JsonRpcResponseBatch::Single(resp) => { + assert!(resp.is_success()); + let result = resp.result.unwrap(); + assert_eq!(result["closed"], true); + } + _ => panic!("Expected single response"), + } + } + + #[tokio::test] + async fn test_handle_batch_request() { + let handler = create_test_handler(); + + let batch = json!([ + { + "jsonrpc": "2.0", + "method": "initialize", + "id": 1 + }, + { + "jsonrpc": "2.0", + "method": "initialize", + "id": 2 + } + ]); + + let response = handler.handle_request(&batch.to_string(), None).await; + + match response { + JsonRpcResponseBatch::Batch(responses) => { + assert_eq!(responses.len(), 2); + for resp in responses { + assert!(resp.is_success()); + } + } + _ => panic!("Expected batch response"), + } + } + + #[tokio::test] + async fn test_handle_unknown_method() { + let handler = create_test_handler(); + + let request = json!({ + "jsonrpc": "2.0", + "method": "unknown/method", + "id": 1 + }); + + let response = handler.handle_request(&request.to_string(), None).await; + + match response { + JsonRpcResponseBatch::Single(resp) => { + assert!(!resp.is_success()); + let error = resp.error.unwrap(); + assert_eq!(error.code, -32601); // Method not found + } + _ => panic!("Expected single response"), + } + } + + #[tokio::test] + async fn test_handle_invalid_json() { + let handler = create_test_handler(); + + let response = handler.handle_request("invalid json", None).await; + + match response { + JsonRpcResponseBatch::Single(resp) => { + assert!(!resp.is_success()); + let error = resp.error.unwrap(); + assert_eq!(error.code, -32700); // Parse error + } + _ => panic!("Expected single response"), + } + } + + #[tokio::test] + async fn test_handle_empty_batch() { + let handler = create_test_handler(); + + let response = handler.handle_request("[]", None).await; + + match response { + JsonRpcResponseBatch::Single(resp) => { + assert!(!resp.is_success()); + let error = resp.error.unwrap(); + assert_eq!(error.code, -32600); // Invalid request + } + _ => panic!("Expected single response"), + } + } + + #[test] + fn test_prompt_content_to_text() { + // Test simple text + let text_content = PromptContent::Text("Hello world".to_string()); + assert_eq!(text_content.to_text(), "Hello world"); + + // Test blocks with text + let blocks_content = PromptContent::Blocks(vec![ + ContentBlock::Text { + text: "First line".to_string(), + }, + ContentBlock::Text { + text: "Second line".to_string(), + }, + ]); + assert_eq!(blocks_content.to_text(), "First line\nSecond line"); + + // Test blocks with mixed content (images filtered out) + let mixed_content = PromptContent::Blocks(vec![ + ContentBlock::Text { + text: "Text only".to_string(), + }, + ContentBlock::Image { + data: "base64data".to_string(), + media_type: "image/png".to_string(), + }, + ]); + assert_eq!(mixed_content.to_text(), "Text only"); + } + + #[test] + fn test_session_info_camel_case() { + let info = SessionInfo { + session_id: "test-123".to_string(), + title: Some("Test Session".to_string()), + connector_id: "conn-1".to_string(), + cwd: None, + created_at: "2024-01-01T00:00:00Z".to_string(), + models: None, + modes: None, + }; + + let json = serde_json::to_value(&info).unwrap(); + assert!(json.get("sessionId").is_some()); + assert!(json.get("connectorId").is_some()); + assert!(json.get("createdAt").is_some()); + } + + #[test] + fn test_connector_info_camel_case() { + let info = ConnectorInfo { + id: "conn-1".to_string(), + name: "Test Connector".to_string(), + connector_type: "test".to_string(), + available: true, + }; + + let json = serde_json::to_value(&info).unwrap(); + assert!(json.get("connectorType").is_some()); + } + + #[test] + fn test_initialize_result_camel_case() { + let result = InitializeResult { + protocol_version: 1, + agent_capabilities: AgentCapabilities::default(), + agent_info: AgentInfo { + name: "test".to_string(), + title: "Test".to_string(), + version: "1.0.0".to_string(), + }, + auth_methods: vec![], + client_id: "client-1".to_string(), + }; + + let json = serde_json::to_value(&result).unwrap(); + assert!(json.get("protocolVersion").is_some()); + assert!(json.get("agentCapabilities").is_some()); + assert!(json.get("agentInfo").is_some()); + assert!(json.get("authMethods").is_some()); + assert!(json.get("clientId").is_some()); + } +} diff --git a/crates/dirigent_acp_api/src/rpc/noop.rs b/crates/dirigent_acp_api/src/rpc/noop.rs new file mode 100644 index 0000000..09ef1fd --- /dev/null +++ b/crates/dirigent_acp_api/src/rpc/noop.rs @@ -0,0 +1,185 @@ +//! NoOp Connector Operations for Testing +//! +//! This module provides a stub implementation of `ConnectorOperations` that +//! can be used for testing the RPC handler without actual connector access. + +use async_trait::async_trait; +use tracing::debug; + +use crate::error::AcpServerError; + +use super::types::{ConnectorInfo, ConnectorOperations, SessionInfo}; + +/// A no-op implementation of ConnectorOperations for testing +/// +/// This stub provides minimal implementations that return success without +/// actually doing anything. Useful for unit testing the RPC dispatch logic. +#[derive(Debug, Clone, Copy, Default)] +pub struct NoOpConnectorOperations; + +#[async_trait] +impl ConnectorOperations for NoOpConnectorOperations { + async fn create_session( + &self, + connector_id: &str, + _cwd: Option, + _ownership: dirigent_protocol::SessionOwnership, + ) -> Result { + Ok(SessionInfo { + session_id: uuid::Uuid::new_v4().to_string(), + title: Some("New Session".to_string()), + connector_id: connector_id.to_string(), + cwd: None, + created_at: chrono::Utc::now().to_rfc3339(), + models: None, + modes: None, + }) + } + + async fn load_session( + &self, + connector_id: &str, + session_id: &str, + _cwd: Option, + _mcp_servers: Option, + ) -> Result { + Ok(SessionInfo { + session_id: session_id.to_string(), + title: Some("Loaded Session".to_string()), + connector_id: connector_id.to_string(), + cwd: None, + created_at: chrono::Utc::now().to_rfc3339(), + models: None, + modes: None, + }) + } + + async fn send_message( + &self, + _connector_id: &str, + _session_id: &str, + text: String, + ) -> Result { + debug!("NoOp: send_message - {} chars", text.len()); + Ok("end_turn".to_string()) + } + + async fn cancel_generation( + &self, + _connector_id: &str, + _session_id: &str, + ) -> Result<(), AcpServerError> { + debug!("NoOp: cancel_generation"); + Ok(()) + } + + async fn list_connectors(&self) -> Result, AcpServerError> { + Ok(vec![ConnectorInfo { + id: "stub-connector".to_string(), + name: "Stub Connector".to_string(), + connector_type: "stub".to_string(), + available: true, + }]) + } + + async fn default_connector_id(&self) -> Option { + Some("stub-connector".to_string()) + } + + async fn get_connector_commands( + &self, + _connector_id: &str, + ) -> Result, AcpServerError> { + // Return a stub list of commands + Ok(vec![crate::sse::SlashCommand { + name: "echo".to_string(), + description: "Echo command (stub)".to_string(), + input: None, + }]) + } + + async fn send_agent_response( + &self, + connector_id: &str, + request_id: serde_json::Value, + _response: serde_json::Value, + ) -> Result<(), AcpServerError> { + debug!( + "NoOp: send_agent_response - connector: {}, request: {}", + connector_id, request_id + ); + Ok(()) + } + + async fn get_session_metadata( + &self, + _connector_id: &str, + _session_id: &str, + ) -> Result< + ( + Option, + Option, + ), + AcpServerError, + > { + debug!("NoOp: get_session_metadata"); + // Return None for both - no metadata available in stub + Ok((None, None)) + } + + async fn set_session_mode( + &self, + _connector_id: &str, + _session_id: &str, + mode_id: &str, + ) -> Result<(), AcpServerError> { + debug!("NoOp: set_session_mode - mode_id: {}", mode_id); + Ok(()) + } + + async fn set_session_model( + &self, + _connector_id: &str, + _session_id: &str, + model_id: &str, + ) -> Result<(), AcpServerError> { + debug!("NoOp: set_session_model - model_id: {}", model_id); + Ok(()) + } + + async fn get_connector_agent_type( + &self, + _connector_id: &str, + ) -> Result, AcpServerError> { + debug!("NoOp: get_connector_agent_type"); + Ok(None) + } + + async fn list_sessions( + &self, + connector_id: &str, + ) -> Result, AcpServerError> { + debug!("NoOp: list_sessions for connector: {}", connector_id); + Ok(vec![ + SessionInfo { + session_id: "stub-session-1".to_string(), + title: Some("Stub Session".to_string()), + connector_id: connector_id.to_string(), + cwd: None, + created_at: chrono::Utc::now().to_rfc3339(), + models: None, + modes: None, + }, + ]) + } + + async fn resolve_session_connector(&self, _session_id: &str) -> Option { + debug!("NoOp: resolve_session_connector"); + None + } + + async fn list_all_sessions(&self) -> Result, AcpServerError> { + debug!("NoOp: list_all_sessions"); + Ok(vec![]) + } +} diff --git a/crates/dirigent_acp_api/src/rpc/types.rs b/crates/dirigent_acp_api/src/rpc/types.rs new file mode 100644 index 0000000..ab33923 --- /dev/null +++ b/crates/dirigent_acp_api/src/rpc/types.rs @@ -0,0 +1,705 @@ +//! JSON-RPC Request/Response Types for ACP Server +//! +//! This module contains all the parameter and result types used in +//! the ACP JSON-RPC protocol. + +use serde::{Deserialize, Serialize}; + +use crate::error::AcpServerError; +use crate::sse::ConfigOption; + +// ============================================================================ +// Session and Connector Info Types +// ============================================================================ + +/// Information about a session +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionInfo { + /// The session ID + pub session_id: String, + + /// Optional title for the session + pub title: Option, + + /// The connector handling this session + pub connector_id: String, + + /// Working directory for this session + pub cwd: Option, + + /// When the session was created (ISO 8601 format) + pub created_at: String, + + /// Available models and current model (optional, connector-dependent) + pub models: Option, + + /// Available modes and current mode (optional, connector-dependent) + pub modes: Option, +} + +/// Information about a connector +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ConnectorInfo { + /// Unique identifier for this connector + pub id: String, + + /// Human-readable name + pub name: String, + + /// Type of connector (e.g., "opencode", "gateway") + pub connector_type: String, + + /// Whether the connector is currently available + pub available: bool, +} + +// ============================================================================ +// Initialize Types +// ============================================================================ + +/// Parameters for the initialize request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializeParams { + /// Client capabilities + #[serde(default)] + pub capabilities: Option, + + /// Client name + #[serde(default)] + pub client_name: Option, + + /// Client version + #[serde(default)] + pub client_version: Option, +} + +/// Result of the initialize request (ACP spec compliant) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InitializeResult { + /// Protocol version (integer) + pub protocol_version: u32, + + /// Agent capabilities + pub agent_capabilities: AgentCapabilities, + + /// Agent info (name, title, version) + pub agent_info: AgentInfo, + + /// Authentication methods (empty for now) + pub auth_methods: Vec, + + /// Client ID for HTTP-based transports (required for SSE routing) + /// Note: This is not in the official ACP spec but is required for stateless + /// HTTP connections where the client needs an ID to subscribe to SSE events. + pub client_id: String, +} + +/// Agent information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentInfo { + /// Agent name (programmatic) + pub name: String, + + /// Agent title (human-readable) + pub title: String, + + /// Agent version + pub version: String, +} + +/// Agent capabilities advertised to clients +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentCapabilities { + /// Whether session/load is supported + #[serde(skip_serializing_if = "Option::is_none")] + pub load_session: Option, + + /// Whether session/list is supported + #[serde(skip_serializing_if = "Option::is_none")] + pub list_sessions: Option, + + /// Session capabilities (resume, fork, etc.) + #[serde(skip_serializing_if = "Option::is_none")] + pub session_capabilities: Option, + + /// Prompt capabilities (content types supported) + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_capabilities: Option, + + /// MCP server support + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp: Option, +} + +/// Extended session capabilities +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCapabilities { + /// Whether session/list is supported (empty object = supported) + #[serde(skip_serializing_if = "Option::is_none")] + pub list: Option, + + /// Whether session/resume is supported (empty object = supported) + #[serde(skip_serializing_if = "Option::is_none")] + pub resume: Option, +} + +/// Prompt capabilities (content types) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptCapabilities { + /// Image content support + #[serde(skip_serializing_if = "Option::is_none")] + pub image: Option, + + /// Audio content support + #[serde(skip_serializing_if = "Option::is_none")] + pub audio: Option, + + /// Embedded context support + #[serde(skip_serializing_if = "Option::is_none")] + pub embedded_context: Option, +} + +/// MCP capabilities +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpCapabilities { + /// HTTP transport support + #[serde(skip_serializing_if = "Option::is_none")] + pub http: Option, + + /// SSE transport support (deprecated) + #[serde(skip_serializing_if = "Option::is_none")] + pub sse: Option, +} + +/// Authentication method +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AuthMethod { + /// Method ID + pub id: String, + + /// Method name + pub name: String, + + /// Method description + pub description: String, +} + +impl Default for AgentCapabilities { + fn default() -> Self { + Self { + load_session: Some(true), + list_sessions: Some(true), + session_capabilities: Some(SessionCapabilities { + list: Some(serde_json::json!({})), + resume: Some(serde_json::json!({})), + }), + prompt_capabilities: Some(PromptCapabilities { + image: Some(true), + audio: None, + embedded_context: Some(true), + }), + mcp: Some(McpCapabilities { + http: Some(true), + sse: Some(true), + }), + } + } +} + +// ============================================================================ +// Session/New Types +// ============================================================================ + +/// Parameters for session/new request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionNewParams { + /// Optional connector ID to create the session on + #[serde(default)] + pub connector_id: Option, + + /// Optional working directory for the session + #[serde(default)] + pub cwd: Option, + + /// Optional client-provided session ID + #[serde(default)] + pub session_id: Option, +} + +/// Result of session/new request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionNewResult { + /// The client-facing session ID + pub session_id: String, + + /// Optional title for the session + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + + /// The connector handling this session + pub connector_id: String, + + /// When the session was created (ISO 8601 format) + pub created_at: String, + + /// Available models and current model (UNSTABLE in ACP spec) + #[serde(skip_serializing_if = "Option::is_none")] + pub models: Option, + + /// Available modes and current mode + #[serde(skip_serializing_if = "Option::is_none")] + pub modes: Option, + + /// Configuration options (modes, models as unified config) + /// This is the preferred way to expose configuration in ACP + #[serde(skip_serializing_if = "Option::is_none")] + pub config_options: Option>, +} + +// ============================================================================ +// Session/Load Types +// ============================================================================ + +/// Parameters for session/load request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionLoadParams { + /// The session ID to load + pub session_id: String, + + /// The connector ID where the session exists (optional; resolved automatically if omitted) + #[serde(default)] + pub connector_id: Option, + + /// Optional working directory (standard ACP field sent by clients like Zed) + #[serde(default)] + pub cwd: Option, + + /// Optional MCP server configurations (standard ACP field) + #[serde(default)] + pub mcp_servers: Option>, +} + +/// Result of session/load request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionLoadResult { + /// The client-facing session ID + pub session_id: String, + + /// Optional title for the session + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + + /// The connector handling this session + pub connector_id: String, + + /// When the session was created (ISO 8601 format) + pub created_at: String, + + /// Available models and current model (UNSTABLE in ACP spec) + #[serde(skip_serializing_if = "Option::is_none")] + pub models: Option, + + /// Available modes and current mode + #[serde(skip_serializing_if = "Option::is_none")] + pub modes: Option, + + /// Configuration options (modes, models as unified config) + /// This is the preferred way to expose configuration in ACP + #[serde(skip_serializing_if = "Option::is_none")] + pub config_options: Option>, +} + +// ============================================================================ +// Session/List Types +// ============================================================================ + +/// Parameters for session/list request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionListParams { + /// Optional connector ID to list sessions from + #[serde(default)] + pub connector_id: Option, + + /// Optional working directory filter + #[serde(default)] + pub cwd: Option, + + /// Optional pagination cursor from previous response + #[serde(default)] + pub cursor: Option, +} + +/// A session entry in a session/list response (ACP spec) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionListEntry { + /// Unique session identifier + pub session_id: String, + + /// Working directory for this session + pub cwd: String, + + /// Human-readable title + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + + /// Last activity timestamp (ISO 8601) + #[serde(skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + + /// Agent-specific metadata + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "_meta")] + pub meta: Option, +} + +/// Result of session/list request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionListResult { + /// List of available sessions + pub sessions: Vec, + + /// Pagination cursor for next page (absent when no more results) + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} + +// ============================================================================ +// Session/Resume Types +// ============================================================================ + +/// Parameters for session/resume request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionResumeParams { + /// The session ID to resume + pub session_id: String, + + /// The connector ID where the session exists (optional; resolved automatically if omitted) + #[serde(default)] + pub connector_id: Option, + + /// Optional working directory (standard ACP field sent by clients like Zed) + #[serde(default)] + pub cwd: Option, + + /// Optional MCP server configurations (standard ACP field) + #[serde(default)] + pub mcp_servers: Option>, +} + +/// Result of session/resume request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionResumeResult { + /// The client-facing session ID + pub session_id: String, + + /// Optional title for the session + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + + /// The connector handling this session + pub connector_id: String, + + /// When the session was created (ISO 8601 format) + pub created_at: String, + + /// Available models and current model + #[serde(skip_serializing_if = "Option::is_none")] + pub models: Option, + + /// Available modes and current mode + #[serde(skip_serializing_if = "Option::is_none")] + pub modes: Option, + + /// Configuration options + #[serde(skip_serializing_if = "Option::is_none")] + pub config_options: Option>, +} + +// ============================================================================ +// Session/Prompt Types +// ============================================================================ + +/// Parameters for session/prompt request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPromptParams { + /// The session ID to send the prompt to + pub session_id: String, + + /// The prompt content (array of content blocks) + pub prompt: PromptContent, +} + +/// Content for a prompt - either simple text or structured blocks +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PromptContent { + /// Simple text content + Text(String), + + /// Structured content blocks + Blocks(Vec), +} + +impl PromptContent { + /// Convert to plain text representation + pub fn to_text(&self) -> String { + match self { + PromptContent::Text(text) => text.clone(), + PromptContent::Blocks(blocks) => blocks + .iter() + .filter_map(|b| { + if let ContentBlock::Text { text } = b { + Some(text.clone()) + } else { + None + } + }) + .collect::>() + .join("\n"), + } + } +} + +/// A content block in a prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ContentBlock { + /// Text content + #[serde(rename = "text")] + Text { text: String }, + + /// Image content (base64 encoded) + #[serde(rename = "image")] + Image { data: String, media_type: String }, +} + +/// Result of session/prompt request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPromptResult { + /// The reason the turn stopped + pub stop_reason: String, +} + +// ============================================================================ +// Session/Cancel Types +// ============================================================================ + +/// Parameters for session/cancel request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCancelParams { + /// The session ID to cancel + pub session_id: String, +} + +/// Result of session/cancel request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCancelResult { + /// Whether the cancellation was successful + pub cancelled: bool, +} + +// ============================================================================ +// Session/Close Types +// ============================================================================ + +/// Parameters for session/close request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCloseParams { + /// The session ID to close + pub session_id: String, +} + +/// Result of session/close request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCloseResult { + /// Whether the session was successfully closed + pub closed: bool, +} + +// ============================================================================ +// Session/SetMode and Session/SetModel Types +// ============================================================================ + +/// Parameters for session/set_mode request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSetModeParams { + /// The session ID to update + pub session_id: String, + + /// The mode ID to switch to + pub mode_id: String, +} + +/// Parameters for session/set_model request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSetModelParams { + /// The session ID to update + pub session_id: String, + + /// The model ID to switch to + pub model_id: String, +} + +/// Parameters for session/set_config_option request +/// +/// This is the unified way to set configuration options (modes, models, etc.) +/// for clients using the new config_options system. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSetConfigOptionParams { + /// The session ID to update + pub session_id: String, + + /// The config option ID (e.g., "mode", "model") + pub config_id: String, + + /// The value to set + pub value: String, +} + +// ============================================================================ +// ConnectorOperations Trait +// ============================================================================ + +/// Trait for connector operations that will be implemented by the web server +/// +/// This trait abstracts the operations that require access to `CoreHandle`, +/// allowing the RPC handler to remain decoupled from `dirigent_core`. +/// The web server implements this trait to bridge between the ACP server +/// and the actual connector implementations. +#[async_trait::async_trait] +pub trait ConnectorOperations: Send + Sync { + /// Create a new session on the specified connector + async fn create_session( + &self, + connector_id: &str, + cwd: Option, + ownership: dirigent_protocol::SessionOwnership, + ) -> Result; + + /// Load an existing session from a connector + async fn load_session( + &self, + connector_id: &str, + session_id: &str, + cwd: Option, + mcp_servers: Option, + ) -> Result; + + /// Send a message to a session and wait for completion + async fn send_message( + &self, + connector_id: &str, + session_id: &str, + text: String, + ) -> Result; + + /// Cancel active generation on a session + async fn cancel_generation( + &self, + connector_id: &str, + session_id: &str, + ) -> Result<(), AcpServerError>; + + /// List all available connectors + async fn list_connectors(&self) -> Result, AcpServerError>; + + /// Get the default connector ID (if configured) + async fn default_connector_id(&self) -> Option; + + /// Get available commands/tools from a connector + async fn get_connector_commands( + &self, + connector_id: &str, + ) -> Result, AcpServerError>; + + /// Send an agent response back to a connector + async fn send_agent_response( + &self, + connector_id: &str, + request_id: serde_json::Value, + response: serde_json::Value, + ) -> Result<(), AcpServerError>; + + /// Get session metadata (models/modes) from a session + async fn get_session_metadata( + &self, + connector_id: &str, + session_id: &str, + ) -> Result< + ( + Option, + Option, + ), + AcpServerError, + >; + + /// Set the session mode + async fn set_session_mode( + &self, + connector_id: &str, + session_id: &str, + mode_id: &str, + ) -> Result<(), AcpServerError>; + + /// Set the session model + async fn set_session_model( + &self, + connector_id: &str, + session_id: &str, + model_id: &str, + ) -> Result<(), AcpServerError>; + + /// Get the agent type for a connector (for mode/model mapping) + async fn get_connector_agent_type( + &self, + connector_id: &str, + ) -> Result, AcpServerError>; + + /// List all sessions on a connector + async fn list_sessions( + &self, + connector_id: &str, + ) -> Result, AcpServerError>; + + /// Look up which connector owns a session by its session ID. + /// Uses archivist to find the connector. Returns connector_id (not UID). + /// Default: returns None (no archivist available). + async fn resolve_session_connector(&self, _session_id: &str) -> Option { + None + } + + /// List sessions across all connectors (archivist-backed). + /// Used when the resolved connector is a gateway to provide cross-connector view. + /// Default: returns empty vec (no archivist available). + async fn list_all_sessions(&self) -> Result, AcpServerError> { + Ok(vec![]) + } +} diff --git a/crates/dirigent_acp_api/src/session_manager.rs b/crates/dirigent_acp_api/src/session_manager.rs new file mode 100644 index 0000000..da1c218 --- /dev/null +++ b/crates/dirigent_acp_api/src/session_manager.rs @@ -0,0 +1,1405 @@ +//! Session Manager for ACP Server +//! +//! This module provides session mapping and client connection tracking for the ACP Server. +//! It manages the relationship between client-facing session IDs and internal session IDs, +//! as well as tracking connected clients and their capabilities. +//! +//! ## Architecture +//! +//! The `SessionManager` uses `Arc>` internally for thread-safe access, +//! allowing multiple readers or a single writer at a time. This enables concurrent +//! read operations while maintaining data consistency during writes. +//! +//! ## Example +//! +//! ```rust +//! use dirigent_acp_api::session_manager::SessionManager; +//! +//! // Create a new session manager +//! let manager = SessionManager::new(); +//! +//! // Register a client +//! let client_id = manager.register_client(None); +//! +//! // Create a session mapping for the client +//! let mapping = manager.create_mapping( +//! &client_id, +//! None, // Auto-generate client_session_id +//! "connector-1".to_string(), +//! ); +//! ``` + +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +use chrono::{DateTime, Utc}; +use dirigent_protocol::streaming::{BusEvent, EventKind, EventOrigin, EventRouting}; +use dirigent_protocol::{Event, SessionOwnership}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::error::AcpServerError; + +/// Closure that publishes a `BusEvent` to the shared event bus. +/// +/// The ACP api crate can't depend on `dirigent_core::SharingBus` directly +/// (dependency direction is the other way round). Callers in +/// `web::acp_server` / `api::acp_controller` wrap the core's +/// `SharingBus::publish` call in a closure and install it via +/// [`SessionManagerBuilder::with_bus_publisher`]. The closure is +/// synchronous; implementations are expected to spawn the publish onto a +/// tokio task if the publish API is async. +pub type BusPublishFn = Arc; + +/// Mapping between client-facing session ID and internal session ID +/// +/// This struct tracks the relationship between the session ID exposed to ACP clients +/// and the internal session ID used by connectors. It also records which client and +/// connector the session belongs to. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionMapping { + /// The session ID exposed to the ACP client + pub client_session_id: String, + + /// The internal session ID used by the connector + pub internal_session_id: String, + + /// The ID of the connector handling this session + pub connector_id: String, + + /// The ID of the client that owns this session + pub client_id: String, + + /// When this mapping was created + pub created_at: DateTime, + + /// Ownership model for this session + /// + /// Tracks whether this session is internal to Dirigent or originates from + /// an external client, and how tool calls should be handled. + pub ownership: SessionOwnership, +} + +impl SessionMapping { + /// Create a new session mapping + pub fn new( + client_session_id: String, + internal_session_id: String, + connector_id: String, + client_id: String, + ) -> Self { + Self { + client_session_id, + internal_session_id, + connector_id, + client_id, + created_at: Utc::now(), + ownership: SessionOwnership::default(), + } + } +} + +/// Information about a connected ACP client +/// +/// This struct tracks client connections, including when they connected, +/// their last activity, and which sessions they have active. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientConnection { + /// Unique identifier for this client + pub client_id: String, + + /// When the client first connected + pub connected_at: DateTime, + + /// When the client was last active (e.g., sent a request) + pub last_activity: DateTime, + + /// List of client_session_ids owned by this client + pub sessions: Vec, + + /// Optional client capabilities from the initialize handshake + /// + /// Contains information about what the client supports, such as + /// protocol version, supported methods, etc. + pub capabilities: Option, + + /// Preferred agent type for connector routing (from --select-connector) + pub preferred_connector: Option, +} + +impl ClientConnection { + /// Create a new client connection + pub fn new( + client_id: String, + capabilities: Option, + preferred_connector: Option, + ) -> Self { + let now = Utc::now(); + Self { + client_id, + connected_at: now, + last_activity: now, + sessions: Vec::new(), + capabilities, + preferred_connector, + } + } + + /// Add a session to this client's session list + pub fn add_session(&mut self, client_session_id: String) { + if !self.sessions.contains(&client_session_id) { + self.sessions.push(client_session_id); + } + } + + /// Remove a session from this client's session list + pub fn remove_session(&mut self, client_session_id: &str) -> bool { + if let Some(pos) = self.sessions.iter().position(|s| s == client_session_id) { + self.sessions.remove(pos); + true + } else { + false + } + } + + /// Update the last activity timestamp to now + pub fn touch(&mut self) { + self.last_activity = Utc::now(); + } +} + +/// Lightweight client information for API responses +/// +/// This struct provides a summary of a connected client suitable for +/// listing in admin interfaces and API responses. It contains only +/// the essential information without internal details. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ClientInfo { + /// Unique identifier for this client + pub client_id: String, + + /// When the client first connected (ISO 8601 format) + pub connected_at: String, + + /// Number of active sessions owned by this client + pub active_sessions_count: usize, +} + +impl ClientInfo { + /// Create a new ClientInfo from a ClientConnection + pub fn from_connection(conn: &ClientConnection) -> Self { + Self { + client_id: conn.client_id.clone(), + connected_at: conn.connected_at.to_rfc3339(), + active_sessions_count: conn.sessions.len(), + } + } +} + +/// Internal state for the session manager +#[derive(Debug, Default)] +struct SessionManagerState { + /// Map from client_session_id to SessionMapping + mappings: HashMap, + + /// Map from client_id to ClientConnection + clients: HashMap, +} + +/// Builder configuration for SessionManager +pub struct SessionManagerBuilder { + bus_publisher: Option, + acceptor_connector_uid: Option, +} + +impl SessionManagerBuilder { + /// Create a new builder + pub fn new() -> Self { + Self { + bus_publisher: None, + acceptor_connector_uid: None, + } + } + + /// Set the bus publisher for emitting ACP client events. + /// + /// When set, the SessionManager will publish events like + /// `AcpClientConnected` and `AcpClientDisconnected` onto the shared + /// event bus via this callback. + pub fn with_bus_publisher(mut self, publisher: BusPublishFn) -> Self { + self.bus_publisher = Some(publisher); + self + } + + /// Set the Acceptor connector's UID for meta session creation + /// + /// This UID is included in `AcpClientConnected` events so the archivist + /// can create meta sessions under the correct connector. + pub fn with_acceptor_connector_uid(mut self, uid: Uuid) -> Self { + self.acceptor_connector_uid = Some(uid); + self + } + + /// Build the SessionManager + pub fn build(self) -> SessionManager { + SessionManager { + state: Arc::new(RwLock::new(SessionManagerState::default())), + bus_publisher: self.bus_publisher, + acceptor_connector_uid: self.acceptor_connector_uid, + } + } +} + +impl Default for SessionManagerBuilder { + fn default() -> Self { + Self::new() + } +} + +/// Thread-safe session manager for the ACP Server +/// +/// Manages session mappings between client-facing and internal session IDs, +/// as well as tracking connected clients and their capabilities. +/// +/// The `SessionManager` is designed to be cloned and shared across async tasks. +/// All operations are thread-safe due to internal `Arc>` synchronization. +/// +/// ## Event Emission +/// +/// When configured with a bus publisher (via `SessionManagerBuilder`), the +/// manager will emit ACP client connection events onto the shared event bus: +/// - `AcpClientConnected` when a client registers +/// - `AcpClientDisconnected` when a client unregisters or is disconnected +#[derive(Clone)] +pub struct SessionManager { + state: Arc>, + /// Optional bus publisher for emitting ACP client events + bus_publisher: Option, + /// The Acceptor connector's UID for meta session creation + acceptor_connector_uid: Option, +} + +impl std::fmt::Debug for SessionManager { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SessionManager") + .field("state", &self.state) + .field("has_bus_publisher", &self.bus_publisher.is_some()) + .field("acceptor_connector_uid", &self.acceptor_connector_uid) + .finish() + } +} + +impl Default for SessionManager { + fn default() -> Self { + Self::new() + } +} + +impl SessionManager { + /// Create a new session manager + pub fn new() -> Self { + Self { + state: Arc::new(RwLock::new(SessionManagerState::default())), + bus_publisher: None, + acceptor_connector_uid: None, + } + } + + /// Create a builder for more advanced configuration + pub fn builder() -> SessionManagerBuilder { + SessionManagerBuilder::new() + } + + /// Emit an event if a bus publisher is configured. + /// + /// Wraps the event in a default-routed `BusEvent` with + /// `EventOrigin::Runtime`. Subscribers interested in more precise + /// routing (e.g., ACP-specific filtering) can match on the event + /// payload itself. + fn emit_event(&self, event: Event) { + if let Some(ref publisher) = self.bus_publisher { + let kind = match &event { + Event::AcpClientConnected { .. } + | Event::AcpClientDisconnected { .. } + | Event::AcpClientSessionOpened { .. } + | Event::AcpClientSessionRouted { .. } => EventKind::System, + _ => EventKind::SessionLifecycle, + }; + let bus_event = BusEvent { + routing: EventRouting { + kind, + ..Default::default() + }, + origin: EventOrigin::Runtime, + event: Arc::new(event), + }; + publisher(bus_event); + } + } + + /// Create a new session mapping + /// + /// # Parameters + /// - `client_id`: The ID of the client creating this session + /// - `client_session_id`: Optional client-provided session ID. If None, a UUID will be generated. + /// - `connector_id`: The ID of the connector that will handle this session + /// + /// # Returns + /// The created `SessionMapping` + /// + /// # Panics + /// Panics if the internal lock is poisoned. + pub fn create_mapping( + &self, + client_id: &str, + client_session_id: Option, + connector_id: String, + ) -> SessionMapping { + // client_session_id is an external ID from the connector client (can be UUID4, string, etc.) + // We preserve it as-is for external system compatibility. + let client_session_id = client_session_id.unwrap_or_else(|| Uuid::new_v4().to_string()); + + // internal_session_id is used internally and passed to archivist as native_session_id. + // Use UUID7 for time-ordered, sortable identifiers. + let internal_session_id = Uuid::now_v7().to_string(); + + let mapping = SessionMapping::new( + client_session_id.clone(), + internal_session_id, + connector_id, + client_id.to_string(), + ); + + let mut state = self.state.write().expect("Lock poisoned"); + + // Store the mapping + state.mappings.insert(client_session_id.clone(), mapping.clone()); + + // Add session to client's session list + if let Some(client) = state.clients.get_mut(client_id) { + client.add_session(client_session_id.clone()); + } + + // Emit AcpClientSessionOpened event + self.emit_event(Event::AcpClientSessionOpened { + client_id: client_id.to_string(), + gateway_session_id: mapping.internal_session_id.clone(), + client_session_id: client_session_id, + timestamp: chrono::Utc::now().to_rfc3339(), + }); + + mapping + } + + /// Get a session mapping by client session ID + /// + /// # Parameters + /// - `client_session_id`: The client-facing session ID to look up + /// + /// # Returns + /// `Some(SessionMapping)` if found, `None` otherwise + pub fn get_mapping(&self, client_session_id: &str) -> Option { + let state = self.state.read().expect("Lock poisoned"); + state.mappings.get(client_session_id).cloned() + } + + /// Remove a session mapping by client session ID + /// + /// This also removes the session from the owning client's session list. + /// + /// # Parameters + /// - `client_session_id`: The client-facing session ID to remove + /// + /// # Returns + /// `Some(SessionMapping)` if found and removed, `None` otherwise + pub fn remove_mapping(&self, client_session_id: &str) -> Option { + let mut state = self.state.write().expect("Lock poisoned"); + + // Remove the mapping + let mapping = state.mappings.remove(client_session_id); + + // Also remove from client's session list if mapping existed + if let Some(ref m) = mapping { + if let Some(client) = state.clients.get_mut(&m.client_id) { + client.remove_session(client_session_id); + } + } + + mapping + } + + /// Register a new client connection + /// + /// # Parameters + /// - `capabilities`: Optional client capabilities from the initialize handshake + /// + /// # Returns + /// The generated unique client ID (UUID7 for time-ordered tracking) + pub fn register_client(&self, capabilities: Option) -> String { + let client_id = Uuid::now_v7().to_string(); + let connected_at = Utc::now(); + let client = ClientConnection::new(client_id.clone(), capabilities.clone(), None); + + let mut state = self.state.write().expect("Lock poisoned"); + state.clients.insert(client_id.clone(), client); + + tracing::info!(client_id = %client_id, "ACP client connected"); + + // Emit AcpClientConnected event + self.emit_event(Event::AcpClientConnected { + client_id: client_id.clone(), + connected_at: connected_at.to_rfc3339(), + capabilities, + connector_uid: self.acceptor_connector_uid + .map(|u| u.to_string()) + .unwrap_or_else(|| Uuid::nil().to_string()), + }); + + client_id + } + + /// Register a client with a specific client_id (for HTTP clients with X-Client-Id header) + /// + /// # Parameters + /// - `client_id`: The client ID provided by the client (from X-Client-Id header) + /// - `capabilities`: Optional client capabilities + /// + /// # Panics + /// Panics if the internal lock is poisoned. + pub fn register_client_with_id(&self, client_id: String, capabilities: Option) { + let connected_at = Utc::now(); + let is_new_client; + + { + let mut state = self.state.write().expect("Lock poisoned"); + is_new_client = !state.clients.contains_key(&client_id); + + let client = ClientConnection::new(client_id.clone(), capabilities.clone(), None); + state.clients.insert(client_id.clone(), client); + } + + if is_new_client { + tracing::info!(client_id = %client_id, "ACP client connected with provided ID"); + + // Emit AcpClientConnected event for new clients + self.emit_event(Event::AcpClientConnected { + client_id: client_id.clone(), + connected_at: connected_at.to_rfc3339(), + capabilities, + connector_uid: self.acceptor_connector_uid + .map(|u| u.to_string()) + .unwrap_or_else(|| Uuid::nil().to_string()), + }); + } else { + tracing::debug!(client_id = %client_id, "ACP client reconnected with existing ID"); + } + } + + /// Unregister a client and remove all associated session mappings + /// + /// # Parameters + /// - `client_id`: The ID of the client to unregister + /// + /// # Returns + /// A list of session mappings that were removed (for cleanup purposes) + pub fn unregister_client(&self, client_id: &str) -> Vec { + let disconnected_at = Utc::now(); + let removed_mappings; + let was_connected; + + { + let mut state = self.state.write().expect("Lock poisoned"); + + // Remove the client and get their sessions + let client = state.clients.remove(client_id); + + let mut mappings = Vec::new(); + + if let Some(client) = client { + tracing::info!(client_id = %client_id, "ACP client disconnected"); + + // Remove all session mappings owned by this client + for session_id in client.sessions { + if let Some(mapping) = state.mappings.remove(&session_id) { + mappings.push(mapping); + } + } + was_connected = true; + } else { + was_connected = false; + } + + removed_mappings = mappings; + } + + // Emit AcpClientDisconnected event if client was found + if was_connected { + self.emit_event(Event::AcpClientDisconnected { + client_id: client_id.to_string(), + disconnected_at: disconnected_at.to_rfc3339(), + reason: None, + }); + } + + removed_mappings + } + + /// Update the last activity timestamp for a client + /// + /// # Parameters + /// - `client_id`: The ID of the client to update + /// + /// # Returns + /// `true` if the client was found and updated, `false` otherwise + pub fn update_client_activity(&self, client_id: &str) -> bool { + let mut state = self.state.write().expect("Lock poisoned"); + + if let Some(client) = state.clients.get_mut(client_id) { + client.touch(); + true + } else { + false + } + } + + /// Update the preferred connector for a client + /// + /// # Parameters + /// - `client_id`: The ID of the client to update + /// - `preferred_connector`: The preferred connector ID or magic word + /// + /// # Returns + /// `true` if the client was found and updated, `false` otherwise + pub fn update_client_preferred_connector( + &self, + client_id: &str, + preferred_connector: Option, + ) -> bool { + let mut state = self.state.write().expect("Lock poisoned"); + + if let Some(client) = state.clients.get_mut(client_id) { + client.preferred_connector = preferred_connector; + true + } else { + false + } + } + + /// List all sessions for a client + /// + /// # Parameters + /// - `client_id`: The ID of the client + /// + /// # Returns + /// A list of client_session_ids owned by the client, or empty vec if client not found + pub fn list_client_sessions(&self, client_id: &str) -> Vec { + let state = self.state.read().expect("Lock poisoned"); + + state + .clients + .get(client_id) + .map(|c| c.sessions.clone()) + .unwrap_or_default() + } + + /// Get information about a client + /// + /// # Parameters + /// - `client_id`: The ID of the client + /// + /// # Returns + /// `Some(ClientConnection)` if found, `None` otherwise + pub fn get_client(&self, client_id: &str) -> Option { + let state = self.state.read().expect("Lock poisoned"); + state.clients.get(client_id).cloned() + } + + /// Get the number of active session mappings + pub fn mapping_count(&self) -> usize { + let state = self.state.read().expect("Lock poisoned"); + state.mappings.len() + } + + /// Get the number of connected clients + pub fn client_count(&self) -> usize { + let state = self.state.read().expect("Lock poisoned"); + state.clients.len() + } + + /// List all connected clients + pub fn list_clients(&self) -> Vec { + let state = self.state.read().expect("Lock poisoned"); + state.clients.values().cloned().collect() + } + + /// List all session mappings + pub fn list_mappings(&self) -> Vec { + let state = self.state.read().expect("Lock poisoned"); + state.mappings.values().cloned().collect() + } + + /// Get a mapping by internal session ID + /// + /// This is useful when receiving events from a connector that uses internal IDs. + /// + /// # Parameters + /// - `internal_session_id`: The internal session ID to look up + /// + /// # Returns + /// `Some(SessionMapping)` if found, `None` otherwise + pub fn get_mapping_by_internal_id(&self, internal_session_id: &str) -> Option { + let state = self.state.read().expect("Lock poisoned"); + state + .mappings + .values() + .find(|m| m.internal_session_id == internal_session_id) + .cloned() + } + + /// List all connected clients with lightweight info + /// + /// Returns a summary of all connected clients suitable for API responses. + /// This is more efficient than `list_clients()` when full client details + /// are not needed. + /// + /// # Returns + /// A vector of `ClientInfo` containing summary information about each client + pub fn list_clients_info(&self) -> Vec { + let state = self.state.read().expect("Lock poisoned"); + state + .clients + .values() + .map(ClientInfo::from_connection) + .collect() + } + + /// Force disconnect a client and clean up all associated sessions + /// + /// This method forcibly removes a client and all their session mappings. + /// Use this for administrative disconnection of clients. + /// + /// # Parameters + /// - `client_id`: The ID of the client to disconnect + /// + /// # Returns + /// - `Ok(())` if the client was found and disconnected + /// - `Err(AcpServerError::ClientNotFound)` if the client does not exist + /// + /// # Example + /// + /// ```rust + /// use dirigent_acp_api::session_manager::SessionManager; + /// + /// let manager = SessionManager::new(); + /// let client_id = manager.register_client(None); + /// + /// // Disconnect the client + /// assert!(manager.disconnect_client(&client_id).is_ok()); + /// + /// // Disconnecting again returns an error + /// assert!(manager.disconnect_client(&client_id).is_err()); + /// ``` + pub fn disconnect_client(&self, client_id: &str) -> Result<(), AcpServerError> { + let disconnected_at = Utc::now(); + + { + let mut state = self.state.write().expect("Lock poisoned"); + + // Check if the client exists + if !state.clients.contains_key(client_id) { + return Err(AcpServerError::ClientNotFound(client_id.to_string())); + } + + // Remove the client and get their sessions + if let Some(client) = state.clients.remove(client_id) { + tracing::info!(client_id = %client_id, "ACP client disconnected (force disconnect)"); + + // Remove all session mappings owned by this client + for session_id in client.sessions { + state.mappings.remove(&session_id); + } + } + } + + // Emit AcpClientDisconnected event + self.emit_event(Event::AcpClientDisconnected { + client_id: client_id.to_string(), + disconnected_at: disconnected_at.to_rfc3339(), + reason: Some("Force disconnect".to_string()), + }); + + Ok(()) + } + + /// Update a session mapping to point to a different connector + /// + /// Used when a session is transferred from one connector to another. + /// The client_session_id remains the same, but the internal routing changes. + /// + /// # Arguments + /// * `client_session_id` - The client-facing session ID (unchanged) + /// * `new_connector_id` - The new connector to route to + /// * `new_connector_title` - Display title for the new connector + /// * `new_internal_session_id` - The new internal session ID in the target connector + /// + /// # Returns + /// The old mapping if it existed, None otherwise + pub fn update_mapping_connector( + &self, + client_session_id: &str, + new_connector_id: String, + new_connector_title: Option, + new_internal_session_id: String, + ) -> Option { + let mut state = self.state.write().expect("Lock poisoned"); + + let mapping = state.mappings.get_mut(client_session_id)?; + + // Capture old state for return + let old_mapping = mapping.clone(); + + // Update routing + let from_session_id = mapping.internal_session_id.clone(); + mapping.connector_id = new_connector_id.clone(); + mapping.internal_session_id = new_internal_session_id.clone(); + + // Get client_id before releasing lock + let client_id = mapping.client_id.clone(); + + tracing::info!( + client_session_id = %client_session_id, + old_connector = %old_mapping.connector_id, + new_connector = %mapping.connector_id, + "Session mapping updated for transfer" + ); + + // Must drop state lock before emitting event (emit_event might do I/O) + drop(state); + + // Emit AcpClientSessionRouted event + let connector_title = new_connector_title.unwrap_or_else(|| new_connector_id.clone()); + self.emit_event(Event::AcpClientSessionRouted { + client_id, + from_session_id, + to_session_id: new_internal_session_id, + connector_id: new_connector_id, + connector_title, + connector_kind: None, // TODO: Populate from connector info + model: None, // TODO: Populate from session/connector state + agent_info: None, // TODO: Populate from connector capabilities + timestamp: chrono::Utc::now().to_rfc3339(), + }); + + Some(old_mapping) + } + + /// Update the ownership model for a session mapping + /// + /// # Parameters + /// * `client_session_id` - The client-facing session ID + /// * `ownership` - The new ownership model + /// + /// # Returns + /// True if the mapping was found and updated, false otherwise + pub fn update_mapping_ownership( + &self, + client_session_id: &str, + ownership: SessionOwnership, + ) -> bool { + let mut state = self.state.write().expect("Lock poisoned"); + + if let Some(mapping) = state.mappings.get_mut(client_session_id) { + mapping.ownership = ownership; + tracing::debug!( + client_session_id = %client_session_id, + "Updated session ownership model" + ); + true + } else { + false + } + } + + /// Update the internal session ID for a mapping + /// + /// This is called after session creation to link the mapping to the actual + /// connector session ID. This is necessary because the mapping is created + /// before the connector session, and we generate a placeholder internal_session_id. + /// + /// # Parameters + /// * `client_session_id` - The client-facing session ID + /// * `new_internal_session_id` - The actual session ID from the connector + /// + /// # Returns + /// The old internal_session_id if the mapping was found, None otherwise + pub fn update_mapping_internal_session( + &self, + client_session_id: &str, + new_internal_session_id: String, + ) -> Option { + let mut state = self.state.write().expect("Lock poisoned"); + + if let Some(mapping) = state.mappings.get_mut(client_session_id) { + let old_internal_id = std::mem::replace(&mut mapping.internal_session_id, new_internal_session_id.clone()); + tracing::info!( + client_session_id = %client_session_id, + old_internal_id = %old_internal_id, + new_internal_id = %new_internal_session_id, + "Updated internal session ID to match connector session" + ); + Some(old_internal_id) + } else { + None + } + } + + /// Find a mapping by the Gateway session ID (for transfer correlation) + /// + /// Used to find the client mapping when we receive a SessionTransferred event + /// that references the original Gateway session. + pub fn get_mapping_by_gateway_session( + &self, + gateway_session_id: &str, + ) -> Option { + let state = self.state.read().expect("Lock poisoned"); + + // Look for a mapping where internal_session_id matches the gateway session + state.mappings.values().find(|m| { + m.internal_session_id == gateway_session_id + }).cloned() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_session_mapping_creation() { + let mapping = SessionMapping::new( + "client-sess-1".to_string(), + "internal-sess-1".to_string(), + "connector-1".to_string(), + "client-1".to_string(), + ); + + assert_eq!(mapping.client_session_id, "client-sess-1"); + assert_eq!(mapping.internal_session_id, "internal-sess-1"); + assert_eq!(mapping.connector_id, "connector-1"); + assert_eq!(mapping.client_id, "client-1"); + } + + #[test] + fn test_client_connection_creation() { + let capabilities = serde_json::json!({"version": "1.0"}); + let client = ClientConnection::new("client-1".to_string(), Some(capabilities.clone()), None); + + assert_eq!(client.client_id, "client-1"); + assert!(client.sessions.is_empty()); + assert_eq!(client.capabilities, Some(capabilities)); + } + + #[test] + fn test_client_connection_sessions() { + let mut client = ClientConnection::new("client-1".to_string(), None, None); + + // Add sessions + client.add_session("session-1".to_string()); + client.add_session("session-2".to_string()); + assert_eq!(client.sessions.len(), 2); + + // Adding duplicate should not increase count + client.add_session("session-1".to_string()); + assert_eq!(client.sessions.len(), 2); + + // Remove session + assert!(client.remove_session("session-1")); + assert_eq!(client.sessions.len(), 1); + + // Removing non-existent session returns false + assert!(!client.remove_session("session-1")); + } + + #[test] + fn test_session_manager_new() { + let manager = SessionManager::new(); + assert_eq!(manager.client_count(), 0); + assert_eq!(manager.mapping_count(), 0); + } + + #[test] + fn test_session_manager_register_client() { + let manager = SessionManager::new(); + + let client_id = manager.register_client(None); + assert!(!client_id.is_empty()); + assert_eq!(manager.client_count(), 1); + + let client = manager.get_client(&client_id); + assert!(client.is_some()); + let client = client.unwrap(); + assert_eq!(client.client_id, client_id); + assert!(client.capabilities.is_none()); + } + + #[test] + fn test_session_manager_register_client_with_capabilities() { + let manager = SessionManager::new(); + let capabilities = serde_json::json!({"protocol": "acp/1.0"}); + + let client_id = manager.register_client(Some(capabilities.clone())); + let client = manager.get_client(&client_id).unwrap(); + assert_eq!(client.capabilities, Some(capabilities)); + } + + #[test] + fn test_session_manager_create_mapping() { + let manager = SessionManager::new(); + + // Register a client first + let client_id = manager.register_client(None); + + // Create a mapping + let mapping = manager.create_mapping(&client_id, None, "connector-1".to_string()); + + assert!(!mapping.client_session_id.is_empty()); + assert!(!mapping.internal_session_id.is_empty()); + assert_eq!(mapping.connector_id, "connector-1"); + assert_eq!(mapping.client_id, client_id); + + // Verify mapping was stored + assert_eq!(manager.mapping_count(), 1); + + // Verify session was added to client + let sessions = manager.list_client_sessions(&client_id); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0], mapping.client_session_id); + } + + #[test] + fn test_session_manager_create_mapping_with_custom_id() { + let manager = SessionManager::new(); + let client_id = manager.register_client(None); + + let mapping = manager.create_mapping( + &client_id, + Some("my-custom-session-id".to_string()), + "connector-1".to_string(), + ); + + assert_eq!(mapping.client_session_id, "my-custom-session-id"); + } + + #[test] + fn test_session_manager_get_mapping() { + let manager = SessionManager::new(); + let client_id = manager.register_client(None); + + let mapping = manager.create_mapping(&client_id, None, "connector-1".to_string()); + + // Get existing mapping + let retrieved = manager.get_mapping(&mapping.client_session_id); + assert!(retrieved.is_some()); + let retrieved = retrieved.unwrap(); + assert_eq!(retrieved.client_session_id, mapping.client_session_id); + + // Get non-existent mapping + assert!(manager.get_mapping("non-existent").is_none()); + } + + #[test] + fn test_session_manager_remove_mapping() { + let manager = SessionManager::new(); + let client_id = manager.register_client(None); + + let mapping = manager.create_mapping(&client_id, None, "connector-1".to_string()); + let session_id = mapping.client_session_id.clone(); + + assert_eq!(manager.mapping_count(), 1); + + // Remove the mapping + let removed = manager.remove_mapping(&session_id); + assert!(removed.is_some()); + assert_eq!(removed.unwrap().client_session_id, session_id); + assert_eq!(manager.mapping_count(), 0); + + // Verify session was removed from client + let sessions = manager.list_client_sessions(&client_id); + assert!(sessions.is_empty()); + + // Removing again returns None + assert!(manager.remove_mapping(&session_id).is_none()); + } + + #[test] + fn test_session_manager_unregister_client() { + let manager = SessionManager::new(); + let client_id = manager.register_client(None); + + // Create some sessions + manager.create_mapping(&client_id, None, "connector-1".to_string()); + manager.create_mapping(&client_id, None, "connector-2".to_string()); + + assert_eq!(manager.client_count(), 1); + assert_eq!(manager.mapping_count(), 2); + + // Unregister the client + let removed_mappings = manager.unregister_client(&client_id); + + assert_eq!(removed_mappings.len(), 2); + assert_eq!(manager.client_count(), 0); + assert_eq!(manager.mapping_count(), 0); + } + + #[test] + fn test_session_manager_unregister_nonexistent_client() { + let manager = SessionManager::new(); + let removed = manager.unregister_client("nonexistent"); + assert!(removed.is_empty()); + } + + #[test] + fn test_session_manager_update_client_activity() { + let manager = SessionManager::new(); + let client_id = manager.register_client(None); + + let client_before = manager.get_client(&client_id).unwrap(); + let activity_before = client_before.last_activity; + + // Small delay to ensure time difference + std::thread::sleep(std::time::Duration::from_millis(10)); + + assert!(manager.update_client_activity(&client_id)); + + let client_after = manager.get_client(&client_id).unwrap(); + assert!(client_after.last_activity > activity_before); + } + + #[test] + fn test_session_manager_update_nonexistent_client_activity() { + let manager = SessionManager::new(); + assert!(!manager.update_client_activity("nonexistent")); + } + + #[test] + fn test_session_manager_list_client_sessions() { + let manager = SessionManager::new(); + let client_id = manager.register_client(None); + + // Empty initially + assert!(manager.list_client_sessions(&client_id).is_empty()); + + // Create sessions + let m1 = manager.create_mapping(&client_id, None, "connector-1".to_string()); + let m2 = manager.create_mapping(&client_id, None, "connector-2".to_string()); + + let sessions = manager.list_client_sessions(&client_id); + assert_eq!(sessions.len(), 2); + assert!(sessions.contains(&m1.client_session_id)); + assert!(sessions.contains(&m2.client_session_id)); + } + + #[test] + fn test_session_manager_list_clients() { + let manager = SessionManager::new(); + + manager.register_client(None); + manager.register_client(Some(serde_json::json!({"test": true}))); + + let clients = manager.list_clients(); + assert_eq!(clients.len(), 2); + } + + #[test] + fn test_session_manager_list_mappings() { + let manager = SessionManager::new(); + let client_id = manager.register_client(None); + + manager.create_mapping(&client_id, None, "connector-1".to_string()); + manager.create_mapping(&client_id, None, "connector-2".to_string()); + + let mappings = manager.list_mappings(); + assert_eq!(mappings.len(), 2); + } + + #[test] + fn test_session_manager_get_mapping_by_internal_id() { + let manager = SessionManager::new(); + let client_id = manager.register_client(None); + + let mapping = manager.create_mapping(&client_id, None, "connector-1".to_string()); + let internal_id = mapping.internal_session_id.clone(); + + // Find by internal ID + let found = manager.get_mapping_by_internal_id(&internal_id); + assert!(found.is_some()); + assert_eq!(found.unwrap().internal_session_id, internal_id); + + // Not found + assert!(manager.get_mapping_by_internal_id("nonexistent").is_none()); + } + + #[test] + fn test_session_manager_clone() { + let manager = SessionManager::new(); + let cloned = manager.clone(); + + // Both should share the same state + let client_id = manager.register_client(None); + assert_eq!(cloned.client_count(), 1); + assert!(cloned.get_client(&client_id).is_some()); + } + + #[test] + fn test_session_manager_default() { + let manager: SessionManager = Default::default(); + assert_eq!(manager.client_count(), 0); + } + + #[test] + fn test_session_mapping_serialization() { + let mapping = SessionMapping::new( + "client-sess".to_string(), + "internal-sess".to_string(), + "connector-1".to_string(), + "client-1".to_string(), + ); + + let json = serde_json::to_string(&mapping).unwrap(); + assert!(json.contains("client-sess")); + assert!(json.contains("internal-sess")); + + let deserialized: SessionMapping = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.client_session_id, mapping.client_session_id); + } + + #[test] + fn test_client_connection_serialization() { + let client = ClientConnection::new( + "client-1".to_string(), + Some(serde_json::json!({"version": "1.0"})), + None, + ); + + let json = serde_json::to_string(&client).unwrap(); + assert!(json.contains("client-1")); + assert!(json.contains("version")); + + let deserialized: ClientConnection = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.client_id, client.client_id); + } + + #[test] + fn test_client_info_from_connection() { + let mut client = ClientConnection::new("client-1".to_string(), None, None); + client.add_session("session-1".to_string()); + client.add_session("session-2".to_string()); + + let info = ClientInfo::from_connection(&client); + + assert_eq!(info.client_id, "client-1"); + assert_eq!(info.active_sessions_count, 2); + // connected_at should be a valid ISO 8601 string + assert!(info.connected_at.contains("T")); // ISO 8601 format contains 'T' + } + + #[test] + fn test_client_info_serialization() { + let info = ClientInfo { + client_id: "client-123".to_string(), + connected_at: "2024-01-01T00:00:00Z".to_string(), + active_sessions_count: 3, + }; + + let json = serde_json::to_string(&info).unwrap(); + assert!(json.contains("client-123")); + assert!(json.contains("2024-01-01T00:00:00Z")); + assert!(json.contains("3")); + + let deserialized: ClientInfo = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.client_id, info.client_id); + assert_eq!(deserialized.connected_at, info.connected_at); + assert_eq!(deserialized.active_sessions_count, info.active_sessions_count); + } + + #[test] + fn test_session_manager_list_clients_info() { + let manager = SessionManager::new(); + + let client1_id = manager.register_client(None); + let client2_id = manager.register_client(None); + + // Add sessions to client1 + manager.create_mapping(&client1_id, None, "connector-1".to_string()); + manager.create_mapping(&client1_id, None, "connector-2".to_string()); + + let clients_info = manager.list_clients_info(); + assert_eq!(clients_info.len(), 2); + + // Find client1 info + let client1_info = clients_info + .iter() + .find(|c| c.client_id == client1_id) + .expect("Client 1 should be in list"); + assert_eq!(client1_info.active_sessions_count, 2); + + // Find client2 info + let client2_info = clients_info + .iter() + .find(|c| c.client_id == client2_id) + .expect("Client 2 should be in list"); + assert_eq!(client2_info.active_sessions_count, 0); + } + + #[test] + fn test_session_manager_disconnect_client() { + let manager = SessionManager::new(); + + let client_id = manager.register_client(None); + manager.create_mapping(&client_id, None, "connector-1".to_string()); + manager.create_mapping(&client_id, None, "connector-2".to_string()); + + assert_eq!(manager.client_count(), 1); + assert_eq!(manager.mapping_count(), 2); + + // Disconnect the client + let result = manager.disconnect_client(&client_id); + assert!(result.is_ok()); + + // Verify client and sessions are removed + assert_eq!(manager.client_count(), 0); + assert_eq!(manager.mapping_count(), 0); + } + + #[test] + fn test_session_manager_disconnect_nonexistent_client() { + let manager = SessionManager::new(); + + let result = manager.disconnect_client("nonexistent-client"); + + assert!(result.is_err()); + match result { + Err(AcpServerError::ClientNotFound(id)) => { + assert_eq!(id, "nonexistent-client"); + } + _ => panic!("Expected ClientNotFound error"), + } + } + + #[test] + fn test_session_manager_disconnect_client_twice() { + let manager = SessionManager::new(); + + let client_id = manager.register_client(None); + + // First disconnect should succeed + assert!(manager.disconnect_client(&client_id).is_ok()); + + // Second disconnect should fail + assert!(manager.disconnect_client(&client_id).is_err()); + } + + #[test] + fn test_update_mapping_connector() { + let manager = SessionManager::new(); + + // Create initial mapping + let client_id = manager.register_client(None); + let mapping = manager.create_mapping( + &client_id, + Some("client-session-1".to_string()), + "gateway-1".to_string(), + ); + + // Store the original internal session ID + let original_internal_id = mapping.internal_session_id.clone(); + + // Update to new connector + let old = manager.update_mapping_connector( + "client-session-1", + "opencode-1".to_string(), + Some("OpenCode".to_string()), + "opencode-internal-1".to_string(), + ); + + assert!(old.is_some()); + let old = old.unwrap(); + assert_eq!(old.connector_id, "gateway-1"); + assert_eq!(old.internal_session_id, original_internal_id); + + // Verify new mapping + let updated = manager.get_mapping("client-session-1").unwrap(); + assert_eq!(updated.connector_id, "opencode-1"); + assert_eq!(updated.internal_session_id, "opencode-internal-1"); + assert_eq!(updated.client_session_id, "client-session-1"); // Should remain unchanged + } + + #[test] + fn test_update_mapping_connector_nonexistent() { + let manager = SessionManager::new(); + + // Try to update a nonexistent mapping + let result = manager.update_mapping_connector( + "nonexistent", + "new-connector".to_string(), + None, + "new-session".to_string(), + ); + + assert!(result.is_none()); + } + + #[test] + fn test_get_mapping_by_gateway_session() { + let manager = SessionManager::new(); + + let client_id = manager.register_client(None); + let mapping = manager.create_mapping( + &client_id, + Some("client-session-1".to_string()), + "gateway-1".to_string(), + ); + + let gateway_session_id = mapping.internal_session_id.clone(); + + let found = manager.get_mapping_by_gateway_session(&gateway_session_id); + assert!(found.is_some()); + let found = found.unwrap(); + assert_eq!(found.client_session_id, "client-session-1"); + assert_eq!(found.internal_session_id, gateway_session_id); + + let not_found = manager.get_mapping_by_gateway_session("nonexistent"); + assert!(not_found.is_none()); + } + + #[test] + fn test_get_mapping_by_gateway_session_after_transfer() { + let manager = SessionManager::new(); + + let client_id = manager.register_client(None); + let mapping = manager.create_mapping( + &client_id, + Some("client-session-1".to_string()), + "gateway-1".to_string(), + ); + + let gateway_session_id = mapping.internal_session_id.clone(); + + // Transfer the session to opencode + manager.update_mapping_connector( + "client-session-1", + "opencode-1".to_string(), + Some("OpenCode".to_string()), + "opencode-internal-1".to_string(), + ); + + // Should no longer find by old gateway session ID since internal_session_id changed + let not_found = manager.get_mapping_by_gateway_session(&gateway_session_id); + assert!(not_found.is_none()); + + // Should find by new internal session ID + let found = manager.get_mapping_by_gateway_session("opencode-internal-1"); + assert!(found.is_some()); + assert_eq!(found.unwrap().client_session_id, "client-session-1"); + } +} diff --git a/crates/dirigent_acp_api/src/sse/content_transform.rs b/crates/dirigent_acp_api/src/sse/content_transform.rs new file mode 100644 index 0000000..e5c59c1 --- /dev/null +++ b/crates/dirigent_acp_api/src/sse/content_transform.rs @@ -0,0 +1,435 @@ +//! Content and Metadata Transformation Utilities +//! +//! This module provides helper functions for transforming content between +//! the internal Dirigent protocol and the ACP wire format. + +use dirigent_protocol::ContentBlock; +use serde_json::json; + +/// Extract 'kind' from tool call metadata (stored as acp_kind) +pub fn extract_kind(tool_call: &dirigent_protocol::ToolCall) -> Option { + tool_call + .metadata + .as_ref() + .and_then(|m| m.get("acp_kind")) + .and_then(|k| k.as_str()) + .map(String::from) +} + +/// Convert ToolCallStatus to ACP string format +pub fn tool_call_status_to_string(status: &dirigent_protocol::ToolCallStatus) -> String { + match status { + dirigent_protocol::ToolCallStatus::Pending => "pending".to_string(), + dirigent_protocol::ToolCallStatus::Running => "in_progress".to_string(), + dirigent_protocol::ToolCallStatus::Completed => "completed".to_string(), + dirigent_protocol::ToolCallStatus::Error => "failed".to_string(), + } +} + +/// Unwrap ToolCallContent wrappers to extract ContentBlocks for SSE output +/// +/// The internal protocol uses ToolCallContent wrappers, but the SSE/ACP wire format +/// expects flat ContentBlock arrays. This function extracts Content variants only. +pub fn unwrap_tool_call_content( + content: &[dirigent_protocol::ToolCallContent], +) -> Vec { + content + .iter() + .filter_map(|wrapper| match wrapper { + dirigent_protocol::ToolCallContent::Content { content } => Some(content.clone()), + // Diff and Terminal variants are not yet supported in SSE output + _ => None, + }) + .collect() +} + +/// Rebuild _meta for ACP output (ensures claudeCode.toolName is present) +pub fn rebuild_meta( + meta: &Option, + tool_call: &dirigent_protocol::ToolCall, +) -> Option { + let mut meta_obj = serde_json::Map::new(); + + // Add claudeCode.toolName + let mut claude_code = serde_json::Map::new(); + claude_code.insert("toolName".to_string(), json!(tool_call.tool_name)); + + // Merge any existing _meta (preserves toolResponse, etc.) + if let Some(existing_meta) = meta { + // Convert Meta to JSON Value for merging + if let Ok(meta_value) = serde_json::to_value(existing_meta) { + if let Some(obj) = meta_value.as_object() { + for (k, v) in obj { + if k == "claudeCode" { + // Merge into our claudeCode object + if let Some(cc) = v.as_object() { + for (ck, cv) in cc { + // Don't override toolName we just set + if ck != "toolName" { + claude_code.insert(ck.clone(), cv.clone()); + } + } + } + } else { + meta_obj.insert(k.clone(), v.clone()); + } + } + } + } + } + + meta_obj.insert("claudeCode".to_string(), json!(claude_code)); + Some(json!(meta_obj)) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_extract_kind_present() { + let mut metadata = serde_json::Map::new(); + metadata.insert("acp_kind".to_string(), json!("search")); + + let tool_call = dirigent_protocol::ToolCall { + id: "call_123".to_string(), + tool_name: "bash".to_string(), + status: dirigent_protocol::ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: Some(json!(metadata)), + origin: None, + }; + + assert_eq!(extract_kind(&tool_call), Some("search".to_string())); + } + + #[test] + fn test_extract_kind_missing_metadata() { + let tool_call = dirigent_protocol::ToolCall { + id: "call_123".to_string(), + tool_name: "bash".to_string(), + status: dirigent_protocol::ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + assert_eq!(extract_kind(&tool_call), None); + } + + #[test] + fn test_extract_kind_missing_field() { + let mut metadata = serde_json::Map::new(); + metadata.insert("other_field".to_string(), json!("value")); + + let tool_call = dirigent_protocol::ToolCall { + id: "call_123".to_string(), + tool_name: "bash".to_string(), + status: dirigent_protocol::ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: Some(json!(metadata)), + origin: None, + }; + + assert_eq!(extract_kind(&tool_call), None); + } + + #[test] + fn test_extract_kind_malformed_metadata() { + // acp_kind is not a string + let mut metadata = serde_json::Map::new(); + metadata.insert("acp_kind".to_string(), json!(123)); + + let tool_call = dirigent_protocol::ToolCall { + id: "call_123".to_string(), + tool_name: "bash".to_string(), + status: dirigent_protocol::ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: Some(json!(metadata)), + origin: None, + }; + + assert_eq!(extract_kind(&tool_call), None); + } + + #[test] + fn test_tool_call_status_to_string_all_variants() { + assert_eq!( + tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Pending), + "pending" + ); + assert_eq!( + tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Running), + "in_progress" + ); + assert_eq!( + tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Completed), + "completed" + ); + assert_eq!( + tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Error), + "failed" + ); + } + + #[test] + fn test_rebuild_meta_basic() { + let tool_call = dirigent_protocol::ToolCall { + id: "call_123".to_string(), + tool_name: "some_tool".to_string(), + status: dirigent_protocol::ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let meta = rebuild_meta(&None, &tool_call); + + assert!(meta.is_some()); + let meta_obj = meta.unwrap(); + assert_eq!(meta_obj["claudeCode"]["toolName"], "some_tool"); + } + + #[test] + fn test_rebuild_meta_preserves_existing_fields() { + let existing_meta = dirigent_protocol::Meta { + provider: Some(dirigent_protocol::ProviderMeta { + name: "test_provider".to_string(), + original_ids: None, + raw_excerpt: None, + }), + extra: std::collections::HashMap::from([( + "customField".to_string(), + json!("customValue"), + )]), + }; + + let tool_call = dirigent_protocol::ToolCall { + id: "call_123".to_string(), + tool_name: "bash".to_string(), + status: dirigent_protocol::ToolCallStatus::Running, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let meta = rebuild_meta(&Some(existing_meta), &tool_call); + + assert!(meta.is_some()); + let meta_obj = meta.unwrap(); + + // Verify claudeCode.toolName is present + assert_eq!(meta_obj["claudeCode"]["toolName"], "bash"); + + // Verify existing fields are preserved + assert_eq!(meta_obj["provider"]["name"], "test_provider"); + assert_eq!(meta_obj["customField"], "customValue"); + } + + #[test] + fn test_rebuild_meta_merges_claude_code_fields() { + let mut existing_meta = dirigent_protocol::Meta::default(); + existing_meta.extra.insert( + "claudeCode".to_string(), + json!({ + "toolResponse": "some response", + "otherField": "other value" + }), + ); + + let tool_call = dirigent_protocol::ToolCall { + id: "call_123".to_string(), + tool_name: "test_tool".to_string(), + status: dirigent_protocol::ToolCallStatus::Completed, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let meta = rebuild_meta(&Some(existing_meta), &tool_call); + + assert!(meta.is_some()); + let meta_obj = meta.unwrap(); + + // Verify toolName is present (we set it) + assert_eq!(meta_obj["claudeCode"]["toolName"], "test_tool"); + + // Verify existing claudeCode fields are preserved + assert_eq!(meta_obj["claudeCode"]["toolResponse"], "some response"); + assert_eq!(meta_obj["claudeCode"]["otherField"], "other value"); + } + + #[test] + fn test_rebuild_meta_does_not_override_tool_name() { + // Existing meta has a different toolName + let mut existing_meta = dirigent_protocol::Meta::default(); + existing_meta.extra.insert( + "claudeCode".to_string(), + json!({ + "toolName": "wrong_tool", + "toolResponse": "response" + }), + ); + + let tool_call = dirigent_protocol::ToolCall { + id: "call_123".to_string(), + tool_name: "correct_tool".to_string(), + status: dirigent_protocol::ToolCallStatus::Completed, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let meta = rebuild_meta(&Some(existing_meta), &tool_call); + + assert!(meta.is_some()); + let meta_obj = meta.unwrap(); + + // Verify our toolName is used (not the one from existing meta) + assert_eq!(meta_obj["claudeCode"]["toolName"], "correct_tool"); + + // Verify other fields are preserved + assert_eq!(meta_obj["claudeCode"]["toolResponse"], "response"); + } + + #[test] + fn test_rebuild_meta_preserves_tool_response() { + // Test that rebuild_meta preserves toolResponse from incoming Claude updates + let mut existing_meta = dirigent_protocol::Meta::default(); + existing_meta.extra.insert( + "claudeCode".to_string(), + json!({ + "toolResponse": { + "mode": "content", + "numFiles": 0, + "filenames": [], + "content": "some output", + "numLines": 58, + "appliedLimit": 100 + }, + "toolName": "OldToolName" + }), + ); + + let tool_call = dirigent_protocol::ToolCall { + id: "call_456".to_string(), + tool_name: "Grep".to_string(), + status: dirigent_protocol::ToolCallStatus::Running, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let meta = rebuild_meta(&Some(existing_meta), &tool_call); + + assert!(meta.is_some()); + let meta_obj = meta.unwrap(); + + // toolName should be updated to current tool_call.tool_name + assert_eq!(meta_obj["claudeCode"]["toolName"], "Grep"); + + // toolResponse should be preserved completely + assert!(meta_obj["claudeCode"]["toolResponse"].is_object()); + assert_eq!(meta_obj["claudeCode"]["toolResponse"]["mode"], "content"); + assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numFiles"], 0); + assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numLines"], 58); + } + + #[test] + fn test_rebuild_meta_round_trip_no_data_loss() { + // Test incoming → internal → outgoing preserves all fields + let mut existing_meta = dirigent_protocol::Meta::default(); + existing_meta.extra.insert( + "claudeCode".to_string(), + json!({ + "toolResponse": { + "mode": "content", + "numFiles": 3, + "filenames": ["file1.rs", "file2.rs", "file3.rs"], + "content": "grep results here", + "numLines": 42, + "appliedLimit": 100, + "customField": "should be preserved", + "nestedObject": { + "deep": "value" + } + }, + "additionalField": "also preserved" + }), + ); + existing_meta + .extra + .insert("customTopLevel".to_string(), json!("preserved too")); + + let tool_call = dirigent_protocol::ToolCall { + id: "call_789".to_string(), + tool_name: "Grep".to_string(), + status: dirigent_protocol::ToolCallStatus::Completed, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let meta = rebuild_meta(&Some(existing_meta), &tool_call); + + assert!(meta.is_some()); + let meta_obj = meta.unwrap(); + + // Verify NO data loss + assert_eq!(meta_obj["claudeCode"]["toolName"], "Grep"); + assert_eq!(meta_obj["claudeCode"]["toolResponse"]["mode"], "content"); + assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numFiles"], 3); + assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numLines"], 42); + assert_eq!( + meta_obj["claudeCode"]["toolResponse"]["customField"], + "should be preserved" + ); + assert_eq!( + meta_obj["claudeCode"]["toolResponse"]["nestedObject"]["deep"], + "value" + ); + assert_eq!(meta_obj["claudeCode"]["additionalField"], "also preserved"); + assert_eq!(meta_obj["customTopLevel"], "preserved too"); + } +} diff --git a/crates/dirigent_acp_api/src/sse/event_translator.rs b/crates/dirigent_acp_api/src/sse/event_translator.rs new file mode 100644 index 0000000..cb27ccf --- /dev/null +++ b/crates/dirigent_acp_api/src/sse/event_translator.rs @@ -0,0 +1,453 @@ +//! Event Translation Layer +//! +//! This module provides translation from Dirigent protocol Events to ACP notifications. +//! It handles the mapping between internal event representations and the wire format +//! expected by ACP clients. + +use dirigent_protocol::{Event, SessionUpdate}; + +use super::content_transform::{ + extract_kind, rebuild_meta, tool_call_status_to_string, unwrap_tool_call_content, +}; +use super::types::{ + models_to_config_option, modes_to_config_option, ConfigOption, ConfigOptionChoice, + ConfigOptionType, SessionUpdateParams, SessionUpdateVariant, +}; + +/// Translate a Dirigent protocol Event to ACP notifications +/// +/// Not all Dirigent events need to be forwarded to ACP clients. This function +/// returns a vector of notifications for events that should be sent, and an empty +/// vector for events that should be filtered out. +/// +/// ## Mapped Events +/// +/// - `Event::SessionUpdate` with `SessionUpdate::AgentMessageChunk` -> `MessageChunk` +/// - `Event::SessionUpdate` with `SessionUpdate::AgentThoughtChunk` -> `MessageChunk` (with content_type) +/// - `Event::SessionUpdate` with `SessionUpdate::ToolCall` -> `ToolCallUpdate` +/// - `Event::SessionUpdate` with `SessionUpdate::ToolCallUpdate` -> `ToolCallUpdate` +/// - `Event::MessageCompleted` -> `MessageComplete` +/// - `Event::SessionIdle` -> `SessionIdle` +/// - `Event::SessionMetadataReceived` -> `ConfigOptionUpdate` (with modes and models) +/// - `Event::MessageFailed` -> `SessionError` +/// - `Event::Error` -> `SessionError` (system-wide error) +/// +/// ## Filtered Events +/// +/// - `Event::SessionsListed` - List operations don't need streaming +/// - `Event::SessionCreated` - Handled via RPC response +/// - `Event::SessionUpdated` - Metadata updates, not content +/// - `Event::SessionMetadataUpdated` - Metadata updates +/// - `Event::SessionDeleted` - Handled via RPC +/// - `Event::Connected` / `Event::Disconnected` - System events +/// - `Event::ConnectorCreated` / `Event::ConnectorRemoved` - System events +/// - `Event::MessagesListed` - List operations +/// - `Event::MessageStarted` - Initial message, content comes via chunks +/// +/// # Parameters +/// +/// - `event`: The Dirigent protocol event to translate +/// +/// # Returns +/// +/// Vec of `SessionUpdateParams` to be sent. Most events return 0 or 1 update. +pub fn translate_event(event: &Event) -> Vec { + match event { + // SessionUpdate events - the main streaming content + Event::SessionUpdate { + session_id, update, .. + } => translate_session_update(session_id, update) + .map(|u| vec![u]) + .unwrap_or_default(), + + // Message completion + Event::MessageCompleted { message, .. } => vec![SessionUpdateParams { + session_id: message.session_id.clone(), + update: SessionUpdateVariant::MessageComplete { + message_id: Some(message.id.clone()), + }, + ..Default::default() + }], + + // Session idle + Event::SessionIdle { session_id, .. } => vec![SessionUpdateParams { + session_id: session_id.clone(), + update: SessionUpdateVariant::SessionIdle {}, + ..Default::default() + }], + + // Session metadata received - forward as BOTH config_option_update AND current_mode_update + // + // We emit both notification types to support: + // - New clients (acp-beta flag): Use config_option_update + // - Legacy clients: Use current_mode_update (mode only - no legacy model updates exist) + // + // Clients safely ignore notification types they don't understand. + // See: https://agentclientprotocol.com/rfds/session-config-options + Event::SessionMetadataReceived { + session_id, + models, + modes, + config_options: event_config_options, + .. + } => { + let mut updates = vec![]; + + // If the event already has config_options from the agent, prefer those + // over converting from modes/models (agent-provided options are authoritative). + if let Some(agent_config_options) = event_config_options { + if !agent_config_options.is_empty() { + // Emit ConfigOptionUpdate with agent-provided options + updates.push(SessionUpdateParams { + session_id: session_id.clone(), + update: SessionUpdateVariant::ConfigOptionUpdate { + config_options: agent_config_options.iter().map(|co| ConfigOption { + id: co.id.clone(), + name: co.name.clone(), + description: co.description.clone(), + category: co.category.clone(), + option_type: match co.option_type { + dirigent_protocol::ConfigOptionType::Select => ConfigOptionType::Select, + }, + current_value: co.current_value.clone(), + options: Some(co.options.iter().map(|v| ConfigOptionChoice { + value: v.value.clone(), + name: v.name.clone(), + description: v.description.clone(), + }).collect()), + }).collect(), + }, + ..Default::default() + }); + + // Also emit CurrentModeUpdate for legacy clients if a mode option exists + if let Some(mode_opt) = agent_config_options.iter().find(|co| co.id == "mode" || co.category.as_deref() == Some("mode")) { + updates.push(SessionUpdateParams { + session_id: session_id.clone(), + update: SessionUpdateVariant::CurrentModeUpdate { + mode_id: mode_opt.current_value.clone(), + }, + ..Default::default() + }); + } + + return updates; + } + } + + // Fall back to building config_options from legacy modes/models fields + let mut config_options: Vec = vec![]; + if let Some(modes_state) = modes { + config_options.push(modes_to_config_option(modes_state)); + + // Also emit CurrentModeUpdate for legacy clients + // Legacy protocol only supports mode updates (no model updates) + updates.push(SessionUpdateParams { + session_id: session_id.clone(), + update: SessionUpdateVariant::CurrentModeUpdate { + mode_id: modes_state.current_mode_id.clone(), + }, + ..Default::default() + }); + } + + if let Some(models_state) = models { + config_options.push(models_to_config_option(models_state)); + // Note: Legacy protocol has no CurrentModelUpdate - only new ConfigOptionUpdate + } + + if !config_options.is_empty() { + updates.push(SessionUpdateParams { + session_id: session_id.clone(), + update: SessionUpdateVariant::ConfigOptionUpdate { config_options }, + ..Default::default() + }); + } + + updates + } + + // Message failure - not forwarded in new structure (no equivalent variant) + Event::MessageFailed { .. } => vec![], + + // System-wide error - not forwarded in new structure (no equivalent variant) + Event::Error { .. } => vec![], + + // Agent request - will be handled in Phase 2 (bidirectional request/response) + // For now, not forwarded (requires new SSE variant and response mechanism) + Event::AgentRequest { .. } => vec![], + + // Events that should not be forwarded via this translation path + // (they may be handled elsewhere or are not relevant to ACP clients) + Event::SessionsListed { .. } + | Event::SessionCreated { .. } + | Event::SessionUpdated { .. } + | Event::SessionMetadataUpdated { .. } + | Event::SessionDeleted { .. } + | Event::SessionClosed { .. } + | Event::SessionSystemMessageSet { .. } + | Event::SessionError { .. } + | Event::SessionTransferred { .. } + | Event::ForwardingPanic { .. } + | Event::MessagesListed { .. } + | Event::MessageStarted { .. } + | Event::TurnComplete { .. } + | Event::ConnectorCreated { .. } + | Event::ConnectorRemoved { .. } + | Event::ConnectorStateChanged { .. } + | Event::Connected + | Event::Disconnected => vec![], + + // ACP client connection events - these are handled by the UI directly + // via the main SSE endpoint, not translated to session updates + Event::AcpClientConnected { .. } + | Event::AcpClientDisconnected { .. } + | Event::AcpClientSessionOpened { .. } + | Event::AcpClientSessionRouted { .. } => vec![], + + // Inspector events - not relevant for ACP session updates + Event::InspectorSnapshot { .. } + | Event::InspectorNodeRegistered { .. } + | Event::InspectorNodeRemoved { .. } + | Event::InspectorStateChanged { .. } + | Event::InspectorPropertiesUpdated { .. } => vec![], + + // Archivist-to-frontend signal — not relevant for ACP clients + Event::SessionRegistered { .. } => vec![], + + // System task events — internal UI concern, not relevant for ACP clients + Event::SystemTaskStatusChanged { .. } => vec![], + } +} + +/// Translate a SessionUpdate to a SessionUpdateParams +fn translate_session_update( + session_id: &str, + update: &SessionUpdate, +) -> Option { + match update { + // Agent message content chunk + SessionUpdate::AgentMessageChunk { content, .. } => Some(SessionUpdateParams { + session_id: session_id.to_string(), + update: SessionUpdateVariant::AgentMessageChunk { + content: content.clone(), + }, + ..Default::default() + }), + + // Agent thought chunk - render as agent message (thoughts shown as agent text in UI) + SessionUpdate::AgentThoughtChunk { content, .. } => Some(SessionUpdateParams { + session_id: session_id.to_string(), + update: SessionUpdateVariant::AgentMessageChunk { + content: content.clone(), + }, + ..Default::default() + }), + + // User message chunks are typically not streamed to other clients + // Filter them out for now + SessionUpdate::UserMessageChunk { .. } => None, + + // Tool call initiated - forward to client + SessionUpdate::ToolCall { + tool_call, _meta, .. + } => Some(SessionUpdateParams { + session_id: session_id.to_string(), + update: SessionUpdateVariant::ToolCall { + tool_call_id: tool_call.id.clone(), + title: tool_call.title.clone(), + kind: extract_kind(tool_call), + raw_input: tool_call.raw_input.clone(), + status: Some(tool_call_status_to_string(&tool_call.status)), + content: unwrap_tool_call_content(&tool_call.content), + _meta: rebuild_meta(_meta, tool_call), + }, + ..Default::default() + }), + + // Tool call update - forward to client + SessionUpdate::ToolCallUpdate { + tool_call, _meta, .. + } => Some(SessionUpdateParams { + session_id: session_id.to_string(), + update: SessionUpdateVariant::ToolCallUpdate { + tool_call_id: tool_call.id.clone(), + status: Some(tool_call_status_to_string(&tool_call.status)), + content: unwrap_tool_call_content(&tool_call.content), + raw_output: tool_call.raw_output.clone(), + error: tool_call.error.clone(), + _meta: rebuild_meta(_meta, tool_call), + }, + ..Default::default() + }), + + // Unknown update type - forward as-is for transparent proxying + SessionUpdate::Unknown { data } => Some(SessionUpdateParams { + session_id: session_id.to_string(), + update: SessionUpdateVariant::Unknown { data: data.clone() }, + ..Default::default() + }), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use dirigent_protocol::{ContentBlock, Message, MessageRole, MessageStatus}; + + #[test] + fn test_translate_session_update_agent_chunk() { + let event = Event::SessionUpdate { + connector_id: "conn-1".to_string(), + session_id: "sess-1".to_string(), + update: SessionUpdate::AgentMessageChunk { + message_id: "msg-1".to_string(), + content: ContentBlock::Text { + text: "Hello".to_string(), + }, + _meta: None, + }, + }; + + let updates = translate_event(&event); + assert_eq!(updates.len(), 1); + + let update_params = &updates[0]; + assert_eq!(update_params.session_id, "sess-1"); + + match &update_params.update { + SessionUpdateVariant::AgentMessageChunk { content } => match content { + ContentBlock::Text { text } => { + assert_eq!(text, "Hello"); + } + _ => panic!("Expected Text content"), + }, + _ => panic!("Expected AgentMessageChunk variant"), + } + } + + #[test] + fn test_translate_session_update_thought_chunk() { + let event = Event::SessionUpdate { + connector_id: "conn-1".to_string(), + session_id: "sess-1".to_string(), + update: SessionUpdate::AgentThoughtChunk { + message_id: "msg-1".to_string(), + content: ContentBlock::Text { + text: "Thinking...".to_string(), + }, + _meta: None, + }, + }; + + let updates = translate_event(&event); + assert_eq!(updates.len(), 1); + + let update_params = &updates[0]; + assert_eq!(update_params.session_id, "sess-1"); + + // Thought chunks are rendered as agent message chunks + match &update_params.update { + SessionUpdateVariant::AgentMessageChunk { content } => match content { + ContentBlock::Text { text } => { + assert_eq!(text, "Thinking..."); + } + _ => panic!("Expected Text content"), + }, + _ => panic!("Expected AgentMessageChunk variant"), + } + } + + #[test] + fn test_translate_message_completed() { + let now = chrono::Utc::now(); + let event = Event::MessageCompleted { + connector_id: "conn-1".to_string(), + message: Message { + id: "msg-1".to_string(), + session_id: "sess-1".to_string(), + role: MessageRole::Assistant, + created_at: now, + content: vec![], + status: MessageStatus::Completed, + metadata: None, + }, + }; + + let updates = translate_event(&event); + assert_eq!(updates.len(), 1); + + let update_params = &updates[0]; + assert_eq!(update_params.session_id, "sess-1"); + + match &update_params.update { + SessionUpdateVariant::MessageComplete { message_id } => { + assert_eq!(message_id, &Some("msg-1".to_string())); + } + _ => panic!("Expected MessageComplete variant"), + } + } + + #[test] + fn test_translate_session_idle() { + let event = Event::SessionIdle { + connector_id: "test-connector".to_string(), + session_id: "sess-1".to_string(), + }; + + let updates = translate_event(&event); + assert_eq!(updates.len(), 1); + + let update_params = &updates[0]; + assert_eq!(update_params.session_id, "sess-1"); + + match &update_params.update { + SessionUpdateVariant::SessionIdle {} => { + // Expected + } + _ => panic!("Expected SessionIdle variant"), + } + } + + #[test] + fn test_translate_message_failed() { + let event = Event::MessageFailed { + message_id: "msg-1".to_string(), + error: "Connection lost".to_string(), + }; + + // MessageFailed events are now filtered (return empty vec) + let updates = translate_event(&event); + assert!( + updates.is_empty(), + "MessageFailed events should be filtered" + ); + } + + #[test] + fn test_translate_filtered_events() { + // These events should not produce notifications + let filtered_events = vec![ + Event::Connected, + Event::Disconnected, + Event::SessionDeleted { + session_id: "s".to_string(), + }, + Event::Error { + message: "System error".to_string(), + }, + Event::MessageFailed { + message_id: "msg-1".to_string(), + error: "Failed".to_string(), + }, + ]; + + for event in filtered_events { + assert!( + translate_event(&event).is_empty(), + "Expected {:?} to be filtered", + event + ); + } + } +} diff --git a/crates/dirigent_acp_api/src/sse/mod.rs b/crates/dirigent_acp_api/src/sse/mod.rs new file mode 100644 index 0000000..603b3c8 --- /dev/null +++ b/crates/dirigent_acp_api/src/sse/mod.rs @@ -0,0 +1,478 @@ +//! SSE (Server-Sent Events) Notifications for ACP Server +//! +//! This module provides the SSE notification infrastructure for streaming events +//! to connected ACP clients. It handles per-client subscriptions using broadcast +//! channels and provides translation from Dirigent protocol events to ACP notifications. +//! +//! ## Architecture +//! +//! The module is organized into several sub-modules: +//! +//! - `types` - Data structures for SSE notifications (SessionUpdateParams, SessionUpdateVariant, etc.) +//! - `notifier` - SseNotifier for managing client subscriptions and broadcasting +//! - `event_translator` - Translation from Dirigent Events to ACP notifications +//! - `content_transform` - Helper functions for content and metadata transformation +//! +//! ## SSE Wire Format +//! +//! The SSE stream uses the following format: +//! +//! ```text +//! event: session/update +//! data: {"sessionId": "...", "update": {"sessionUpdate": "agent_message_chunk", ...}} +//! +//! event: session/update +//! data: {"sessionId": "...", "update": {"sessionUpdate": "message_complete", ...}} +//! ``` +//! +//! ## Example +//! +//! ```rust,ignore +//! use dirigent_acp_api::sse::{SseNotifier, SessionUpdateParams, translate_event}; +//! +//! // Create the notifier +//! let notifier = SseNotifier::new(); +//! +//! // Subscribe a client +//! let stream = notifier.subscribe("client-123"); +//! +//! // Translate and broadcast an event +//! let updates = translate_event(&event); +//! for update in updates { +//! notifier.broadcast("client-123", update); +//! } +//! +//! // Unsubscribe when done +//! notifier.unsubscribe("client-123"); +//! ``` + +mod content_transform; +mod event_translator; +mod notifier; +mod types; + +// Re-export public types +pub use content_transform::{ + extract_kind, rebuild_meta, tool_call_status_to_string, unwrap_tool_call_content, +}; +pub use event_translator::translate_event; +pub use notifier::SseNotifier; +pub use types::{ + models_to_config_option, modes_to_config_option, AcpNotification, ConfigOption, + ConfigOptionChoice, ConfigOptionType, SessionUpdateParams, SessionUpdateVariant, SlashCommand, +}; + +// ============================================================================ +// Tests (moved from original sse.rs) +// ============================================================================ + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + // ======================================================================== + // AcpNotification Tests + // ======================================================================== + + #[test] + fn test_message_chunk_serialization() { + let notification = AcpNotification::MessageChunk { + session_id: "sess-123".to_string(), + message_id: "msg-456".to_string(), + content: "Hello, world!".to_string(), + content_type: Some("text".to_string()), + }; + + let json = serde_json::to_string(¬ification).unwrap(); + assert!(json.contains(r#""type":"message_chunk"#)); + assert!(json.contains(r#""sessionId":"sess-123"#)); + assert!(json.contains(r#""messageId":"msg-456"#)); + assert!(json.contains(r#""content":"Hello, world!"#)); + assert!(json.contains(r#""contentType":"text"#)); + } + + #[test] + fn test_message_complete_serialization() { + let notification = AcpNotification::MessageComplete { + session_id: "sess-123".to_string(), + message_id: "msg-456".to_string(), + }; + + let json = serde_json::to_string(¬ification).unwrap(); + assert!(json.contains(r#""type":"message_complete"#)); + assert!(json.contains(r#""sessionId":"sess-123"#)); + assert!(json.contains(r#""messageId":"msg-456"#)); + } + + #[test] + fn test_session_idle_serialization() { + let notification = AcpNotification::SessionIdle { + session_id: "sess-123".to_string(), + }; + + let json = serde_json::to_string(¬ification).unwrap(); + assert!(json.contains(r#""type":"session_idle"#)); + assert!(json.contains(r#""sessionId":"sess-123"#)); + } + + #[test] + fn test_session_error_serialization() { + let notification = AcpNotification::SessionError { + session_id: "sess-123".to_string(), + error: "Something went wrong".to_string(), + }; + + let json = serde_json::to_string(¬ification).unwrap(); + assert!(json.contains(r#""type":"session_error"#)); + assert!(json.contains(r#""error":"Something went wrong"#)); + } + + #[test] + fn test_tool_call_update_serialization() { + let notification = AcpNotification::ToolCallUpdate { + session_id: "sess-123".to_string(), + message_id: "msg-456".to_string(), + tool_call_id: "call-789".to_string(), + tool_name: "bash".to_string(), + status: "running".to_string(), + title: Some("Running command".to_string()), + error: None, + }; + + let json = serde_json::to_string(¬ification).unwrap(); + assert!(json.contains(r#""type":"tool_call_update"#)); + assert!(json.contains(r#""toolName":"bash"#)); + assert!(json.contains(r#""status":"running"#)); + assert!(json.contains(r#""title":"Running command"#)); + // error should not be present when None + assert!(!json.contains(r#""error""#)); + } + + #[test] + fn test_event_type() { + assert_eq!( + AcpNotification::MessageChunk { + session_id: "s".to_string(), + message_id: "m".to_string(), + content: "c".to_string(), + content_type: None, + } + .event_type(), + "message/chunk" + ); + + assert_eq!( + AcpNotification::MessageComplete { + session_id: "s".to_string(), + message_id: "m".to_string(), + } + .event_type(), + "message/complete" + ); + + assert_eq!( + AcpNotification::SessionIdle { + session_id: "s".to_string(), + } + .event_type(), + "session/idle" + ); + + assert_eq!( + AcpNotification::SessionError { + session_id: "s".to_string(), + error: "e".to_string(), + } + .event_type(), + "session/error" + ); + + assert_eq!( + AcpNotification::ToolCallUpdate { + session_id: "s".to_string(), + message_id: "m".to_string(), + tool_call_id: "t".to_string(), + tool_name: "bash".to_string(), + status: "running".to_string(), + title: None, + error: None, + } + .event_type(), + "tool/update" + ); + } + + #[test] + fn test_to_sse_string() { + let notification = AcpNotification::SessionIdle { + session_id: "sess-123".to_string(), + }; + + let sse = notification.to_sse_string(); + assert!(sse.starts_with("event: session/idle\n")); + assert!(sse.contains("data: ")); + assert!(sse.contains("sess-123")); + assert!(sse.ends_with("\n\n")); + } + + #[test] + fn test_notification_roundtrip() { + let notifications = vec![ + AcpNotification::MessageChunk { + session_id: "s".to_string(), + message_id: "m".to_string(), + content: "c".to_string(), + content_type: Some("text".to_string()), + }, + AcpNotification::MessageComplete { + session_id: "s".to_string(), + message_id: "m".to_string(), + }, + AcpNotification::SessionIdle { + session_id: "s".to_string(), + }, + AcpNotification::SessionError { + session_id: "s".to_string(), + error: "e".to_string(), + }, + AcpNotification::ToolCallUpdate { + session_id: "s".to_string(), + message_id: "m".to_string(), + tool_call_id: "t".to_string(), + tool_name: "bash".to_string(), + status: "running".to_string(), + title: Some("title".to_string()), + error: None, + }, + ]; + + for notification in notifications { + let json = serde_json::to_string(¬ification).unwrap(); + let deserialized: AcpNotification = serde_json::from_str(&json).unwrap(); + assert_eq!(notification, deserialized); + } + } + + // ======================================================================== + // SessionUpdate Tag Verification Tests + // ======================================================================== + + #[test] + fn test_session_update_variant_uses_session_update_tag_tool_call() { + // Test ToolCall variant with flattened structure + let variant = SessionUpdateVariant::ToolCall { + tool_call_id: "call_456".to_string(), + title: Some("Running command".to_string()), + kind: Some("command".to_string()), + raw_input: Some(json!({"command": "ls"})), + status: Some("in_progress".to_string()), + content: vec![], + _meta: Some(json!({ + "claudeCode": { + "toolName": "bash" + } + })), + }; + + let json = serde_json::to_value(&variant).unwrap(); + + // Verify correct tag name "sessionUpdate" is present + assert!( + json.get("sessionUpdate").is_some(), + "Missing 'sessionUpdate' field in JSON: {:?}", + json + ); + assert_eq!( + json["sessionUpdate"], "tool_call", + "Expected sessionUpdate value 'tool_call', got: {:?}", + json["sessionUpdate"] + ); + + // Verify incorrect tag "type" is NOT present + assert!( + json.get("type").is_none(), + "Should not have 'type' field in ACP output, found: {:?}", + json.get("type") + ); + + // Verify flattened fields are present + assert!(json.get("toolCallId").is_some(), "Missing toolCallId field"); + assert!(json.get("title").is_some(), "Missing title field"); + assert!(json.get("status").is_some(), "Missing status field"); + } + + #[test] + fn test_session_update_variant_roundtrip_with_session_update_tag() { + // Test that serialization and deserialization work correctly with sessionUpdate tag + let variants = vec![ + SessionUpdateVariant::SessionIdle {}, + SessionUpdateVariant::MessageComplete { + message_id: Some("msg_1".to_string()), + }, + SessionUpdateVariant::AgentMessageChunk { + content: dirigent_protocol::ContentBlock::Text { + text: "test".to_string(), + }, + }, + SessionUpdateVariant::ToolCall { + tool_call_id: "call_1".to_string(), + title: None, + kind: None, + raw_input: None, + status: Some("pending".to_string()), + content: vec![], + _meta: Some(json!({ + "claudeCode": { + "toolName": "test_tool" + } + })), + }, + ]; + + for variant in variants { + // Serialize + let json_str = serde_json::to_string(&variant).unwrap(); + let json_val: serde_json::Value = serde_json::from_str(&json_str).unwrap(); + + // Verify sessionUpdate tag in JSON + assert!( + json_val.get("sessionUpdate").is_some(), + "Missing sessionUpdate in serialized JSON for variant: {:?}", + variant + ); + + // Verify no type tag + assert!( + json_val.get("type").is_none(), + "Found unexpected 'type' tag for variant: {:?}", + variant + ); + + // Deserialize back + let deserialized: SessionUpdateVariant = serde_json::from_str(&json_str).unwrap(); + + // Verify equality (roundtrip) + assert_eq!( + variant, deserialized, + "Roundtrip failed for variant: {:?}", + variant + ); + } + } + + // ======================================================================== + // camelCase Field Naming Compliance Tests + // ======================================================================== + + #[test] + fn test_all_fields_use_camel_case() { + // Test SessionUpdateParams with ToolCall variant + let params = SessionUpdateParams { + session_id: "sess_123".to_string(), + update: SessionUpdateVariant::ToolCall { + tool_call_id: "tool_123".to_string(), + title: Some("test".to_string()), + kind: Some("search".to_string()), + raw_input: Some(json!({"key": "value"})), + status: Some("pending".to_string()), + content: vec![], + _meta: None, + }, + ..Default::default() + }; + + let json = serde_json::to_value(¶ms).unwrap(); + + // Verify top-level camelCase + assert!( + json.get("sessionId").is_some(), + "Should have sessionId (not session_id)" + ); + assert!( + json.get("session_id").is_none(), + "Should NOT have session_id" + ); + + let update = &json["update"]; + + // Verify ToolCall variant fields use camelCase + assert!( + update.get("toolCallId").is_some(), + "Should have toolCallId (not tool_call_id)" + ); + assert!( + update.get("tool_call_id").is_none(), + "Should NOT have tool_call_id" + ); + assert!( + update.get("rawInput").is_some(), + "Should have rawInput (not raw_input)" + ); + assert!( + update.get("raw_input").is_none(), + "Should NOT have raw_input" + ); + } + + #[test] + fn test_no_snake_case_in_serialized_json() { + // Create various SessionUpdateParams with different variants + let test_cases = vec![ + SessionUpdateParams { + session_id: "s1".to_string(), + update: SessionUpdateVariant::ToolCall { + tool_call_id: "tc1".to_string(), + title: Some("t".to_string()), + kind: Some("k".to_string()), + raw_input: Some(json!({})), + status: Some("pending".to_string()), + content: vec![], + _meta: None, + }, + ..Default::default() + }, + SessionUpdateParams { + session_id: "s2".to_string(), + update: SessionUpdateVariant::ToolCallUpdate { + tool_call_id: "tc2".to_string(), + status: Some("completed".to_string()), + content: vec![], + raw_output: Some(json!({})), + error: None, + _meta: None, + }, + ..Default::default() + }, + SessionUpdateParams { + session_id: "s3".to_string(), + update: SessionUpdateVariant::MessageComplete { + message_id: Some("m3".to_string()), + }, + ..Default::default() + }, + ]; + + for params in test_cases { + let json_str = serde_json::to_string(¶ms).unwrap(); + + // Check for common snake_case fields that should NOT appear + let forbidden_patterns = vec![ + "session_id", + "tool_call_id", + "raw_input", + "raw_output", + "message_id", + ]; + + for pattern in forbidden_patterns { + assert!( + !json_str.contains(&format!("\"{}\"", pattern)), + "Found forbidden snake_case field '{}' in JSON: {}", + pattern, + json_str + ); + } + } + } +} diff --git a/crates/dirigent_acp_api/src/sse/notifier.rs b/crates/dirigent_acp_api/src/sse/notifier.rs new file mode 100644 index 0000000..f222605 --- /dev/null +++ b/crates/dirigent_acp_api/src/sse/notifier.rs @@ -0,0 +1,403 @@ +//! SSE Notifier for managing client subscriptions +//! +//! This module provides the subscription management infrastructure for SSE streaming. +//! The `SseNotifier` manages per-client broadcast channels for targeted notifications. + +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::{Arc, RwLock}; + +use tokio::sync::broadcast; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; +use tokio_stream::wrappers::BroadcastStream; +use tokio_stream::Stream; +use tracing::{debug, trace, warn}; + +use super::types::SessionUpdateParams; + +/// Default broadcast channel capacity +const DEFAULT_CHANNEL_CAPACITY: usize = 256; + +/// Internal state for SSE subscriptions +#[derive(Debug, Default)] +struct SseNotifierState { + /// Map from client_id to their broadcast sender + subscriptions: HashMap>, +} + +/// SSE Notifier for managing client subscriptions and broadcasting notifications +/// +/// The `SseNotifier` manages per-client subscriptions using tokio broadcast channels. +/// Each client has their own broadcast channel, allowing targeted notifications. +/// +/// ## Thread Safety +/// +/// The `SseNotifier` is designed to be cloned and shared across async tasks. +/// Internal state is protected by `Arc>` for thread-safe access. +/// +/// ## Example +/// +/// ```rust,ignore +/// let notifier = SseNotifier::new(); +/// +/// // Subscribe a client and get their stream +/// let stream = notifier.subscribe("client-123"); +/// +/// // In another task, broadcast notifications +/// notifier.broadcast("client-123", notification); +/// +/// // When client disconnects +/// notifier.unsubscribe("client-123"); +/// ``` +#[derive(Debug, Clone)] +pub struct SseNotifier { + state: Arc>, + channel_capacity: usize, +} + +impl Default for SseNotifier { + fn default() -> Self { + Self::new() + } +} + +impl SseNotifier { + /// Create a new SSE notifier with default capacity + /// + /// Creates a new notifier with the default channel capacity of 256 messages. + pub fn new() -> Self { + Self { + state: Arc::new(RwLock::new(SseNotifierState::default())), + channel_capacity: DEFAULT_CHANNEL_CAPACITY, + } + } + + /// Create a new SSE notifier with custom channel capacity + /// + /// # Parameters + /// + /// - `capacity`: The maximum number of messages that can be buffered per client + pub fn with_capacity(capacity: usize) -> Self { + Self { + state: Arc::new(RwLock::new(SseNotifierState::default())), + channel_capacity: capacity, + } + } + + /// Subscribe a client and return an async stream of notifications + /// + /// Creates a broadcast channel for the client and returns a stream that can + /// be used with Axum's `Sse` response type. + /// + /// If the client is already subscribed, a new receiver is created from the + /// existing sender (allowing multiple connections per client). + /// + /// # Parameters + /// + /// - `client_id`: Unique identifier for the client + /// + /// # Returns + /// + /// A pinned stream of `SessionUpdateParams` wrapped in `Result` for error handling. + /// The stream is compatible with Axum's SSE handler. + pub fn subscribe( + &self, + client_id: &str, + ) -> Pin> + Send>> + { + let mut state = self.state.write().expect("Lock poisoned"); + + // Get or create the sender for this client + let sender = state + .subscriptions + .entry(client_id.to_string()) + .or_insert_with(|| { + debug!("Creating new broadcast channel for client: {}", client_id); + let (tx, _rx) = broadcast::channel(self.channel_capacity); + tx + }); + + // Create a new receiver + let receiver = sender.subscribe(); + + debug!( + "Client subscribed to SSE: {} (subscribers: {})", + client_id, + sender.receiver_count() + ); + + // Convert to stream using tokio_stream + let stream = BroadcastStream::new(receiver); + + Box::pin(stream) + } + + /// Unsubscribe a client and cleanup resources + /// + /// Removes the client's subscription and drops the sender, which will + /// cause all associated receivers to receive an error on their next poll. + /// + /// # Parameters + /// + /// - `client_id`: The ID of the client to unsubscribe + /// + /// # Returns + /// + /// `true` if the client was subscribed and removed, `false` if not found + pub fn unsubscribe(&self, client_id: &str) -> bool { + let mut state = self.state.write().expect("Lock poisoned"); + + if let Some(sender) = state.subscriptions.remove(client_id) { + debug!( + "Client unsubscribed from SSE: {} (was {} receivers)", + client_id, + sender.receiver_count() + ); + // Sender is dropped here, which will close the channel + true + } else { + debug!("Client was not subscribed: {}", client_id); + false + } + } + + /// Broadcast a session update to a specific client + /// + /// Sends the session update to all receivers subscribed to the specified client. + /// If no receivers are active (e.g., all have dropped or lagged), the send + /// is handled gracefully. + /// + /// # Parameters + /// + /// - `client_id`: The ID of the client to send to + /// - `update_params`: The session update parameters to send + /// + /// # Returns + /// + /// - `Ok(n)` where n is the number of receivers that received the update + /// - `Err(())` if the client is not subscribed + pub fn broadcast( + &self, + client_id: &str, + update_params: SessionUpdateParams, + ) -> Result { + let state = self.state.read().expect("Lock poisoned"); + + if let Some(sender) = state.subscriptions.get(client_id) { + match sender.send(update_params) { + Ok(n) => { + trace!( + "Broadcast session update to client {}: {} receivers", + client_id, + n + ); + Ok(n) + } + Err(_) => { + // No active receivers - this is okay, the update is just dropped + trace!("Broadcast to client {}: no active receivers", client_id); + Ok(0) + } + } + } else { + warn!("Attempted broadcast to unsubscribed client: {}", client_id); + Err(()) + } + } + + /// Broadcast a session update to all subscribed clients + /// + /// Sends the session update to all connected clients. Failed sends to + /// individual clients are logged but don't stop the broadcast. + /// + /// # Parameters + /// + /// - `update_params`: The session update parameters to broadcast + /// + /// # Returns + /// + /// The total number of receivers that received the update + pub fn broadcast_all(&self, update_params: SessionUpdateParams) -> usize { + let state = self.state.read().expect("Lock poisoned"); + let mut total_receivers = 0; + + for (client_id, sender) in state.subscriptions.iter() { + match sender.send(update_params.clone()) { + Ok(n) => { + total_receivers += n; + trace!("Broadcast to client {}: {} receivers", client_id, n); + } + Err(_) => { + trace!("Broadcast to client {}: no active receivers", client_id); + } + } + } + + debug!( + "Broadcast to all clients: {} total receivers across {} clients", + total_receivers, + state.subscriptions.len() + ); + + total_receivers + } + + /// Get the number of subscribed clients + pub fn client_count(&self) -> usize { + let state = self.state.read().expect("Lock poisoned"); + state.subscriptions.len() + } + + /// Check if a client is subscribed + pub fn is_subscribed(&self, client_id: &str) -> bool { + let state = self.state.read().expect("Lock poisoned"); + state.subscriptions.contains_key(client_id) + } + + /// Get the list of subscribed client IDs + pub fn subscribed_clients(&self) -> Vec { + let state = self.state.read().expect("Lock poisoned"); + state.subscriptions.keys().cloned().collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::sse::types::SessionUpdateVariant; + use tokio_stream::StreamExt; + + #[test] + fn test_sse_notifier_new() { + let notifier = SseNotifier::new(); + assert_eq!(notifier.client_count(), 0); + } + + #[test] + fn test_sse_notifier_with_capacity() { + let notifier = SseNotifier::with_capacity(512); + assert_eq!(notifier.channel_capacity, 512); + } + + #[tokio::test] + async fn test_subscribe_unsubscribe() { + let notifier = SseNotifier::new(); + + // Subscribe + let _stream = notifier.subscribe("client-1"); + assert!(notifier.is_subscribed("client-1")); + assert_eq!(notifier.client_count(), 1); + + // Unsubscribe + assert!(notifier.unsubscribe("client-1")); + assert!(!notifier.is_subscribed("client-1")); + assert_eq!(notifier.client_count(), 0); + + // Unsubscribe non-existent + assert!(!notifier.unsubscribe("client-1")); + } + + #[tokio::test] + async fn test_broadcast_to_client() { + let notifier = SseNotifier::new(); + + // Subscribe + let mut stream = notifier.subscribe("client-1"); + + // Broadcast + let update_params = SessionUpdateParams { + session_id: "sess-1".to_string(), + update: SessionUpdateVariant::SessionIdle {}, + ..Default::default() + }; + + let result = notifier.broadcast("client-1", update_params.clone()); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), 1); + + // Receive + let received = stream.next().await; + assert!(received.is_some()); + let received = received.unwrap(); + assert!(received.is_ok()); + assert_eq!(received.unwrap(), update_params); + } + + #[tokio::test] + async fn test_broadcast_to_unsubscribed_client() { + let notifier = SseNotifier::new(); + + let update_params = SessionUpdateParams { + session_id: "sess-1".to_string(), + update: SessionUpdateVariant::SessionIdle {}, + ..Default::default() + }; + + let result = notifier.broadcast("unknown-client", update_params); + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_broadcast_all() { + let notifier = SseNotifier::new(); + + // Subscribe multiple clients + let mut stream1 = notifier.subscribe("client-1"); + let mut stream2 = notifier.subscribe("client-2"); + + let update_params = SessionUpdateParams { + session_id: "sess-1".to_string(), + update: SessionUpdateVariant::SessionIdle {}, + ..Default::default() + }; + + let count = notifier.broadcast_all(update_params.clone()); + assert_eq!(count, 2); + + // Both should receive + let r1 = stream1.next().await.unwrap().unwrap(); + let r2 = stream2.next().await.unwrap().unwrap(); + + assert_eq!(r1, update_params); + assert_eq!(r2, update_params); + } + + #[test] + fn test_subscribed_clients() { + let notifier = SseNotifier::new(); + + let _s1 = notifier.subscribe("client-1"); + let _s2 = notifier.subscribe("client-2"); + + let clients = notifier.subscribed_clients(); + assert_eq!(clients.len(), 2); + assert!(clients.contains(&"client-1".to_string())); + assert!(clients.contains(&"client-2".to_string())); + } + + #[tokio::test] + async fn test_multiple_receivers_same_client() { + let notifier = SseNotifier::new(); + + // Subscribe same client twice + let mut stream1 = notifier.subscribe("client-1"); + let mut stream2 = notifier.subscribe("client-1"); + + let update_params = SessionUpdateParams { + session_id: "sess-1".to_string(), + update: SessionUpdateVariant::SessionIdle {}, + ..Default::default() + }; + + let count = notifier.broadcast("client-1", update_params.clone()); + assert!(count.is_ok()); + assert_eq!(count.unwrap(), 2); // Both receivers + + // Both should receive + let r1 = stream1.next().await.unwrap().unwrap(); + let r2 = stream2.next().await.unwrap().unwrap(); + + assert_eq!(r1, update_params); + assert_eq!(r2, update_params); + } +} diff --git a/crates/dirigent_acp_api/src/sse/types.rs b/crates/dirigent_acp_api/src/sse/types.rs new file mode 100644 index 0000000..f1ce003 --- /dev/null +++ b/crates/dirigent_acp_api/src/sse/types.rs @@ -0,0 +1,548 @@ +//! SSE Type Definitions for ACP Server +//! +//! This module contains all the data structures and types used for SSE notifications. +//! These types define the wire format for ACP protocol messages. + +use serde::{Deserialize, Serialize}; + +use dirigent_protocol::ContentBlock; + +// ============================================================================ +// SessionUpdateParams +// ============================================================================ + +/// Params for session/update JSON-RPC notification (ACP spec compliant) +/// +/// This matches the structure sent by Claude and expected by Zed. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +#[serde(rename_all = "camelCase", default)] +pub struct SessionUpdateParams { + /// The session ID this update belongs to + pub session_id: String, + + /// The update content + pub update: SessionUpdateVariant, + + /// Optional event type override (default: "session/update") + /// Reserved for future protocol extensions. Currently unused. + #[serde(skip)] + pub event_type_override: Option, +} + +impl SessionUpdateParams { + /// Get the SSE event type for session updates + /// + /// According to ACP spec, session updates use "session/update" event type. + pub fn event_type(&self) -> &str { + self.event_type_override + .as_deref() + .unwrap_or("session/update") + } + + /// Create a raw event with a custom event type + /// + /// This creates an event that will be serialized as raw JSON data + /// (without the SessionUpdateParams wrapper) when sent over SSE. + /// Used for `_rpc_response` deferred responses pushed via SSE after gateway transfer. + pub fn raw_event(event_type: &str, data: serde_json::Value) -> Self { + Self { + session_id: String::new(), // Not used for raw events + update: SessionUpdateVariant::Unknown { data }, + event_type_override: Some(event_type.to_string()), + } + } + + /// Check if this is a raw event (should serialize data directly) + pub fn is_raw_event(&self) -> bool { + self.event_type_override.is_some() + } + + /// Get the data to serialize for SSE + /// + /// For raw events (with event_type_override), returns the raw JSON data. + /// For normal events, returns None (caller should serialize the full struct). + pub fn raw_data(&self) -> Option<&serde_json::Value> { + if self.event_type_override.is_some() { + if let SessionUpdateVariant::Unknown { data } = &self.update { + return Some(data); + } + } + None + } + + /// Serialize to JSON string for SSE transmission + /// + /// For raw events, serializes just the raw data. + /// For normal events, serializes the full SessionUpdateParams struct. + pub fn to_sse_json(&self) -> String { + if let Some(raw_data) = self.raw_data() { + // Raw event: serialize just the data + serde_json::to_string(raw_data).unwrap_or_else(|_| "{}".to_string()) + } else { + // Normal event: serialize the full struct + serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string()) + } + } +} + +// ============================================================================ +// ConfigOption Types +// ============================================================================ + +/// A single configuration option for session settings (outgoing ACP wire format) +/// +/// Used in `config_option_update` notifications to send the complete set of +/// configuration options (modes, models, etc.) to the client. +/// +/// Note: This is the **outgoing** wire format for ACP Server SSE. +/// For the **incoming** format (parsed from agents), see `dirigent_protocol::ConfigOption`. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct ConfigOption { + /// Unique identifier for this option (e.g., "mode", "model") + pub id: String, + /// Human-readable display name (e.g., "Session Mode", "Model") + pub name: String, + /// Optional description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Semantic category for UX grouping (e.g., "mode", "model", "thought_level") + #[serde(skip_serializing_if = "Option::is_none")] + pub category: Option, + /// Input type (e.g., "select", "toggle", "text") + #[serde(rename = "type")] + pub option_type: ConfigOptionType, + /// Currently selected value + pub current_value: String, + /// Available options for "select" type + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option>, +} + +/// Type of configuration option input +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ConfigOptionType { + Select, + Toggle, + Text, +} + +/// A single choice within a select-type config option +/// +/// Matches ACP's `SessionConfigSelectOption` structure. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct ConfigOptionChoice { + /// Value identifier for this choice (sent back when selected) + pub value: String, + /// Human-readable display name + pub name: String, + /// Optional description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, +} + +// ============================================================================ +// SessionUpdateVariant +// ============================================================================ + +/// Variants of session updates (ACP spec compliant) +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "sessionUpdate", rename_all = "snake_case")] +pub enum SessionUpdateVariant { + /// Agent message chunk (streaming text) + #[serde(rename = "agent_message_chunk")] + AgentMessageChunk { + /// The content block (text, image, etc.) + content: ContentBlock, + }, + + /// Message generation complete + #[serde(rename = "message_complete")] + MessageComplete { + /// The message ID that completed + #[serde(rename = "messageId", skip_serializing_if = "Option::is_none")] + message_id: Option, + }, + + /// Session is idle and ready for input + #[serde(rename = "session_idle")] + SessionIdle {}, + + /// Available slash commands update + #[serde(rename = "available_commands_update")] + AvailableCommandsUpdate { + /// List of available slash commands + #[serde(rename = "availableCommands")] + available_commands: Vec, + }, + + /// Connector changed (session transferred to a different connector) + #[serde(rename = "connector_changed")] + ConnectorChanged { + /// The new connector ID + #[serde(rename = "newConnectorId")] + new_connector_id: String, + /// The new internal session ID in the target connector + #[serde(rename = "newInternalSessionId")] + new_internal_session_id: String, + /// Whether a new session was created (true) or existing session loaded (false) + #[serde(rename = "isNewSession")] + is_new_session: bool, + }, + + /// Tool call started/initiated + #[serde(rename = "tool_call")] + ToolCall { + /// The tool call ID + #[serde(rename = "toolCallId")] + tool_call_id: String, + + /// Optional title for the tool call + #[serde(skip_serializing_if = "Option::is_none")] + title: Option, + + /// Optional kind/category (e.g., "search", "edit") + #[serde(skip_serializing_if = "Option::is_none")] + kind: Option, + + /// Raw input parameters + #[serde(rename = "rawInput", skip_serializing_if = "Option::is_none")] + raw_input: Option, + + /// Current status (pending, in_progress, completed, failed) + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + + /// Content blocks (e.g., tool output) + #[serde(default, skip_serializing_if = "Vec::is_empty")] + content: Vec, + + /// Metadata (can include claudeCode.toolName, toolResponse, etc.) + #[serde(skip_serializing_if = "Option::is_none")] + _meta: Option, + }, + + /// Tool call update (status change, output, etc.) + #[serde(rename = "tool_call_update")] + ToolCallUpdate { + /// The tool call ID being updated + #[serde(rename = "toolCallId")] + tool_call_id: String, + + /// Updated status + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + + /// Updated content blocks + #[serde(default, skip_serializing_if = "Vec::is_empty")] + content: Vec, + + /// Raw output from the tool + #[serde(rename = "rawOutput", skip_serializing_if = "Option::is_none")] + raw_output: Option, + + /// Error message if tool call failed + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + + /// Metadata (can include toolResponse from Claude) + #[serde(skip_serializing_if = "Option::is_none")] + _meta: Option, + }, + + /// Agent request needing client response (e.g., permission prompt) + /// + /// This variant is used for bidirectional requests where the agent initiates + /// a request (like `session/request_permission`) that requires a response from + /// the client. The client should respond via the `/acp/agent_response` endpoint. + #[serde(rename = "agent_request")] + AgentRequest { + /// The request ID from the agent (to correlate response) + #[serde(rename = "requestId")] + request_id: serde_json::Value, + + /// The method being requested (e.g., "session/request_permission") + method: String, + + /// The request parameters + params: serde_json::Value, + }, + + /// Session modes update (full state) - NOT SUPPORTED BY ZED + /// Kept for future ACP protocol support + #[serde(rename = "modes_update")] + ModesUpdate { + /// Session mode state (flattened) + #[serde(flatten)] + modes: dirigent_protocol::SessionModeState, + }, + + /// Session models update (full state) - NOT SUPPORTED BY ZED + /// Kept for future ACP protocol support + #[serde(rename = "models_update")] + ModelsUpdate { + /// Session model state (flattened) + #[serde(flatten)] + models: dirigent_protocol::SessionModelState, + }, + + /// Current mode update (standard ACP notification) + /// Signals which mode is currently selected. Zed accepts this. + #[serde(rename = "current_mode_update")] + CurrentModeUpdate { + /// The ID of the currently selected mode + /// Note: Zed expects "currentModeId" not "modeId" + #[serde(rename = "currentModeId")] + mode_id: String, + }, + + /// Current model update - NOT SUPPORTED BY ZED + /// Kept for completeness but Zed only accepts current_mode_update + #[serde(rename = "current_model_update")] + CurrentModelUpdate { + /// The ID of the currently selected model + #[serde(rename = "modelId")] + model_id: String, + }, + + /// Config option update - send full configuration options to client + /// + /// This is used after session transfer to provide the target connector's + /// actual modes/models instead of Gateway's placeholder values. + /// See: https://agentclientprotocol.com/rfds/session-config-options + #[serde(rename = "config_option_update")] + ConfigOptionUpdate { + /// List of configuration options to display + #[serde(rename = "configOptions")] + config_options: Vec, + }, + + /// Unknown update type (forward compatibility - pass through as raw JSON) + #[serde(untagged)] + Unknown { + #[serde(flatten)] + data: serde_json::Value, + }, +} + +impl SessionUpdateVariant { + /// Returns the variant name for logging (without data) + pub fn variant_name(&self) -> &'static str { + match self { + SessionUpdateVariant::AgentMessageChunk { .. } => "AgentMessageChunk", + SessionUpdateVariant::MessageComplete { .. } => "MessageComplete", + SessionUpdateVariant::SessionIdle { .. } => "SessionIdle", + SessionUpdateVariant::AvailableCommandsUpdate { .. } => "AvailableCommandsUpdate", + SessionUpdateVariant::ConnectorChanged { .. } => "ConnectorChanged", + SessionUpdateVariant::ToolCall { .. } => "ToolCall", + SessionUpdateVariant::ToolCallUpdate { .. } => "ToolCallUpdate", + SessionUpdateVariant::AgentRequest { .. } => "AgentRequest", + SessionUpdateVariant::ModesUpdate { .. } => "ModesUpdate", + SessionUpdateVariant::ModelsUpdate { .. } => "ModelsUpdate", + SessionUpdateVariant::CurrentModeUpdate { .. } => "CurrentModeUpdate", + SessionUpdateVariant::CurrentModelUpdate { .. } => "CurrentModelUpdate", + SessionUpdateVariant::ConfigOptionUpdate { .. } => "ConfigOptionUpdate", + SessionUpdateVariant::Unknown { .. } => "Unknown", + } + } +} + +impl Default for SessionUpdateVariant { + fn default() -> Self { + SessionUpdateVariant::SessionIdle {} + } +} + +// ============================================================================ +// SlashCommand +// ============================================================================ + +/// Slash command definition +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct SlashCommand { + pub name: String, + pub description: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, +} + +// ============================================================================ +// AcpNotification (Legacy) +// ============================================================================ + +/// ACP Notification types for SSE streaming (DEPRECATED - keeping for compatibility) +/// +/// These notifications are sent to clients over the SSE connection to inform +/// them about session events, message streaming, and errors. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde( + tag = "type", + rename_all = "snake_case", + rename_all_fields = "camelCase" +)] +pub enum AcpNotification { + /// A chunk of message content has been received + /// + /// Sent during streaming to provide incremental content updates. + /// The client should append this content to the message being built. + MessageChunk { + /// The session ID this chunk belongs to + session_id: String, + + /// The message ID this chunk belongs to + message_id: String, + + /// The content chunk (typically text) + content: String, + + /// Optional content type (e.g., "text", "thought") + #[serde(skip_serializing_if = "Option::is_none")] + content_type: Option, + }, + + /// A message has been completed + /// + /// Sent when the agent has finished generating a message. + /// The client should finalize the message display. + MessageComplete { + /// The session ID + session_id: String, + + /// The completed message ID + message_id: String, + }, + + /// A session has become idle + /// + /// Sent when the session transitions to an idle state, meaning + /// no active generation is occurring. + SessionIdle { + /// The session ID that became idle + session_id: String, + }, + + /// An error occurred in a session + /// + /// Sent when an error occurs during processing. + SessionError { + /// The session ID where the error occurred + session_id: String, + + /// Human-readable error message + error: String, + }, + + /// A tool call has started or been updated + /// + /// Sent when a tool call is initiated or its status changes. + ToolCallUpdate { + /// The session ID + session_id: String, + + /// The message ID containing this tool call + message_id: String, + + /// The tool call ID + tool_call_id: String, + + /// The tool name + tool_name: String, + + /// Current status (pending, running, completed, error) + status: String, + + /// Optional title for the tool call + #[serde(skip_serializing_if = "Option::is_none")] + title: Option, + + /// Optional error message + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + }, +} + +impl AcpNotification { + /// Get the SSE event type for this notification + /// + /// This is used as the `event:` field in the SSE stream. + pub fn event_type(&self) -> &'static str { + match self { + AcpNotification::MessageChunk { .. } => "message/chunk", + AcpNotification::MessageComplete { .. } => "message/complete", + AcpNotification::SessionIdle { .. } => "session/idle", + AcpNotification::SessionError { .. } => "session/error", + AcpNotification::ToolCallUpdate { .. } => "tool/update", + } + } + + /// Convert to SSE format string + /// + /// Returns a string formatted for SSE transmission: + /// ```text + /// event: + /// data: + /// ``` + pub fn to_sse_string(&self) -> String { + let data = serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string()); + format!("event: {}\ndata: {}\n\n", self.event_type(), data) + } +} + +// ============================================================================ +// ConfigOption Conversion Functions +// ============================================================================ + +/// Convert SessionModeState to a ConfigOption +/// +/// Transforms the protocol's mode state into the ACP `config_option_update` format. +/// Uses "mode" ID to match the ACP protocol standard. +pub fn modes_to_config_option(modes: &dirigent_protocol::SessionModeState) -> ConfigOption { + ConfigOption { + id: "mode".to_string(), + name: "Session Mode".to_string(), + description: None, + category: Some("mode".to_string()), + option_type: ConfigOptionType::Select, + current_value: modes.current_mode_id.clone(), + options: Some( + modes + .available_modes + .iter() + .map(|m| ConfigOptionChoice { + value: m.id.clone(), + name: m.name.clone(), + description: m.description.clone(), + }) + .collect(), + ), + } +} + +/// Convert SessionModelState to a ConfigOption +/// +/// Transforms the protocol's model state into the ACP `config_option_update` format. +/// Uses "model" ID to match the ACP protocol standard. +pub fn models_to_config_option(models: &dirigent_protocol::SessionModelState) -> ConfigOption { + ConfigOption { + id: "model".to_string(), + name: "Model".to_string(), + description: None, + category: Some("model".to_string()), + option_type: ConfigOptionType::Select, + current_value: models.current_model_id.clone(), + options: Some( + models + .available_models + .iter() + .map(|m| ConfigOptionChoice { + value: m.model_id.clone(), + name: m.name.clone(), + description: m.description.clone(), + }) + .collect(), + ), + } +} diff --git a/crates/dirigent_acp_api/tests/common/mod.rs b/crates/dirigent_acp_api/tests/common/mod.rs new file mode 100644 index 0000000..a1b720b --- /dev/null +++ b/crates/dirigent_acp_api/tests/common/mod.rs @@ -0,0 +1,338 @@ +//! Common test utilities for bidirectional flow testing. +//! +//! This module provides mock implementations and test helpers for testing +//! the bidirectional request/response flow in the ACP Server. + +use anyhow::Result; +use serde_json::{json, Value}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, Mutex}; +use std::collections::HashMap; + +/// Mock event for testing event bridge +#[derive(Debug, Clone)] +pub struct MockEvent { + pub event_type: String, + pub data: Value, +} + +/// Mock SSE client for testing +/// +/// Simulates an HTTP client that receives SSE events and posts responses. +pub struct MockSseClient { + pub client_id: String, + /// Events received via SSE + pub received_events: Arc>>, + /// Sender for simulating HTTP POST responses + pub response_tx: mpsc::UnboundedSender<(Value, oneshot::Sender>)>, +} + +impl MockSseClient { + pub fn new(client_id: String) -> (Self, mpsc::UnboundedReceiver<(Value, oneshot::Sender>)>) { + let (response_tx, response_rx) = mpsc::unbounded_channel(); + + ( + Self { + client_id, + received_events: Arc::new(Mutex::new(Vec::new())), + response_tx, + }, + response_rx, + ) + } + + /// Simulate receiving an SSE event + pub async fn receive_sse(&self, event_type: String, data: Value) { + let mut events = self.received_events.lock().await; + events.push(MockEvent { event_type, data }); + } + + /// Get all received events + pub async fn get_events(&self) -> Vec { + let events = self.received_events.lock().await; + events.clone() + } + + /// Get the most recent event of a specific type + pub async fn get_latest_event(&self, event_type: &str) -> Option { + let events = self.received_events.lock().await; + events.iter() + .filter(|e| e.event_type == event_type) + .last() + .cloned() + } + + /// Clear received events + pub async fn clear_events(&self) { + let mut events = self.received_events.lock().await; + events.clear(); + } + + /// Simulate sending a response to /acp/agent_response + pub async fn send_response(&self, response: Value) -> Result<()> { + let (tx, rx) = oneshot::channel(); + self.response_tx.send((response, tx)) + .map_err(|_| anyhow::anyhow!("Failed to send response"))?; + rx.await? + } +} + +/// Mock connector for testing +/// +/// Simulates an ACP connector that can send agent requests and receive responses. +pub struct MockConnector { + pub connector_id: String, + /// Channel for receiving agent requests from this connector + pub request_tx: mpsc::UnboundedSender<(Value, String, String, Value)>, + /// Pending responses (request_id -> sender) + pub pending_responses: Arc>>>, +} + +impl MockConnector { + pub fn new( + connector_id: String, + ) -> (Self, mpsc::UnboundedReceiver<(Value, String, String, Value)>) { + let (request_tx, request_rx) = mpsc::unbounded_channel(); + + ( + Self { + connector_id, + request_tx, + pending_responses: Arc::new(Mutex::new(HashMap::new())), + }, + request_rx, + ) + } + + /// Simulate sending an agent request + pub async fn send_agent_request( + &self, + session_id: String, + request_id: Value, + method: String, + params: Value, + ) -> oneshot::Receiver { + let (tx, rx) = oneshot::channel(); + + // Store the response sender + let mut pending = self.pending_responses.lock().await; + pending.insert(request_id.clone(), tx); + drop(pending); + + // Send the request + self.request_tx.send((request_id, session_id, method, params)) + .expect("Failed to send agent request"); + + rx + } + + /// Complete a pending response (simulating response from ACP Server) + pub async fn complete_response(&self, request_id: Value, response: Value) -> Result<()> { + let mut pending = self.pending_responses.lock().await; + + if let Some(tx) = pending.remove(&request_id) { + tx.send(response) + .map_err(|_| anyhow::anyhow!("Failed to send response to connector"))?; + Ok(()) + } else { + Err(anyhow::anyhow!("No pending response for request_id: {}", request_id)) + } + } +} + +/// Test context for integration tests +/// +/// Provides a complete test environment with mocked components. +pub struct TestContext { + /// Mock SSE clients by client_id + pub clients: Arc>>, + /// Mock connectors by connector_id + pub connectors: Arc>>, +} + +impl TestContext { + pub fn new() -> Self { + Self { + clients: Arc::new(Mutex::new(HashMap::new())), + connectors: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Create a mock SSE client + pub async fn create_client( + &self, + client_id: String, + ) -> (MockSseClient, mpsc::UnboundedReceiver<(Value, oneshot::Sender>)>) { + let (client, response_rx) = MockSseClient::new(client_id.clone()); + + let mut clients = self.clients.lock().await; + clients.insert(client_id.clone(), client.clone()); + + (client, response_rx) + } + + /// Create a mock connector + pub async fn create_connector( + &self, + connector_id: String, + ) -> (MockConnector, mpsc::UnboundedReceiver<(Value, String, String, Value)>) { + let (connector, request_rx) = MockConnector::new(connector_id.clone()); + + let mut connectors = self.connectors.lock().await; + connectors.insert(connector_id.clone(), connector.clone()); + + (connector, request_rx) + } + + /// Get a client by ID + pub async fn get_client(&self, client_id: &str) -> Option { + let clients = self.clients.lock().await; + clients.get(client_id).cloned() + } + + /// Get a connector by ID + pub async fn get_connector(&self, connector_id: &str) -> Option { + let connectors = self.connectors.lock().await; + connectors.get(connector_id).cloned() + } +} + +impl Default for TestContext { + fn default() -> Self { + Self::new() + } +} + +/// Helper to create a sample permission request +pub fn sample_permission_request(request_id: u64) -> Value { + json!({ + "jsonrpc": "2.0", + "id": request_id, + "method": "session/request_permission", + "params": { + "sessionId": "test-session", + "tool": "Write", + "parameters": { + "path": "/tmp/test.txt", + "content": "test" + } + } + }) +} + +/// Helper to create a sample permission response +pub fn sample_permission_response(request_id: u64, allow: bool) -> Value { + json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": { + "selectedOptionId": if allow { "allow" } else { "deny" } + } + }) +} + +/// Helper to extract agent_request data from SSE event +pub fn extract_agent_request(event: &MockEvent) -> Option<(Value, String, Value)> { + if event.event_type != "session/update" { + return None; + } + + let update = event.data.get("update")?; + if update.get("sessionUpdate")?.as_str()? != "agent_request" { + return None; + } + + let request_id = update.get("requestId")?.clone(); + let method = update.get("method")?.as_str()?.to_string(); + let params = update.get("params")?.clone(); + + Some((request_id, method, params)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_mock_client() { + let (client, _response_rx) = MockSseClient::new("test-client".to_string()); + + // Simulate receiving an event + client.receive_sse("session/update".to_string(), json!({"test": "data"})).await; + + // Verify event was received + let events = client.get_events().await; + assert_eq!(events.len(), 1); + assert_eq!(events[0].event_type, "session/update"); + } + + #[tokio::test] + async fn test_mock_connector() { + let (connector, mut request_rx) = MockConnector::new("test-connector".to_string()); + + // Simulate sending an agent request + let response_fut = connector.send_agent_request( + "session-1".to_string(), + json!(0), + "session/request_permission".to_string(), + json!({"tool": "Write"}), + ).await; + + // Verify request was sent + let request = request_rx.recv().await.unwrap(); + assert_eq!(request.0, json!(0)); + assert_eq!(request.1, "session-1"); + assert_eq!(request.2, "session/request_permission"); + + // Simulate completing the response + connector.complete_response(json!(0), json!({"result": "success"})).await.unwrap(); + + // Verify response was received + let response = response_fut.await.unwrap(); + assert_eq!(response, json!({"result": "success"})); + } + + #[tokio::test] + async fn test_test_context() { + let ctx = TestContext::new(); + + // Create client and connector + let (client, _) = ctx.create_client("client-1".to_string()).await; + let (connector, _) = ctx.create_connector("connector-1".to_string()).await; + + // Verify we can retrieve them + assert!(ctx.get_client("client-1").await.is_some()); + assert!(ctx.get_connector("connector-1").await.is_some()); + assert!(ctx.get_client("non-existent").await.is_none()); + } + + #[test] + fn test_sample_helpers() { + let request = sample_permission_request(0); + assert_eq!(request["method"], "session/request_permission"); + + let response = sample_permission_response(0, true); + assert_eq!(response["result"]["selectedOptionId"], "allow"); + } + + #[test] + fn test_extract_agent_request() { + let event = MockEvent { + event_type: "session/update".to_string(), + data: json!({ + "sessionId": "session-1", + "update": { + "sessionUpdate": "agent_request", + "requestId": 0, + "method": "session/request_permission", + "params": {"tool": "Write"} + } + }), + }; + + let (request_id, method, params) = extract_agent_request(&event).unwrap(); + assert_eq!(request_id, json!(0)); + assert_eq!(method, "session/request_permission"); + assert_eq!(params["tool"], "Write"); + } +} diff --git a/crates/dirigent_acp_api/tests/concurrent_requests_test.rs b/crates/dirigent_acp_api/tests/concurrent_requests_test.rs new file mode 100644 index 0000000..c6afe7e --- /dev/null +++ b/crates/dirigent_acp_api/tests/concurrent_requests_test.rs @@ -0,0 +1,327 @@ +//! Integration test for concurrent agent requests (T050) +//! +//! This test verifies that the system can handle multiple agent requests +//! simultaneously without cross-contamination. +//! +//! Test scenario: +//! 1. Register multiple pending requests concurrently +//! 2. Complete them in random/different order +//! 3. Verify each response goes to the correct request +//! 4. Verify no cross-contamination + +use dirigent_acp_api::agent_requests::AgentRequestTracker; +use serde_json::json; +use tokio::task::JoinSet; + +#[tokio::test] +async fn test_concurrent_requests_basic() { + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + + // Register 5 concurrent requests + let mut receivers = Vec::new(); + for i in 0..5 { + let rx = tracker.register(client_id, json!(i)); + receivers.push((i, rx)); + } + + assert_eq!(tracker.pending_count(), 5); + + // Complete them in reverse order + for i in (0..5).rev() { + let response = json!({"request": i, "result": "success"}); + tracker.complete(client_id, json!(i), response).unwrap(); + } + + assert_eq!(tracker.pending_count(), 0); + + // Verify each receiver got the correct response + for (i, rx) in receivers { + let response = rx.await.unwrap(); + assert_eq!(response["request"], i); + } +} + +#[tokio::test] +async fn test_concurrent_requests_multiple_clients() { + let tracker = AgentRequestTracker::new(); + + // Two clients, each with 3 requests + let client1 = "client-1"; + let client2 = "client-2"; + + let mut receivers1 = Vec::new(); + let mut receivers2 = Vec::new(); + + for i in 0..3 { + let rx1 = tracker.register(client1, json!(i)); + let rx2 = tracker.register(client2, json!(i)); + receivers1.push((i, rx1)); + receivers2.push((i, rx2)); + } + + assert_eq!(tracker.pending_count(), 6); + assert_eq!(tracker.client_pending_count(client1), 3); + assert_eq!(tracker.client_pending_count(client2), 3); + + // Complete client1's requests + for i in 0..3 { + let response = json!({"client": 1, "request": i}); + tracker.complete(client1, json!(i), response).unwrap(); + } + + assert_eq!(tracker.client_pending_count(client1), 0); + assert_eq!(tracker.client_pending_count(client2), 3); + + // Complete client2's requests + for i in 0..3 { + let response = json!({"client": 2, "request": i}); + tracker.complete(client2, json!(i), response).unwrap(); + } + + assert_eq!(tracker.pending_count(), 0); + + // Verify each receiver got the correct response + for (i, rx) in receivers1 { + let response = rx.await.unwrap(); + assert_eq!(response["client"], 1); + assert_eq!(response["request"], i); + } + + for (i, rx) in receivers2 { + let response = rx.await.unwrap(); + assert_eq!(response["client"], 2); + assert_eq!(response["request"], i); + } +} + +#[tokio::test] +async fn test_concurrent_requests_same_id_different_clients() { + // Test that same request_id for different clients are handled independently + let tracker = AgentRequestTracker::new(); + + let client1 = "client-1"; + let client2 = "client-2"; + let request_id = json!(0); // Same ID for both + + let rx1 = tracker.register(client1, request_id.clone()); + let rx2 = tracker.register(client2, request_id.clone()); + + assert_eq!(tracker.pending_count(), 2); + + // Complete client1's request + let response1 = json!({"client": "client-1"}); + tracker.complete(client1, request_id.clone(), response1.clone()).unwrap(); + + // Complete client2's request + let response2 = json!({"client": "client-2"}); + tracker.complete(client2, request_id, response2.clone()).unwrap(); + + assert_eq!(tracker.pending_count(), 0); + + // Verify each got the correct response + let received1 = rx1.await.unwrap(); + let received2 = rx2.await.unwrap(); + + assert_eq!(received1["client"], "client-1"); + assert_eq!(received2["client"], "client-2"); +} + +#[tokio::test] +async fn test_concurrent_async_completion() { + // Test completing requests from multiple async tasks concurrently + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + let num_requests = 10; + + // Register requests + let mut receivers = Vec::new(); + for i in 0..num_requests { + let rx = tracker.register(client_id, json!(i)); + receivers.push((i, rx)); + } + + assert_eq!(tracker.pending_count(), num_requests); + + // Spawn tasks to complete requests concurrently + let mut join_set = JoinSet::new(); + + for i in 0..num_requests { + let tracker_clone = tracker.clone(); + join_set.spawn(async move { + // Small delay to ensure concurrency + tokio::time::sleep(tokio::time::Duration::from_millis(((i % 3) * 10) as u64)).await; + let response = json!({"request": i, "result": "success"}); + tracker_clone.complete(client_id, json!(i), response) + }); + } + + // Wait for all completions + while let Some(result) = join_set.join_next().await { + assert!(result.unwrap().is_ok()); + } + + assert_eq!(tracker.pending_count(), 0); + + // Verify all receivers got correct responses + for (i, rx) in receivers { + let response = rx.await.unwrap(); + assert_eq!(response["request"], i); + } +} + +#[tokio::test] +async fn test_concurrent_register_and_complete() { + // Test registering and completing requests concurrently + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + let num_requests = 20; + + let mut join_set = JoinSet::new(); + + // Spawn tasks to register and complete requests + for i in 0..num_requests { + let tracker_clone = tracker.clone(); + join_set.spawn(async move { + // Register + let rx = tracker_clone.register(client_id, json!(i)); + + // Small delay + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Complete + let response = json!({"request": i}); + tracker_clone.complete(client_id, json!(i), response.clone()).unwrap(); + + // Wait for response + let received = rx.await.unwrap(); + assert_eq!(received["request"], i); + + i + }); + } + + // Wait for all tasks + let mut completed = Vec::new(); + while let Some(result) = join_set.join_next().await { + completed.push(result.unwrap()); + } + + // All requests should have completed + assert_eq!(completed.len(), num_requests); + assert_eq!(tracker.pending_count(), 0); +} + +#[tokio::test] +async fn test_concurrent_mixed_operations() { + // Test mix of register, complete, and timeout operations + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + + let mut join_set = JoinSet::new(); + + // Spawn 15 tasks with different behaviors + for i in 0..15 { + let tracker_clone = tracker.clone(); + join_set.spawn(async move { + let rx = tracker_clone.register(client_id, json!(i)); + + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + match i % 3 { + 0 => { + // Complete normally + let response = json!({"request": i, "type": "complete"}); + tracker_clone.complete(client_id, json!(i), response.clone()).unwrap(); + let received = rx.await.unwrap(); + assert_eq!(received["type"], "complete"); + "completed" + } + 1 => { + // Timeout + tracker_clone.timeout(client_id, json!(i)); + assert!(rx.await.is_err()); + "timeout" + } + _ => { + // Complete with delay + tokio::time::sleep(tokio::time::Duration::from_millis(20)).await; + let response = json!({"request": i, "type": "delayed"}); + tracker_clone.complete(client_id, json!(i), response.clone()).unwrap(); + let received = rx.await.unwrap(); + assert_eq!(received["type"], "delayed"); + "delayed" + } + } + }); + } + + // Wait for all tasks + let mut results = Vec::new(); + while let Some(result) = join_set.join_next().await { + results.push(result.unwrap()); + } + + assert_eq!(results.len(), 15); + + // Count outcomes + let completed = results.iter().filter(|&r| r == &"completed").count(); + let timeout = results.iter().filter(|&r| r == &"timeout").count(); + let delayed = results.iter().filter(|&r| r == &"delayed").count(); + + assert_eq!(completed, 5); + assert_eq!(timeout, 5); + assert_eq!(delayed, 5); + + // All should be cleaned up + assert_eq!(tracker.pending_count(), 0); +} + +#[tokio::test] +async fn test_high_concurrency() { + // Stress test with many concurrent requests + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + let num_requests = 100; + + let mut join_set = JoinSet::new(); + + for i in 0..num_requests { + let tracker_clone = tracker.clone(); + join_set.spawn(async move { + let rx = tracker_clone.register(client_id, json!(i)); + + // Random-ish delay + let delay = ((i * 7) % 20) as u64; + tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await; + + let response = json!({"request": i}); + tracker_clone.complete(client_id, json!(i), response).unwrap(); + + rx.await.unwrap()["request"].as_u64().unwrap() + }); + } + + // Collect all results + let mut results = Vec::new(); + while let Some(result) = join_set.join_next().await { + results.push(result.unwrap()); + } + + // Verify all requests completed + assert_eq!(results.len(), num_requests); + + // Verify all request IDs are present + results.sort(); + for (idx, &val) in results.iter().enumerate() { + assert_eq!(val, idx as u64); + } + + // All cleaned up + assert_eq!(tracker.pending_count(), 0); +} diff --git a/crates/dirigent_acp_api/tests/disconnect_test.rs b/crates/dirigent_acp_api/tests/disconnect_test.rs new file mode 100644 index 0000000..cd8cb40 --- /dev/null +++ b/crates/dirigent_acp_api/tests/disconnect_test.rs @@ -0,0 +1,286 @@ +//! Integration test for client disconnection (T051) +//! +//! This test verifies that pending agent requests are cleaned up when a client +//! disconnects, and that other clients are unaffected. +//! +//! Test scenario: +//! 1. Register pending requests for multiple clients +//! 2. Simulate client disconnection +//! 3. Verify cleanup occurs for disconnected client +//! 4. Verify other clients are unaffected + +use dirigent_acp_api::agent_requests::AgentRequestTracker; +use serde_json::json; + +#[tokio::test] +async fn test_disconnect_single_client() { + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + + // Register 3 pending requests + let rx1 = tracker.register(client_id, json!(0)); + let rx2 = tracker.register(client_id, json!(1)); + let rx3 = tracker.register(client_id, json!(2)); + + assert_eq!(tracker.pending_count(), 3); + assert_eq!(tracker.client_pending_count(client_id), 3); + + // Simulate client disconnection - clear all requests for this client + tracker.clear(Some(client_id)); + + // All requests should be removed + assert_eq!(tracker.pending_count(), 0); + assert_eq!(tracker.client_pending_count(client_id), 0); + + // All receivers should get errors (channels closed) + assert!(rx1.await.is_err()); + assert!(rx2.await.is_err()); + assert!(rx3.await.is_err()); +} + +#[tokio::test] +async fn test_disconnect_multiple_clients() { + let tracker = AgentRequestTracker::new(); + + let client1 = "client-1"; + let client2 = "client-2"; + let client3 = "client-3"; + + // Register requests for all clients + let rx1_1 = tracker.register(client1, json!(0)); + let rx1_2 = tracker.register(client1, json!(1)); + + let rx2_1 = tracker.register(client2, json!(0)); + let rx2_2 = tracker.register(client2, json!(1)); + let rx2_3 = tracker.register(client2, json!(2)); + + let rx3_1 = tracker.register(client3, json!(0)); + + assert_eq!(tracker.pending_count(), 6); + assert_eq!(tracker.client_pending_count(client1), 2); + assert_eq!(tracker.client_pending_count(client2), 3); + assert_eq!(tracker.client_pending_count(client3), 1); + + // Disconnect client2 + tracker.clear(Some(client2)); + + // Only client2's requests should be removed + assert_eq!(tracker.pending_count(), 3); + assert_eq!(tracker.client_pending_count(client1), 2); + assert_eq!(tracker.client_pending_count(client2), 0); + assert_eq!(tracker.client_pending_count(client3), 1); + + // Client2's receivers should error + assert!(rx2_1.await.is_err()); + assert!(rx2_2.await.is_err()); + assert!(rx2_3.await.is_err()); + + // Client1 and client3 should still work + tracker.complete(client1, json!(0), json!({"result": "client1-0"})).unwrap(); + assert_eq!(rx1_1.await.unwrap()["result"], "client1-0"); + + tracker.complete(client3, json!(0), json!({"result": "client3-0"})).unwrap(); + assert_eq!(rx3_1.await.unwrap()["result"], "client3-0"); + + // Complete remaining client1 request + tracker.complete(client1, json!(1), json!({"result": "client1-1"})).unwrap(); + assert_eq!(rx1_2.await.unwrap()["result"], "client1-1"); + + // All cleaned up + assert_eq!(tracker.pending_count(), 0); +} + +#[tokio::test] +async fn test_disconnect_then_reconnect() { + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + + // First connection - register requests + let rx1 = tracker.register(client_id, json!(0)); + let rx2 = tracker.register(client_id, json!(1)); + + assert_eq!(tracker.pending_count(), 2); + + // Disconnect - cleanup + tracker.clear(Some(client_id)); + + assert_eq!(tracker.pending_count(), 0); + + // Old receivers should error + assert!(rx1.await.is_err()); + assert!(rx2.await.is_err()); + + // Reconnect - register new requests (same client_id, same request_ids) + let rx3 = tracker.register(client_id, json!(0)); + let rx4 = tracker.register(client_id, json!(1)); + + assert_eq!(tracker.pending_count(), 2); + + // Complete new requests + tracker.complete(client_id, json!(0), json!({"result": "new-0"})).unwrap(); + tracker.complete(client_id, json!(1), json!({"result": "new-1"})).unwrap(); + + // New receivers should get responses + assert_eq!(rx3.await.unwrap()["result"], "new-0"); + assert_eq!(rx4.await.unwrap()["result"], "new-1"); + + assert_eq!(tracker.pending_count(), 0); +} + +#[tokio::test] +async fn test_disconnect_no_pending_requests() { + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + + // No pending requests + assert_eq!(tracker.client_pending_count(client_id), 0); + + // Disconnect should be no-op + tracker.clear(Some(client_id)); + + assert_eq!(tracker.pending_count(), 0); +} + +#[tokio::test] +async fn test_clear_all_clients() { + let tracker = AgentRequestTracker::new(); + + let client1 = "client-1"; + let client2 = "client-2"; + let client3 = "client-3"; + + // Register requests for multiple clients + let rx1 = tracker.register(client1, json!(0)); + let rx2 = tracker.register(client2, json!(0)); + let rx3 = tracker.register(client3, json!(0)); + + assert_eq!(tracker.pending_count(), 3); + + // Clear all (simulating server shutdown) + tracker.clear(None); + + // All should be removed + assert_eq!(tracker.pending_count(), 0); + assert_eq!(tracker.client_pending_count(client1), 0); + assert_eq!(tracker.client_pending_count(client2), 0); + assert_eq!(tracker.client_pending_count(client3), 0); + + // All receivers should error + assert!(rx1.await.is_err()); + assert!(rx2.await.is_err()); + assert!(rx3.await.is_err()); +} + +#[tokio::test] +async fn test_disconnect_race_with_completion() { + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + + // Register multiple requests + let rx1 = tracker.register(client_id, json!(0)); + let rx2 = tracker.register(client_id, json!(1)); + let rx3 = tracker.register(client_id, json!(2)); + + assert_eq!(tracker.pending_count(), 3); + + // Complete one request + tracker.complete(client_id, json!(0), json!({"result": "0"})).unwrap(); + + // Verify it was removed + assert_eq!(tracker.pending_count(), 2); + + // Now disconnect (should only clear remaining requests) + tracker.clear(Some(client_id)); + + assert_eq!(tracker.pending_count(), 0); + + // First receiver should have gotten response + assert_eq!(rx1.await.unwrap()["result"], "0"); + + // Other receivers should error + assert!(rx2.await.is_err()); + assert!(rx3.await.is_err()); +} + +#[tokio::test] +async fn test_partial_disconnect_completion() { + // Test that completing a request after disconnect fails gracefully + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + + let _rx = tracker.register(client_id, json!(0)); + + assert_eq!(tracker.pending_count(), 1); + + // Disconnect + tracker.clear(Some(client_id)); + + assert_eq!(tracker.pending_count(), 0); + + // Try to complete after disconnect - should fail + let result = tracker.complete(client_id, json!(0), json!({"result": "late"})); + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_concurrent_disconnect_and_complete() { + use tokio::task::JoinSet; + + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + + // Register many requests + for i in 0..50 { + tracker.register(client_id, json!(i)); + } + + assert_eq!(tracker.pending_count(), 50); + + let mut join_set = JoinSet::new(); + + // Spawn task to disconnect after delay + { + let tracker_clone = tracker.clone(); + join_set.spawn(async move { + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + tracker_clone.clear(Some(client_id)); + "disconnected" + }); + } + + // Spawn tasks to complete requests + for i in 0..50 { + let tracker_clone = tracker.clone(); + join_set.spawn(async move { + // Small delay to ensure some complete before disconnect + tokio::time::sleep(tokio::time::Duration::from_millis((i % 5) as u64)).await; + let response = json!({"request": i}); + match tracker_clone.complete(client_id, json!(i), response) { + Ok(_) => "completed", + Err(_) => "failed", + } + }); + } + + // Wait for all tasks + let mut results = Vec::new(); + while let Some(result) = join_set.join_next().await { + results.push(result.unwrap()); + } + + // Should have 1 disconnect + 50 completion attempts + assert_eq!(results.len(), 51); + + // Some completions succeeded, some failed (after disconnect) + let disconnects = results.iter().filter(|r| r == &&"disconnected").count(); + assert_eq!(disconnects, 1); + + // All requests should be cleaned up + assert_eq!(tracker.pending_count(), 0); +} diff --git a/crates/dirigent_acp_api/tests/session_list_test.rs b/crates/dirigent_acp_api/tests/session_list_test.rs new file mode 100644 index 0000000..81bb199 --- /dev/null +++ b/crates/dirigent_acp_api/tests/session_list_test.rs @@ -0,0 +1,153 @@ +//! Integration tests for session/list RPC method + +use dirigent_acp_api::{ + NoOpConnectorOperations, RpcHandler, SessionManager, +}; +use dirigent_acp_api::agent_requests::AgentRequestTracker; +use dirigent_acp_api::sse::SseNotifier; +use serde_json::json; +use std::sync::Arc; + +fn create_test_handler() -> RpcHandler { + RpcHandler::new( + SessionManager::new(), + NoOpConnectorOperations, + SseNotifier::new(), + Arc::new(AgentRequestTracker::new()), + ) +} + +#[tokio::test] +async fn test_session_list_returns_sessions() { + let handler = create_test_handler(); + + let request_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "session/list", + "params": { + "connectorId": "stub-connector" + } + }); + + let response = handler.handle_request(&request_body.to_string(), None).await; + let response_json = serde_json::to_value(&response).unwrap(); + + // NoOp returns one stub session + let result = &response_json["result"]; + assert!(result["sessions"].is_array()); + let sessions = result["sessions"].as_array().unwrap(); + assert!(!sessions.is_empty()); + assert!(sessions[0]["sessionId"].is_string()); +} + +#[tokio::test] +async fn test_session_list_no_params() { + let handler = create_test_handler(); + + // session/list with no params should use default connector + let request_body = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "session/list" + }); + + let response = handler.handle_request(&request_body.to_string(), None).await; + let response_json = serde_json::to_value(&response).unwrap(); + + // Should succeed using default connector from NoOp + let result = &response_json["result"]; + assert!(result["sessions"].is_array()); +} + +#[tokio::test] +async fn test_session_list_creates_mappings() { + let session_manager = SessionManager::new(); + let handler = RpcHandler::new( + session_manager.clone(), + NoOpConnectorOperations, + SseNotifier::new(), + Arc::new(AgentRequestTracker::new()), + ); + + // First, initialize to register a client + let init_body = json!({ + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": {} + }); + let init_response = handler.handle_request(&init_body.to_string(), Some("test-client")).await; + let init_json = serde_json::to_value(&init_response).unwrap(); + let client_id = init_json["result"]["clientId"].as_str().unwrap(); + + // Now list sessions + let request_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "session/list", + "params": { + "connectorId": "stub-connector" + } + }); + + let response = handler.handle_request(&request_body.to_string(), Some(client_id)).await; + let response_json = serde_json::to_value(&response).unwrap(); + + let sessions = response_json["result"]["sessions"].as_array().unwrap(); + assert!(!sessions.is_empty()); + + // Session mapping should exist for the returned session + let session_id = sessions[0]["sessionId"].as_str().unwrap(); + let mapping = session_manager.get_mapping(session_id); + assert!(mapping.is_some(), "Session mapping should be created for listed sessions"); +} + +#[tokio::test] +async fn test_session_load_without_connector_id() { + let handler = create_test_handler(); + + // Standard ACP: only sessionId + cwd + mcpServers, no connectorId + let request_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "session/load", + "params": { + "sessionId": "sess-abc", + "cwd": "G:\\dev\\projects\\test", + "mcpServers": [] + } + }); + + let response = handler.handle_request(&request_body.to_string(), None).await; + let response_json = serde_json::to_value(&response).unwrap(); + + let result = &response_json["result"]; + assert!(result["sessionId"].is_string(), "session/load should succeed without connectorId"); + assert!(result["createdAt"].is_string()); +} + +#[tokio::test] +async fn test_initialize_advertises_list_sessions() { + let handler = create_test_handler(); + + let request_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": {} + }); + + let response = handler.handle_request(&request_body.to_string(), None).await; + let response_json = serde_json::to_value(&response).unwrap(); + + let caps = &response_json["result"]["agentCapabilities"]; + assert_eq!(caps["listSessions"], true, "listSessions capability should be advertised"); + assert_eq!(caps["loadSession"], true, "loadSession capability should still be advertised"); + + // Verify nested sessionCapabilities.list is advertised (required by Zed v0.9.4+) + assert!( + caps["sessionCapabilities"]["list"].is_object(), + "sessionCapabilities.list should be advertised as an empty object" + ); +} diff --git a/crates/dirigent_acp_api/tests/session_resume_test.rs b/crates/dirigent_acp_api/tests/session_resume_test.rs new file mode 100644 index 0000000..a0a3cf6 --- /dev/null +++ b/crates/dirigent_acp_api/tests/session_resume_test.rs @@ -0,0 +1,134 @@ +//! Integration tests for session/resume RPC method + +use dirigent_acp_api::{ + NoOpConnectorOperations, RpcHandler, SessionManager, +}; +use dirigent_acp_api::agent_requests::AgentRequestTracker; +use dirigent_acp_api::sse::SseNotifier; +use serde_json::json; +use std::sync::Arc; + +fn create_test_handler() -> RpcHandler { + RpcHandler::new( + SessionManager::new(), + NoOpConnectorOperations, + SseNotifier::new(), + Arc::new(AgentRequestTracker::new()), + ) +} + +#[tokio::test] +async fn test_session_resume_returns_session() { + let handler = create_test_handler(); + + let request_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "session/resume", + "params": { + "sessionId": "sess-123", + "connectorId": "stub-connector" + } + }); + + let response = handler.handle_request(&request_body.to_string(), None).await; + let response_json = serde_json::to_value(&response).unwrap(); + + let result = &response_json["result"]; + assert!(result["sessionId"].is_string()); + assert!(result["connectorId"].is_string()); + assert!(result["createdAt"].is_string()); +} + +#[tokio::test] +async fn test_session_resume_missing_params() { + let handler = create_test_handler(); + + let request_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "session/resume" + }); + + let response = handler.handle_request(&request_body.to_string(), None).await; + let response_json = serde_json::to_value(&response).unwrap(); + + // Should return error for missing params + assert!(response_json["error"].is_object()); +} + +#[tokio::test] +async fn test_session_resume_creates_mapping() { + let session_manager = SessionManager::new(); + let handler = RpcHandler::new( + session_manager.clone(), + NoOpConnectorOperations, + SseNotifier::new(), + Arc::new(AgentRequestTracker::new()), + ); + + let request_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "session/resume", + "params": { + "sessionId": "sess-456", + "connectorId": "stub-connector" + } + }); + + let response = handler.handle_request(&request_body.to_string(), None).await; + let response_json = serde_json::to_value(&response).unwrap(); + + let session_id = response_json["result"]["sessionId"].as_str().unwrap(); + let mapping = session_manager.get_mapping(session_id); + assert!(mapping.is_some(), "Session mapping should be created for resumed session"); +} + +#[tokio::test] +async fn test_session_resume_without_connector_id() { + let handler = create_test_handler(); + + // Standard ACP: only sessionId, no connectorId — should resolve via default connector + let request_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "session/resume", + "params": { + "sessionId": "sess-789", + "cwd": "G:\\dev\\projects\\test" + } + }); + + let response = handler.handle_request(&request_body.to_string(), None).await; + let response_json = serde_json::to_value(&response).unwrap(); + + let result = &response_json["result"]; + assert!(result["sessionId"].is_string(), "Should succeed without connectorId"); + assert!(result["createdAt"].is_string()); +} + +#[tokio::test] +async fn test_initialize_advertises_session_resume() { + let handler = create_test_handler(); + + let request_body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": {} + }); + + let response = handler.handle_request(&request_body.to_string(), None).await; + let response_json = serde_json::to_value(&response).unwrap(); + + let caps = &response_json["result"]["agentCapabilities"]; + assert!( + caps["sessionCapabilities"]["list"].is_object(), + "sessionCapabilities.list should be advertised as an empty object" + ); + assert!( + caps["sessionCapabilities"]["resume"].is_object(), + "sessionCapabilities.resume should be advertised as an empty object" + ); +} diff --git a/crates/dirigent_acp_api/tests/timeout_test.rs b/crates/dirigent_acp_api/tests/timeout_test.rs new file mode 100644 index 0000000..1d8c26e --- /dev/null +++ b/crates/dirigent_acp_api/tests/timeout_test.rs @@ -0,0 +1,229 @@ +//! Integration test for timeout handling (T049) +//! +//! This test verifies that the system handles timeouts gracefully when a client +//! fails to respond to an agent request within the timeout period. +//! +//! Test scenario: +//! 1. Register a pending agent request +//! 2. Wait for timeout (using reduced timeout for testing) +//! 3. Verify timeout occurs +//! 4. Verify cleanup happens correctly +//! 5. Verify no resource leaks + +use dirigent_acp_api::agent_requests::AgentRequestTracker; +use serde_json::json; +use tokio::time::{timeout, Duration}; + +#[tokio::test] +async fn test_timeout_basic() { + // Create tracker + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + let request_id = json!(0); + + // Register request + let receiver = tracker.register(client_id, request_id.clone()); + assert_eq!(tracker.pending_count(), 1); + + // Wait for timeout (use short timeout for testing) + let result = timeout(Duration::from_millis(100), receiver).await; + + // Should timeout + assert!(result.is_err(), "Expected timeout but request completed"); + + // Manually trigger cleanup (in production, event bridge does this) + tracker.timeout(client_id, request_id); + + // Verify cleanup + assert_eq!(tracker.pending_count(), 0); +} + +#[tokio::test] +async fn test_timeout_cleanup() { + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + let request_id = json!(123); + + // Register request + let receiver = tracker.register(client_id, request_id.clone()); + assert_eq!(tracker.pending_count(), 1); + + // Trigger timeout before waiting + tracker.timeout(client_id, request_id); + + // Verify cleanup happened + assert_eq!(tracker.pending_count(), 0); + + // Receiver should get error (channel closed) + let result = receiver.await; + assert!(result.is_err(), "Expected receiver to get error after timeout"); +} + +#[tokio::test] +async fn test_timeout_multiple_clients() { + let tracker = AgentRequestTracker::new(); + + let client1 = "client-1"; + let client2 = "client-2"; + + // Register requests for both clients + let rx1 = tracker.register(client1, json!(0)); + let rx2 = tracker.register(client2, json!(0)); + + assert_eq!(tracker.pending_count(), 2); + assert_eq!(tracker.client_pending_count(client1), 1); + assert_eq!(tracker.client_pending_count(client2), 1); + + // Timeout only client1's request + tracker.timeout(client1, json!(0)); + + // Verify only client1's request is removed + assert_eq!(tracker.pending_count(), 1); + assert_eq!(tracker.client_pending_count(client1), 0); + assert_eq!(tracker.client_pending_count(client2), 1); + + // Client1's receiver should error + assert!(rx1.await.is_err()); + + // Complete client2's request normally + let response = json!({"result": "success"}); + let result = tracker.complete(client2, json!(0), response); + assert!(result.is_ok()); + + // Client2's receiver should get response + let received = rx2.await.unwrap(); + assert_eq!(received, json!({"result": "success"})); + + // All cleaned up + assert_eq!(tracker.pending_count(), 0); +} + +#[tokio::test] +async fn test_timeout_no_double_cleanup() { + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + let request_id = json!(0); + + // Register request + let _receiver = tracker.register(client_id, request_id.clone()); + assert_eq!(tracker.pending_count(), 1); + + // First timeout - should remove + tracker.timeout(client_id, request_id.clone()); + assert_eq!(tracker.pending_count(), 0); + + // Second timeout - should be no-op (not panic) + tracker.timeout(client_id, request_id); + assert_eq!(tracker.pending_count(), 0); +} + +#[tokio::test] +async fn test_timeout_race_with_complete() { + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + let request_id = json!(0); + + // Register request + let receiver = tracker.register(client_id, request_id.clone()); + assert_eq!(tracker.pending_count(), 1); + + // Complete the request + let response = json!({"result": "success"}); + let result = tracker.complete(client_id, request_id.clone(), response.clone()); + assert!(result.is_ok()); + assert_eq!(tracker.pending_count(), 0); + + // Try to timeout after completion - should be no-op + tracker.timeout(client_id, request_id); + assert_eq!(tracker.pending_count(), 0); + + // Receiver should still get the response + let received = receiver.await.unwrap(); + assert_eq!(received, response); +} + +#[tokio::test] +async fn test_concurrent_timeouts() { + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + + // Register 10 requests + let mut receivers = Vec::new(); + for i in 0..10 { + let rx = tracker.register(client_id, json!(i)); + receivers.push((i, rx)); + } + + assert_eq!(tracker.pending_count(), 10); + + // Spawn tasks to timeout each request after random delays + let tracker_clone = tracker.clone(); + let timeout_handles: Vec<_> = (0..10) + .map(|i| { + let tracker = tracker_clone.clone(); + tokio::spawn(async move { + // Small random-ish delay based on index + tokio::time::sleep(Duration::from_millis((i * 10) as u64)).await; + tracker.timeout(client_id, json!(i)); + }) + }) + .collect(); + + // Wait for all timeouts to complete + for handle in timeout_handles { + handle.await.unwrap(); + } + + // All should be cleaned up + assert_eq!(tracker.pending_count(), 0); + + // All receivers should get errors + for (_i, rx) in receivers { + assert!(rx.await.is_err()); + } +} + +#[tokio::test] +async fn test_timeout_with_actual_delay() { + // This test uses actual time delays to verify timeout behavior more realistically + let tracker = AgentRequestTracker::new(); + + let client_id = "test-client"; + let request_id = json!(0); + + let start = std::time::Instant::now(); + + // Register request + let receiver = tracker.register(client_id, request_id.clone()); + + // Spawn task to timeout after 200ms + let tracker_clone = tracker.clone(); + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + tracker_clone.timeout(client_id, json!(0)); + }); + + // Wait on receiver with longer timeout + let result = timeout(Duration::from_secs(1), receiver).await; + + let elapsed = start.elapsed(); + + // Should complete due to timeout() call, not tokio::time::timeout + assert!(result.is_ok(), "Should complete when timeout() is called"); + assert!(result.unwrap().is_err(), "Receiver should get error"); + + // Should take approximately 200ms + assert!( + elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(300), + "Expected ~200ms but took {:?}", + elapsed + ); + + // Should be cleaned up + assert_eq!(tracker.pending_count(), 0); +} diff --git a/crates/dirigent_anth/CLAUDE.md b/crates/dirigent_anth/CLAUDE.md new file mode 100644 index 0000000..051c429 --- /dev/null +++ b/crates/dirigent_anth/CLAUDE.md @@ -0,0 +1,148 @@ +# Package: dirigent_anth + +Claude Code JSONL session parser and toolkit. + +## Quick Facts +- **Type**: Library +- **Main Entry**: src/lib.rs +- **Dependencies**: serde, serde_json, chrono, uuid, camino, thiserror, tracing, dirs +- **Status**: Core parsing complete — ready for downstream consumers + +## Purpose + +Reads Claude Code's local JSONL session storage (`~/.claude/projects/`) and produces typed, deduplicated, correlated Rust data structures. The types are the product — downstream consumers (archivist import, shell usage analyzers, session browsers) depend on these structs. + +## Key Features + +- **Session Discovery**: Scan `~/.claude/projects/` for all Claude Code projects and sessions +- **JSONL Parsing**: Lenient line-by-line parser that handles unknown fields and message types +- **Streaming Dedup**: Collapse streamed assistant messages to their final version +- **Tool Correlation**: ID-based pairing of tool_use → tool_result across parallel calls +- **Conversation Tree**: Reconstruct uuid/parentUuid threading with branch detection +- **Noise Classification**: Identify meta messages, warmup, interruptions, API errors +- **Sub-Agent Loading**: Recursive parsing of sub-agent JSONL with metadata +- **Timestamp Parsing**: Handle ISO 8601, Unix seconds, and Unix milliseconds + +## Architecture + +### Design Principles + +1. **Types are the product** — Well-typed Rust structs that downstream consumers import +2. **Lenient parsing** — Unknown fields ignored, unknown message types logged and skipped +3. **Stream-oriented** — Line-by-line BufReader parsing, never loads entire files +4. **Sync-first** — File parsing is CPU-bound; no async overhead +5. **Cross-platform** — camino::Utf8PathBuf throughout for Windows/Unix compatibility + +### Module Organization + +- **`types.rs`** — All public data types (Content, ContentBlock, RawMessage variants, ToolCall, etc.) +- **`error.rs`** — AntError enum with I/O, JSON parse, home-not-found, invalid-path variants +- **`parser.rs`** — JSONL line parser and file parser with lenient error handling +- **`dedup.rs`** — Streaming deduplication of assistant messages by uuid +- **`correlation.rs`** — Tool call ↔ result pairing by tool_use_id +- **`tree.rs`** — Conversation tree from uuid/parentUuid relationships +- **`noise.rs`** — Noise pattern classification (meta, warmup, interruptions, etc.) +- **`discovery.rs`** — Filesystem scanning for Claude projects and sessions +- **`subagent.rs`** — Sub-agent JSONL and metadata loading +- **`util.rs`** — Timestamp parsing utilities + +## Public API + +### Quick Start + +```rust +use dirigent_anth::{discover_claude_home, discover_projects, load_session}; + +// Discover all projects +let home = discover_claude_home()?; +let projects = discover_projects(&home)?; + +// Load a session with full parsing +for project in &projects { + for session_ref in &project.sessions { + let session = load_session(session_ref)?; + println!("Messages: {}, Tools: {}, Subagents: {}", + session.messages.len(), + session.tool_exchanges.len(), + session.subagents.len()); + } +} +``` + +### Key Functions + +| Function | Purpose | +|----------|---------| +| `discover_claude_home()` | Find `~/.claude/` directory | +| `discover_projects(home)` | Scan for all project directories | +| `parse_session(path)` | Parse a JSONL file into messages | +| `parse_session_deduped(path)` | Parse with streaming dedup applied | +| `dedup_messages(msgs)` | Deduplicate streamed assistant messages | +| `correlate_tools(msgs)` | Pair tool calls with results by ID | +| `ConversationTree::build(msgs)` | Build conversation tree | +| `classify_noise(msg)` | Classify a message as noise | +| `load_subagents(dir)` | Load sub-agent sessions from artifacts | +| `load_session(ref)` | Full parse: dedup + correlate + tree + subagents | +| `parse_timestamp(value)` | Parse ISO/Unix timestamps | + +## Data Model + +### Claude Code JSONL Format + +Each line in `~/.claude/projects//.jsonl` is a JSON object with a `type` field discriminator. Five types: `user`, `assistant`, `progress`, `system`, `queue-operation`. + +- **Outer wrapper**: camelCase fields (sessionId, parentUuid, isSidechain, gitBranch) +- **Inner message body**: snake_case fields (stop_reason, tool_use_id, is_error) +- **Content**: Either a plain string or array of typed content blocks + +### Content Blocks + +| Type | Fields | +|------|--------| +| text | `text` | +| tool_use | `id`, `name`, `input` | +| tool_result | `tool_use_id`, `content`, `is_error` | +| thinking | `thinking` | +| image | `source` | + +Unknown content block types are silently dropped (lenient deserialization). + +## Testing + +```bash +cargo test --package dirigent_anth +``` + +Tests use synthetic JSONL fixtures in `tests/fixtures/`: +- `minimal_session.jsonl` — Basic session with all message types +- `streaming_dedup.jsonl` — Streaming dedup scenario +- `tool_correlation.jsonl` — Parallel and sequential tool calls +- `branching_tree.jsonl` — Conversation with branches +- `noise_patterns.jsonl` — All noise pattern types +- `subagent/` — Sub-agent session with parent and metadata + +## Error Handling + +- Individual unparseable JSONL lines are logged and skipped (lenient) +- I/O errors and missing directories are propagated as AntError +- Unknown message types are skipped via serde +- Unknown content blocks are silently filtered + +## Related Packages + +- **dirigent_archivist** — Future consumer for session import +- No current dependencies on other dirigent packages (standalone) + +## Future Enhancements + +- Bash command analysis module (shell usage analytics) +- Archivist event transform/import +- CLI tool with scan/analyze/import subcommands +- SQLite caching layer +- Watch mode for new session monitoring + +## Documentation + +- **Package README**: `./README.md` - User-facing overview +- **API Docs**: Run `cargo doc --package dirigent_anth --open` +- **Design Plan**: `docs/superpowers/plans/2026-03-23-dirigent-ant-design.md` diff --git a/crates/dirigent_anth/Cargo.toml b/crates/dirigent_anth/Cargo.toml new file mode 100644 index 0000000..082e4f0 --- /dev/null +++ b/crates/dirigent_anth/Cargo.toml @@ -0,0 +1,37 @@ +[package] +name = "dirigent_anth" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "anth_bear" +path = "src/bin/anth.rs" + +[[bin]] +name = "anth_usage" +path = "src/bin/anth_usage.rs" + +[features] +default = [] +dirigent-paths = ["dep:dirigent_config"] + +[dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +chrono = { version = "0.4", features = ["serde"] } +chrono-tz = "0.10" +uuid = { version = "1.11", features = ["serde"] } +camino = { version = "1.1", features = ["serde1"] } +dirs = "6.0" +thiserror = "2.0" +tracing = "0.1" +regex = "1" +portable-pty = "0.8" +vt100 = "0.15" +dirigent_config = { path = "../dirigent_config", optional = true } + +[dev-dependencies] +tempfile = "3.0" diff --git a/crates/dirigent_anth/src/anth_usage.rs b/crates/dirigent_anth/src/anth_usage.rs new file mode 100644 index 0000000..b9213e3 --- /dev/null +++ b/crates/dirigent_anth/src/anth_usage.rs @@ -0,0 +1,331 @@ +use chrono::{Datelike, NaiveDate, NaiveTime, Utc}; +use chrono_tz::Tz; +use serde::Serialize; + +#[derive(Debug, Serialize, Default)] +pub struct UsageData { + pub gauges: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub contributions: Option, +} + +#[derive(Debug, Serialize)] +pub struct UsageGauge { + pub name: String, + pub percent_used: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub resets: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub resets_iso: Option, +} + +#[derive(Debug, Serialize, Default)] +pub struct ContributionInfo { + #[serde(skip_serializing_if = "Vec::is_empty")] + pub factors: Vec, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub subagents: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ContributionFactor { + pub description: String, + pub percent: u32, +} + +#[derive(Debug, Serialize)] +pub struct SubagentUsage { + pub name: String, + pub percent: u32, +} + +pub struct ProcessedOutput { + pub raw_screen: String, + pub data: UsageData, +} + +pub fn process_usage_screen(raw: &str) -> ProcessedOutput { + let lines: Vec<&str> = raw.lines().collect(); + + let start = lines + .iter() + .position(|l| { + let t = l.trim(); + t.starts_with('─') && t.chars().filter(|&c| c == '─').count() >= 6 + }) + .unwrap_or(0); + + let end = lines + .iter() + .rposition(|l| !l.trim().is_empty()) + .map(|i| i + 1) + .unwrap_or(lines.len()); + + let clean_lines = &lines[start..end]; + let raw_screen = clean_lines.join("\n"); + + let data = extract_usage_data(clean_lines); + + ProcessedOutput { raw_screen, data } +} + +fn extract_usage_data(lines: &[&str]) -> UsageData { + let mut data = UsageData::default(); + let mut i = 0; + + while i < lines.len() { + let trimmed = lines[i].trim(); + + if (trimmed.starts_with("Current session") || trimmed.starts_with("Current week")) + && !trimmed.contains('%') + { + let name = trimmed.to_string(); + if let Some(gauge) = find_gauge(&lines[i..], &name) { + data.gauges.push(gauge); + } + } + + if let Some(factor) = parse_contribution_factor(trimmed) { + data.contributions + .get_or_insert_with(ContributionInfo::default) + .factors + .push(factor); + } + + if trimmed.starts_with("Subagents") { + let subs = parse_subagent_table(&lines[i + 1..]); + if !subs.is_empty() { + data.contributions + .get_or_insert_with(ContributionInfo::default) + .subagents = subs; + } + } + + i += 1; + } + + data +} + +fn find_gauge(lines: &[&str], name: &str) -> Option { + let mut percent = None; + let mut resets_raw = None; + + for line in lines.iter().skip(1).take(4) { + let t = line.trim(); + if let Some(pct) = extract_percent_used(t) { + percent = Some(pct); + } + if t.starts_with("Resets ") { + resets_raw = Some(t.trim_start_matches("Resets ").to_string()); + } + } + + percent.map(|p| { + let resets_iso = resets_raw.as_deref().and_then(parse_reset_to_iso); + UsageGauge { + name: name.to_string(), + percent_used: p, + resets: resets_raw, + resets_iso, + } + }) +} + +/// Parse reset strings like: +/// "12:30pm (Europe/Vienna)" → today at 12:30 in that tz +/// "May 12, 9am (Europe/Vienna)" → May 12 at 09:00 +/// "May 12, 9:30am (Europe/Vienna)" → May 12 at 09:30 +/// "Jun 1, 12pm (America/New_York)" → Jun 1 at 12:00 +/// +/// Claude Code uses JS `Intl.DateTimeFormat` style output. +fn parse_reset_to_iso(s: &str) -> Option { + // Split off the timezone from parentheses + let (datetime_part, tz_str) = { + let open = s.rfind('(')?; + let close = s.rfind(')')?; + let tz = s[open + 1..close].trim(); + let dt = s[..open].trim(); + (dt, tz) + }; + + let tz: Tz = tz_str.parse().ok()?; + let now = Utc::now().with_timezone(&tz); + + let (date, time_str) = if datetime_part.contains(',') { + // "May 12, 9am" or "May 12, 9:30am" + let comma_pos = datetime_part.find(',')?; + let date_part = datetime_part[..comma_pos].trim(); + let time_part = datetime_part[comma_pos + 1..].trim(); + + let date = parse_month_day(date_part, now.year())?; + (date, time_part) + } else { + // "12:30pm" — today in the given timezone + (now.date_naive(), datetime_part) + }; + + let time = parse_12h_time(time_str)?; + let naive = date.and_time(time); + let local = naive.and_local_timezone(tz).earliest()?; + let utc = local.with_timezone(&Utc); + + Some(utc.to_rfc3339()) +} + +/// Parse "May 12", "Jun 1", "December 25", etc. +fn parse_month_day(s: &str, year: i32) -> Option { + let parts: Vec<&str> = s.split_whitespace().collect(); + if parts.len() != 2 { + return None; + } + let month = match parts[0].to_lowercase().as_str() { + "jan" | "january" => 1, + "feb" | "february" => 2, + "mar" | "march" => 3, + "apr" | "april" => 4, + "may" => 5, + "jun" | "june" => 6, + "jul" | "july" => 7, + "aug" | "august" => 8, + "sep" | "september" => 9, + "oct" | "october" => 10, + "nov" | "november" => 11, + "dec" | "december" => 12, + _ => return None, + }; + let day: u32 = parts[1].parse().ok()?; + NaiveDate::from_ymd_opt(year, month, day) +} + +/// Parse "9am", "12pm", "9:30am", "12:30pm" +fn parse_12h_time(s: &str) -> Option { + let s = s.trim().to_lowercase(); + let is_pm = s.ends_with("pm"); + let is_am = s.ends_with("am"); + if !is_pm && !is_am { + return None; + } + + let num_part = &s[..s.len() - 2]; + + let (hour, minute) = if let Some((h, m)) = num_part.split_once(':') { + (h.parse::().ok()?, m.parse::().ok()?) + } else { + (num_part.parse::().ok()?, 0) + }; + + let hour_24 = match (hour, is_pm) { + (12, true) => 12, + (12, false) => 0, + (h, true) => h + 12, + (h, false) => h, + }; + + NaiveTime::from_hms_opt(hour_24, minute, 0) +} + +fn extract_percent_used(line: &str) -> Option { + let line = line.trim(); + if !line.ends_with("% used") { + return None; + } + let before_pct = line.trim_end_matches("% used").trim(); + before_pct + .rsplit_once(char::is_whitespace) + .map(|(_, n)| n) + .unwrap_or(before_pct) + .parse() + .ok() +} + +fn parse_contribution_factor(line: &str) -> Option { + if !line.contains("% of your usage") { + return None; + } + let pct_str = line.split('%').next()?; + let percent: u32 = pct_str.trim().parse().ok()?; + let description = line.to_string(); + Some(ContributionFactor { + description, + percent, + }) +} + +fn parse_subagent_table(lines: &[&str]) -> Vec { + let mut subs = Vec::new(); + for line in lines { + let t = line.trim(); + if t.is_empty() || t.starts_with('─') || t.contains("to day") || t.contains("to cancel") { + break; + } + if let Some(pos) = t.rfind('%') { + let num_start = t[..pos] + .rfind(char::is_whitespace) + .map(|i| i + 1) + .unwrap_or(0); + if let Ok(pct) = t[num_start..pos].parse::() { + let name = t[..num_start].trim().to_string(); + if !name.is_empty() && !name.contains("% of") { + subs.push(SubagentUsage { + name, + percent: pct, + }); + } + } + } + } + subs +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_time_only() { + let t = parse_12h_time("12:30pm").unwrap(); + assert_eq!(t, NaiveTime::from_hms_opt(12, 30, 0).unwrap()); + } + + #[test] + fn parse_time_am() { + let t = parse_12h_time("9am").unwrap(); + assert_eq!(t, NaiveTime::from_hms_opt(9, 0, 0).unwrap()); + } + + #[test] + fn parse_time_12am() { + let t = parse_12h_time("12am").unwrap(); + assert_eq!(t, NaiveTime::from_hms_opt(0, 0, 0).unwrap()); + } + + #[test] + fn parse_time_with_minutes() { + let t = parse_12h_time("9:30am").unwrap(); + assert_eq!(t, NaiveTime::from_hms_opt(9, 30, 0).unwrap()); + } + + #[test] + fn parse_reset_time_only() { + let iso = parse_reset_to_iso("12:30pm (Europe/Vienna)"); + assert!(iso.is_some()); + let iso = iso.unwrap(); + assert!(iso.contains("T")); + // Should end in +00:00 (UTC via rfc3339) + assert!(iso.ends_with("+00:00")); + } + + #[test] + fn parse_reset_date_and_time() { + let iso = parse_reset_to_iso("May 12, 9am (Europe/Vienna)").unwrap(); + assert!(iso.contains("T07:00:00") || iso.contains("T08:00:00")); + // CEST is UTC+2, CET is UTC+1 — depends on whether May 12 is summer time + } + + #[test] + fn parse_month_day_basic() { + let d = parse_month_day("May 12", 2026).unwrap(); + assert_eq!(d, NaiveDate::from_ymd_opt(2026, 5, 12).unwrap()); + } +} diff --git a/crates/dirigent_anth/src/bin/anth.rs b/crates/dirigent_anth/src/bin/anth.rs new file mode 100644 index 0000000..5ecdefa --- /dev/null +++ b/crates/dirigent_anth/src/bin/anth.rs @@ -0,0 +1,252 @@ +//! Minimal CLI for dirigent_anth — validate parsing and search sessions. +//! +//! Usage: +//! cargo run --package dirigent_anth --bin ant # validate all sessions +//! cargo run --package dirigent_anth --bin ant -- search "query" # search user messages +//! cargo run --package dirigent_anth --bin ant -- stats # show statistics + +use dirigent_anth::*; +use std::io::BufRead; + +fn main() { + let args: Vec = std::env::args().skip(1).collect(); + + let home = match discover_claude_home() { + Ok(h) => h, + Err(e) => { + eprintln!("Could not find Claude home: {e}"); + std::process::exit(1); + } + }; + + let projects = match discover_projects(&home) { + Ok(p) => p, + Err(e) => { + eprintln!("Could not discover projects: {e}"); + std::process::exit(1); + } + }; + + match args.first().map(|s| s.as_str()) { + Some("search") => { + let query = args.get(1).map(|s| s.as_str()).unwrap_or(""); + if query.is_empty() { + eprintln!("Usage: ant search "); + std::process::exit(1); + } + cmd_search(&projects, query); + } + Some("stats") => cmd_stats(&projects), + Some("validate") | None => cmd_validate(&projects), + Some(other) => { + eprintln!("Unknown command: {other}"); + eprintln!("Commands: validate (default), search , stats"); + std::process::exit(1); + } + } +} + +/// Validate that the parser can handle all sessions without errors. +fn cmd_validate(projects: &[ClaudeProject]) { + let mut total_sessions = 0; + let mut total_ok = 0; + let mut total_messages = 0; + let mut total_skipped_lines = 0; + let mut errors: Vec<(String, String)> = Vec::new(); + + for project in projects { + println!( + "Project: {} ({} sessions)", + project.original_path, + project.sessions.len() + ); + + for session in &project.sessions { + total_sessions += 1; + + // Raw line-level validation: count how many lines parse vs skip + let (_raw_ok, raw_skip) = validate_lines(&session.jsonl_path); + total_skipped_lines += raw_skip; + + // Full pipeline validation + match load_session(session) { + Ok(parsed) => { + total_ok += 1; + total_messages += parsed.messages.len(); + let tools = parsed.tool_exchanges.len(); + let subs = parsed.subagents.len(); + let branches = if parsed.tree.is_linear() { + "linear" + } else { + "branched" + }; + + if raw_skip > 0 { + println!( + " {} — {} msgs, {} tools, {} subagents, {} | {raw_skip} lines skipped", + &session.id[..8.min(session.id.len())], + parsed.messages.len(), + tools, + subs, + branches, + ); + } + } + Err(e) => { + errors.push((session.id.clone(), e.to_string())); + eprintln!(" {} — ERROR: {e}", &session.id[..8.min(session.id.len())]); + } + } + } + } + + println!("\n--- Validation Summary ---"); + println!("Projects: {}", projects.len()); + println!("Sessions: {total_sessions} ({total_ok} ok, {} errors)", errors.len()); + println!("Messages: {total_messages}"); + if total_skipped_lines > 0 { + println!("Skipped: {total_skipped_lines} unparseable lines"); + } + + if !errors.is_empty() { + println!("\nErrors:"); + for (id, err) in &errors { + println!(" {id}: {err}"); + } + std::process::exit(1); + } +} + +/// Count parseable vs skipped lines in a JSONL file. +fn validate_lines(path: &camino::Utf8Path) -> (usize, usize) { + let file = match std::fs::File::open(path.as_std_path()) { + Ok(f) => f, + Err(_) => return (0, 0), + }; + let reader = std::io::BufReader::new(file); + let mut ok = 0; + let mut skip = 0; + + for (i, line) in reader.lines().enumerate() { + let line = match line { + Ok(l) => l, + Err(_) => { + skip += 1; + continue; + } + }; + if line.trim().is_empty() { + continue; + } + if parse_line(&line, i + 1).is_some() { + ok += 1; + } else { + skip += 1; + } + } + + (ok, skip) +} + +/// Search user messages for a query string (case-insensitive). +fn cmd_search(projects: &[ClaudeProject], query: &str) { + let query_lower = query.to_lowercase(); + let mut hits = 0; + + for project in projects { + for session in &project.sessions { + let messages = match parse_session_deduped(&session.jsonl_path) { + Ok(m) => m, + Err(_) => continue, + }; + + for msg in &messages { + let text = match msg { + types::RawMessage::User(u) => match &u.message.content { + types::Content::Text(s) => s.clone(), + types::Content::Blocks(_) => continue, + }, + types::RawMessage::Assistant(a) => { + let mut parts = Vec::new(); + for block in &a.message.content { + if let types::ContentBlock::Text { text } = block { + parts.push(text.as_str()); + } + } + parts.join(" ") + } + _ => continue, + }; + + if text.to_lowercase().contains(&query_lower) { + let role = match msg { + types::RawMessage::User(_) => "user", + types::RawMessage::Assistant(_) => "assistant", + _ => "other", + }; + let preview = truncate(&text, 120); + println!( + "[{}] {} {} | {}", + &project.original_path, + &session.id[..8.min(session.id.len())], + role, + preview + ); + hits += 1; + } + } + } + } + + println!("\n{hits} matches for \"{query}\""); +} + +/// Show aggregate statistics across all sessions. +fn cmd_stats(projects: &[ClaudeProject]) { + let mut total_sessions = 0; + let mut total_messages = 0; + let mut total_tools = 0; + let mut total_subagents = 0; + let mut tool_counts: std::collections::HashMap = std::collections::HashMap::new(); + + for project in projects { + for session in &project.sessions { + total_sessions += 1; + if let Ok(parsed) = load_session(session) { + total_messages += parsed.messages.len(); + total_tools += parsed.tool_exchanges.len(); + total_subagents += parsed.subagents.len(); + + for ex in &parsed.tool_exchanges { + let name = format!("{:?}", ex.call.name); + *tool_counts.entry(name).or_default() += 1; + } + } + } + } + + println!("--- Statistics ---"); + println!("Projects: {}", projects.len()); + println!("Sessions: {total_sessions}"); + println!("Messages: {total_messages}"); + println!("Tool calls: {total_tools}"); + println!("Sub-agents: {total_subagents}"); + + if !tool_counts.is_empty() { + println!("\nTool usage:"); + let mut sorted: Vec<_> = tool_counts.into_iter().collect(); + sorted.sort_by(|a, b| b.1.cmp(&a.1)); + for (name, count) in sorted.iter().take(15) { + println!(" {name:20} {count}"); + } + } +} + +fn truncate(s: &str, max: usize) -> String { + let s = s.replace('\n', " ").replace('\r', ""); + if s.len() <= max { + s + } else { + format!("{}...", &s[..max]) + } +} diff --git a/crates/dirigent_anth/src/bin/anth_usage.rs b/crates/dirigent_anth/src/bin/anth_usage.rs new file mode 100644 index 0000000..58fd3e0 --- /dev/null +++ b/crates/dirigent_anth/src/bin/anth_usage.rs @@ -0,0 +1,192 @@ +use portable_pty::{CommandBuilder, NativePtySystem, PtySize, PtySystem}; +use std::io::{Read, Write}; +use std::path::PathBuf; +use std::time::Duration; + +const ROWS: u16 = 80; +const COLS: u16 = 120; + +struct Args { + debug: bool, + raw: bool, + no_trust: bool, + workdir: Option, + use_cwd: bool, +} + +fn parse_args() -> Args { + let mut args = Args { + debug: false, + raw: false, + no_trust: false, + workdir: None, + use_cwd: false, + }; + let mut iter = std::env::args().skip(1); + while let Some(arg) = iter.next() { + match arg.as_str() { + "--debug" => args.debug = true, + "--raw" => args.raw = true, + "--no-trust" => args.no_trust = true, + "--cwd" => args.use_cwd = true, + "--workdir" => { + args.workdir = Some(PathBuf::from( + iter.next().expect("--workdir requires a path argument"), + )); + } + other => { + eprintln!("Unknown argument: {other}"); + eprintln!( + "Usage: anth_usage [--debug] [--raw] [--no-trust] [--workdir ] [--cwd]" + ); + std::process::exit(2); + } + } + } + args +} + +fn resolve_workdir(args: &Args) -> PathBuf { + if let Some(ref dir) = args.workdir { + return dir.clone(); + } + if args.use_cwd { + return std::env::current_dir().expect("failed to get current directory"); + } + + #[cfg(feature = "dirigent-paths")] + { + if let Ok(paths) = dirigent_config::DirigentPaths::resolve() { + let noproject = paths.noproject_home_dir(); + if noproject.exists() { + return noproject; + } + } + } + + dirs::home_dir().expect("failed to resolve home directory") +} + +fn grab_screen(parser: &vt100::Parser) -> String { + let screen = parser.screen(); + let mut output = String::new(); + for line in screen.rows(0, COLS) { + output.push_str(&line); + output.push('\n'); + } + output +} + +macro_rules! debug { + ($args:expr, $($tt:tt)*) => { + if $args.debug { + eprintln!($($tt)*); + } + }; +} + +fn main() { + let args = parse_args(); + let workdir = resolve_workdir(&args); + + debug!(args, "Working directory: {}", workdir.display()); + + let pty_system = NativePtySystem::default(); + let pair = pty_system + .openpty(PtySize { + rows: ROWS, + cols: COLS, + pixel_width: 0, + pixel_height: 0, + }) + .expect("failed to open pty"); + + let mut cmd = CommandBuilder::new("claude"); + cmd.cwd(&workdir); + let mut child = pair.slave.spawn_command(cmd).expect("failed to spawn claude"); + drop(pair.slave); + + let mut writer = pair.master.take_writer().expect("failed to get writer"); + let reader = pair.master.try_clone_reader().expect("failed to get reader"); + + let (tx, rx) = std::sync::mpsc::channel(); + std::thread::spawn(move || { + let mut reader = reader; + let mut buf = [0u8; 4096]; + loop { + match reader.read(&mut buf) { + Ok(0) => break, + Ok(n) => { + let _ = tx.send(buf[..n].to_vec()); + } + Err(_) => break, + } + } + }); + + // Wait for claude to render + std::thread::sleep(Duration::from_secs(5)); + + debug!( + args, + "Child alive: {}", + matches!(child.try_wait(), Ok(None)) + ); + + // Grab screen + let mut parser = vt100::Parser::new(ROWS, COLS, 0); + while let Ok(data) = rx.try_recv() { + parser.process(&data); + } + let output = grab_screen(&parser); + debug!(args, "=== SCREEN ===\n{output}=== END ==="); + + // Handle trust prompt + if output.contains("Yes, I trust this folder") { + if args.no_trust { + eprintln!("Folder is not trusted: {}", workdir.display()); + eprintln!("Run claude in this folder manually to trust it, or omit --no-trust."); + let _ = child.kill(); + std::process::exit(1); + } + debug!(args, "Sending enter for trust..."); + writer.write_all(b"\r").expect("failed to confirm trust"); + + std::thread::sleep(Duration::from_secs(3)); + + while let Ok(data) = rx.try_recv() { + parser.process(&data); + } + debug!( + args, + "=== AFTER TRUST ===\n{}=== END ===", + grab_screen(&parser) + ); + } + + // Send /usage + debug!(args, "Sending /usage..."); + writer + .write_all(b"/usage\r") + .expect("failed to send /usage"); + + std::thread::sleep(Duration::from_secs(3)); + + while let Ok(data) = rx.try_recv() { + parser.process(&data); + } + let raw_output = grab_screen(&parser); + + let processed = dirigent_anth::anth_usage::process_usage_screen(&raw_output); + + if args.raw { + println!("{}", processed.raw_screen); + } else { + println!( + "{}", + serde_json::to_string_pretty(&processed.data).expect("failed to serialize usage data") + ); + } + + let _ = child.kill(); +} diff --git a/crates/dirigent_anth/src/claude_grab.rs b/crates/dirigent_anth/src/claude_grab.rs new file mode 100644 index 0000000..1d47902 --- /dev/null +++ b/crates/dirigent_anth/src/claude_grab.rs @@ -0,0 +1,157 @@ +use portable_pty::{Child, CommandBuilder, NativePtySystem, PtySize, PtySystem}; +use std::io::{Read, Write}; +use std::sync::mpsc::{self, Receiver}; +use std::time::Duration; +use vt100::Parser; + +const DEFAULT_ROWS: u16 = 80; +const DEFAULT_COLS: u16 = 120; + +pub struct PtySession { + parser: Parser, + writer: Option>, + rx: Receiver>, + cols: u16, + #[allow(dead_code)] + child: Box, +} + +impl PtySession { + pub fn spawn_claude(args: &[&str]) -> Self { + Self::spawn_claude_with_size(args, DEFAULT_ROWS, DEFAULT_COLS) + } + + pub fn spawn_claude_with_size(args: &[&str], rows: u16, cols: u16) -> Self { + let pty_system = NativePtySystem::default(); + + let pair = pty_system + .openpty(PtySize { + rows, + cols, + pixel_width: 0, + pixel_height: 0, + }) + .expect("failed to open pty"); + + let mut cmd = CommandBuilder::new("claude"); + for arg in args { + cmd.arg(*arg); + } + if let Some(home) = dirs::home_dir() { + cmd.cwd(home); + } + let child = pair + .slave + .spawn_command(cmd) + .expect("failed to spawn claude"); + + drop(pair.slave); + + let writer = pair.master.take_writer().expect("failed to get writer"); + let reader = pair + .master + .try_clone_reader() + .expect("failed to get reader"); + + let (tx, rx) = mpsc::channel::>(); + std::thread::spawn(move || { + let mut reader = reader; + let mut chunk = [0u8; 4096]; + loop { + match reader.read(&mut chunk) { + Ok(0) => break, + Ok(n) => { + if tx.send(chunk[..n].to_vec()).is_err() { + break; + } + } + Err(_) => break, + } + } + }); + + Self { + parser: Parser::new(rows, cols, 0), + writer: Some(writer), + rx, + cols, + child, + } + } + + pub fn grab_screen(&mut self) -> String { + while let Ok(data) = self.rx.try_recv() { + self.parser.process(&data); + } + let deadline = std::time::Instant::now() + Duration::from_millis(200); + while std::time::Instant::now() < deadline { + match self.rx.recv_timeout(Duration::from_millis(50)) { + Ok(data) => self.parser.process(&data), + Err(_) => {} + } + } + + let screen = self.parser.screen(); + let mut output = String::new(); + for line in screen.rows(0, self.cols) { + output.push_str(&line); + output.push('\n'); + } + output + } + + pub fn wait_for(&mut self, needle: &str, timeout: Duration) -> bool { + self.wait_for_any(&[needle], timeout) + } + + pub fn wait_for_any(&mut self, needles: &[&str], timeout: Duration) -> bool { + let deadline = std::time::Instant::now() + timeout; + while std::time::Instant::now() < deadline { + match self.rx.recv_timeout(Duration::from_millis(100)) { + Ok(data) => self.parser.process(&data), + Err(_) => {} + } + let screen = self.parser.screen(); + let mut content = String::new(); + for line in screen.rows(0, self.cols) { + content.push_str(&line); + content.push('\n'); + } + for needle in needles { + if content.contains(needle) { + return true; + } + } + } + false + } + + pub fn is_alive(&mut self) -> bool { + matches!(self.child.try_wait(), Ok(None)) + } + + pub fn send(&mut self, input: &[u8]) { + self.writer.as_mut().expect("writer gone").write_all(input).expect("failed to write to pty"); + } + + pub fn try_send(&mut self, input: &[u8]) -> std::io::Result<()> { + match self.writer.as_mut() { + Some(w) => w.write_all(input), + None => Err(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "writer gone")), + } + } + + pub fn try_send_line(&mut self, text: &str) -> std::io::Result<()> { + self.try_send(text.as_bytes())?; + self.try_send(b"\r") + } + + pub fn send_enter(&mut self) { + self.send(b"\r"); + } + + pub fn send_line(&mut self, text: &str) { + self.send(text.as_bytes()); + self.send_enter(); + } +} diff --git a/crates/dirigent_anth/src/correlation.rs b/crates/dirigent_anth/src/correlation.rs new file mode 100644 index 0000000..5e96bef --- /dev/null +++ b/crates/dirigent_anth/src/correlation.rs @@ -0,0 +1,107 @@ +//! Tool call correlation — matches assistant ToolUse blocks with their +//! corresponding user ToolResult blocks by ID across a message sequence. + +use std::collections::HashMap; + +use crate::types::{ + Content, ContentBlock, RawAssistantMessage, RawMessage, RawUserMessage, ToolCall, + ToolExchange, ToolName, ToolResultData, +}; + +/// Extract tool calls from an assistant message's content blocks. +fn extract_tool_calls(msg: &RawAssistantMessage) -> Vec { + let source_uuid = msg.uuid.clone().unwrap_or_default(); + msg.message + .content + .iter() + .filter_map(|block| { + if let ContentBlock::ToolUse { id, name, input, .. } = block { + Some(ToolCall { + id: id.clone(), + name: ToolName::from(name.clone()), + input: input.clone(), + source_message_uuid: source_uuid.clone(), + }) + } else { + None + } + }) + .collect() +} + +/// Extract tool results from a user message's content blocks. +fn extract_tool_results(msg: &RawUserMessage) -> Vec { + let source_uuid = msg.uuid.clone().unwrap_or_default(); + match &msg.message.content { + Content::Blocks(blocks) => blocks + .iter() + .filter_map(|block| { + if let ContentBlock::ToolResult { tool_use_id, content, is_error } = block { + // Extract text content from the tool result + let text_content = content.as_ref().and_then(|c| match c { + Content::Text(s) => Some(s.clone()), + Content::Blocks(bs) => { + // Concatenate text blocks + let texts: Vec<&str> = bs + .iter() + .filter_map(|b| { + if let ContentBlock::Text { text } = b { + Some(text.as_str()) + } else { + None + } + }) + .collect(); + if texts.is_empty() { None } else { Some(texts.join("\n")) } + } + }); + Some(ToolResultData { + tool_use_id: tool_use_id.clone(), + content: text_content, + is_error: *is_error, + source_message_uuid: source_uuid.clone(), + }) + } else { + None + } + }) + .collect(), + Content::Text(_) => Vec::new(), + } +} + +/// Correlate tool calls with their results across a message sequence. +/// +/// Iterates messages in order, collecting ToolUse blocks from assistant +/// messages and matching them by ID to ToolResult blocks in subsequent user +/// messages. Any tool calls that never received a result are emitted with +/// `result: None`. +pub fn correlate_tools(messages: &[RawMessage]) -> Vec { + let mut pending: HashMap = HashMap::new(); + let mut exchanges: Vec = Vec::new(); + + for msg in messages { + match msg { + RawMessage::Assistant(asst) => { + for call in extract_tool_calls(asst) { + pending.insert(call.id.clone(), call); + } + } + RawMessage::User(user) => { + for result in extract_tool_results(user) { + if let Some(call) = pending.remove(&result.tool_use_id) { + exchanges.push(ToolExchange { call, result: Some(result) }); + } + } + } + _ => {} + } + } + + // Emit unmatched calls (no result found) + for (_id, call) in pending { + exchanges.push(ToolExchange { call, result: None }); + } + + exchanges +} diff --git a/crates/dirigent_anth/src/dedup.rs b/crates/dirigent_anth/src/dedup.rs new file mode 100644 index 0000000..27c2bd7 --- /dev/null +++ b/crates/dirigent_anth/src/dedup.rs @@ -0,0 +1,116 @@ +//! Streaming deduplication for assistant messages. + +use crate::types::{RawAssistantMessage, RawMessage}; + +/// Deduplicate streamed assistant messages. +/// +/// Claude Code writes multiple JSONL lines for the same assistant message +/// as it streams. Each shares the same `uuid` with progressively more +/// content blocks. We keep only the last entry per uuid. +/// +/// Non-assistant messages pass through unchanged. +pub fn dedup_messages(messages: Vec) -> Vec { + let mut result: Vec = Vec::new(); + let mut buffered_assistant: Option = None; + + for msg in messages { + match msg { + RawMessage::Assistant(ref asst) => { + let current_uuid = asst.uuid.as_deref(); + + if let Some(ref buffered) = buffered_assistant { + let buffered_uuid = buffered.uuid.as_deref(); + if current_uuid == buffered_uuid { + // Same uuid — replace buffer with newer (more complete) version + buffered_assistant = Some(asst.clone()); + } else { + // Different uuid — flush old buffer, start new + result.push(RawMessage::Assistant(buffered.clone())); + buffered_assistant = Some(asst.clone()); + } + } else { + // No buffer yet — start buffering + buffered_assistant = Some(asst.clone()); + } + } + _ => { + // Non-assistant: flush any buffered assistant first, then push this + if let Some(buffered) = buffered_assistant.take() { + result.push(RawMessage::Assistant(buffered)); + } + result.push(msg); + } + } + } + + // Flush remaining buffer + if let Some(buffered) = buffered_assistant { + result.push(RawMessage::Assistant(buffered)); + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{AssistantInner, ContentBlock}; + + fn make_assistant(uuid: &str, stop_reason: Option<&str>, text: &str) -> RawMessage { + RawMessage::Assistant(RawAssistantMessage { + uuid: Some(uuid.to_string()), + parent_uuid: None, + timestamp: None, + session_id: None, + cwd: None, + version: None, + git_branch: None, + is_sidechain: false, + request_id: None, + message: AssistantInner { + model: None, + id: None, + message_type: None, + role: None, + content: vec![ContentBlock::Text { text: text.to_string() }], + stop_reason: stop_reason.map(str::to_string), + stop_sequence: None, + usage: None, + }, + }) + } + + #[test] + fn dedup_single_streamed_message() { + let msgs = vec![ + make_assistant("a-1", None, "Part 1"), + make_assistant("a-1", None, "Part 1 more"), + make_assistant("a-1", Some("end_turn"), "Part 1 final"), + ]; + let deduped = dedup_messages(msgs); + assert_eq!(deduped.len(), 1); + if let RawMessage::Assistant(a) = &deduped[0] { + assert_eq!(a.message.stop_reason.as_deref(), Some("end_turn")); + match &a.message.content[0] { + ContentBlock::Text { text } => assert_eq!(text, "Part 1 final"), + _ => panic!("Expected text block"), + } + } + } + + #[test] + fn dedup_two_distinct_assistants() { + let msgs = vec![ + make_assistant("a-1", Some("end_turn"), "First"), + make_assistant("a-2", Some("end_turn"), "Second"), + ]; + let deduped = dedup_messages(msgs); + assert_eq!(deduped.len(), 2); + } + + #[test] + fn dedup_empty_input() { + let deduped = dedup_messages(vec![]); + assert!(deduped.is_empty()); + } +} diff --git a/crates/dirigent_anth/src/discovery.rs b/crates/dirigent_anth/src/discovery.rs new file mode 100644 index 0000000..5e830f2 --- /dev/null +++ b/crates/dirigent_anth/src/discovery.rs @@ -0,0 +1,342 @@ +use std::collections::HashMap; +use camino::{Utf8Path, Utf8PathBuf}; +use crate::types::*; +use crate::error::{AntError, Result}; + +/// Discover the Claude Code home directory (~/.claude/). +pub fn discover_claude_home() -> Result { + let home = dirs::home_dir().ok_or(AntError::HomeNotFound)?; + let claude_dir = home.join(".claude"); + if !claude_dir.exists() { + return Err(AntError::HomeNotFound); + } + Utf8PathBuf::try_from(claude_dir.to_path_buf()) + .map_err(|e| AntError::InvalidPath(e.to_string())) +} + +/// Normalise a native path to forward slashes for consistent storage. +fn normalize_to_forward_slashes(path: &str) -> String { + path.replace('\\', "/") +} + +/// Resolve the original filesystem path for a Claude project directory. +/// +/// Priority: +/// 1. `projectPath` from `sessions-index.json` (authoritative, cheap) +/// 2. `cwd` from the first user message in any session JSONL (authoritative, costs one file parse) +/// 3. `decode_project_path` (lossy fallback for empty project directories) +pub fn resolve_original_path(dir_name: &str, sessions: &[SessionRef]) -> String { + // 1. Try sessions-index.json projectPath + for session in sessions { + if let Some(ref idx) = session.index_entry { + if let Some(ref path) = idx.project_path { + if !path.is_empty() { + return normalize_to_forward_slashes(path); + } + } + } + } + + // 2. Try cwd from first user message in any session + for session in sessions { + if let Ok(msgs) = crate::parser::parse_session(&session.jsonl_path) { + for msg in &msgs { + if let crate::types::RawMessage::User(user) = msg { + if let Some(ref cwd) = user.cwd { + if !cwd.is_empty() { + return normalize_to_forward_slashes(cwd); + } + } + } + } + } + } + + // 3. Lossy fallback + decode_project_path(dir_name) +} + +/// Discover all Claude Code project directories under the given home. +pub fn discover_projects(home: &Utf8Path) -> Result> { + let projects_dir = home.join("projects"); + if !projects_dir.as_std_path().exists() { + return Ok(Vec::new()); + } + + let mut projects = Vec::new(); + for entry in std::fs::read_dir(projects_dir.as_std_path())? { + let entry = entry?; + let path = entry.path(); + if !path.is_dir() { + continue; + } + let dir_name = match path.file_name().and_then(|n| n.to_str()) { + Some(name) => name.to_string(), + None => continue, + }; + + let utf8_path = match Utf8PathBuf::try_from(path.clone()) { + Ok(p) => p, + Err(_) => continue, + }; + + let sessions = discover_sessions(&utf8_path)?; + let original_path = resolve_original_path(&dir_name, &sessions); + + projects.push(ClaudeProject { + path: utf8_path, + original_path, + sessions, + }); + } + + Ok(projects) +} + +/// Decode an encoded project folder name back to the original path (lossy). +/// +/// **Warning**: Claude Code's encoding replaces `\`, `/`, AND `_` all with +/// `-`, making this decoding ambiguous. For example, `G--dev-projects-adk-rust` +/// could be `G:/dev/projects/adk-rust` or `G:/dev/projects/adk/rust`. Prefer +/// [`resolve_original_path`] which reads ground truth from `sessions-index.json` +/// or session JSONL files. This function is a last-resort fallback for empty +/// project directories with no sessions or index. +pub fn decode_project_path(encoded: &str) -> String { + // Split on "--" to recover path segments separated by the original separators. + let parts: Vec<&str> = encoded.split("--").collect(); + + if parts.is_empty() { + return encoded.to_string(); + } + + let mut result = String::new(); + + let first = parts[0]; + + if first.len() == 1 && first.chars().next().map_or(false, |c| c.is_ascii_uppercase()) { + // Windows drive letter: "G" → "G:" + result.push_str(first); + result.push(':'); + } else if first.starts_with('-') || first.is_empty() { + // Unix-style absolute path: the original path started with "/". + // The first segment has a leading "-" that encoded the root separator. + // Strip that leading "-" to recover the first directory component. + let component = first.trim_start_matches('-'); + result.push('/'); + if !component.is_empty() { + // Single dashes within the component are path separators. + result.push_str(&component.replace('-', "/")); + } + } else { + result.push_str(first); + } + + // Remaining "--"-separated parts are additional path components. + // Within each part, single "-" represent path separators. + for part in &parts[1..] { + result.push('/'); + result.push_str(&part.replace('-', "/")); + } + + result +} + +/// Discover all session JSONL files in a project directory. +pub fn discover_sessions(project_dir: &Utf8Path) -> Result> { + let index = load_session_index(project_dir); + let mut sessions = Vec::new(); + + for entry in std::fs::read_dir(project_dir.as_std_path())? { + let entry = entry?; + let path = entry.path(); + + // Only .jsonl files + let extension = path.extension().and_then(|e| e.to_str()); + if extension != Some("jsonl") { + continue; + } + + let stem = match path.file_stem().and_then(|s| s.to_str()) { + Some(s) => s.to_string(), + None => continue, + }; + + let utf8_path = match Utf8PathBuf::try_from(path) { + Ok(p) => p, + Err(_) => continue, + }; + + // Check for artifacts directory (same name as the session stem). + let artifacts_dir = { + let dir = project_dir.join(&stem); + if dir.as_std_path().is_dir() { + Some(dir) + } else { + None + } + }; + + let index_entry = index.as_ref().and_then(|idx| idx.get(&stem).cloned()); + + sessions.push(SessionRef { + id: stem, + jsonl_path: utf8_path, + artifacts_dir, + index_entry, + }); + } + + Ok(sessions) +} + +/// Load `sessions-index.json` if it exists in the given project directory. +fn load_session_index(project_dir: &Utf8Path) -> Option> { + let index_path = project_dir.join("sessions-index.json"); + if !index_path.as_std_path().exists() { + return None; + } + + let content = std::fs::read_to_string(index_path.as_std_path()).ok()?; + serde_json::from_str::>(&content).ok() +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + #[test] + fn decode_project_path_windows() { + assert_eq!( + decode_project_path("G--dev-projects-dirigent"), + "G:/dev/projects/dirigent" + ); + } + + #[test] + fn decode_project_path_windows_users() { + assert_eq!( + decode_project_path("C--Users-g4b-tmp"), + "C:/Users/g4b/tmp" + ); + } + + #[test] + fn decode_project_path_unix() { + assert_eq!( + decode_project_path("-home-user-projects-foo"), + "/home/user/projects/foo" + ); + } + + #[test] + fn discover_sessions_in_temp_dir() { + let tmp = TempDir::new().unwrap(); + let project_dir = Utf8Path::from_path(tmp.path()).unwrap(); + + // Create fake session files. + std::fs::write(project_dir.join("abc-def-123.jsonl").as_std_path(), "{}\n").unwrap(); + std::fs::write(project_dir.join("xyz-456-789.jsonl").as_std_path(), "{}\n").unwrap(); + // Create an artifacts directory for one session. + std::fs::create_dir(project_dir.join("abc-def-123").as_std_path()).unwrap(); + + let sessions = discover_sessions(project_dir).unwrap(); + assert_eq!(sessions.len(), 2); + + let with_artifacts = sessions.iter().find(|s| s.id == "abc-def-123").unwrap(); + assert!(with_artifacts.artifacts_dir.is_some()); + + let without_artifacts = sessions.iter().find(|s| s.id == "xyz-456-789").unwrap(); + assert!(without_artifacts.artifacts_dir.is_none()); + } + + #[test] + fn discover_sessions_ignores_non_jsonl() { + let tmp = TempDir::new().unwrap(); + let project_dir = Utf8Path::from_path(tmp.path()).unwrap(); + + std::fs::write(project_dir.join("session.jsonl").as_std_path(), "{}\n").unwrap(); + std::fs::write( + project_dir.join("sessions-index.json").as_std_path(), + "{}", + ) + .unwrap(); + std::fs::create_dir(project_dir.join("some-dir").as_std_path()).unwrap(); + + let sessions = discover_sessions(project_dir).unwrap(); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].id, "session"); + } + + #[test] + fn discover_sessions_loads_index_entry() { + let tmp = TempDir::new().unwrap(); + let project_dir = Utf8Path::from_path(tmp.path()).unwrap(); + + std::fs::write(project_dir.join("abc-123.jsonl").as_std_path(), "{}\n").unwrap(); + + let index_json = r#"{ + "abc-123": { + "sessionId": "abc-123", + "firstPrompt": "Hello", + "summary": "A test session", + "messageCount": 5 + } + }"#; + std::fs::write( + project_dir.join("sessions-index.json").as_std_path(), + index_json, + ) + .unwrap(); + + let sessions = discover_sessions(project_dir).unwrap(); + assert_eq!(sessions.len(), 1); + + let entry = sessions[0].index_entry.as_ref().unwrap(); + assert_eq!(entry.session_id.as_deref(), Some("abc-123")); + assert_eq!(entry.first_prompt.as_deref(), Some("Hello")); + assert_eq!(entry.message_count, Some(5)); + } + + #[test] + fn resolve_original_path_prefers_index_project_path() { + let sessions = vec![SessionRef { + id: "test-session".to_string(), + jsonl_path: Utf8PathBuf::from("/tmp/fake.jsonl"), + artifacts_dir: None, + index_entry: Some(SessionIndexEntry { + session_id: Some("test-session".to_string()), + first_prompt: None, + summary: None, + message_count: None, + created: None, + modified: None, + git_branch: None, + project_path: Some(r"G:\dev\projects\bevy_sprite3d".to_string()), + }), + }]; + let result = resolve_original_path("G--dev-projects-bevy-sprite3d", &sessions); + assert_eq!(result, "G:/dev/projects/bevy_sprite3d"); + } + + #[test] + fn resolve_original_path_falls_back_to_decode() { + let sessions: Vec = vec![]; + let result = resolve_original_path("G--dev-projects-dirigent", &sessions); + assert_eq!(result, "G:/dev/projects/dirigent"); + } + + #[test] + fn discover_projects_empty_when_no_projects_dir() { + let tmp = TempDir::new().unwrap(); + let home_dir = Utf8Path::from_path(tmp.path()).unwrap(); + + // No "projects" subdirectory — should return empty vec, not an error. + let projects = discover_projects(home_dir).unwrap(); + assert!(projects.is_empty()); + } +} diff --git a/crates/dirigent_anth/src/error.rs b/crates/dirigent_anth/src/error.rs new file mode 100644 index 0000000..ef33e4e --- /dev/null +++ b/crates/dirigent_anth/src/error.rs @@ -0,0 +1,19 @@ +#[derive(Debug, thiserror::Error)] +pub enum AntError { + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON parse error at line {line}: {source}")] + JsonParse { + line: usize, + source: serde_json::Error, + }, + + #[error("Claude home directory not found")] + HomeNotFound, + + #[error("Invalid path: {0}")] + InvalidPath(String), +} + +pub type Result = std::result::Result; diff --git a/crates/dirigent_anth/src/lib.rs b/crates/dirigent_anth/src/lib.rs new file mode 100644 index 0000000..e0b6621 --- /dev/null +++ b/crates/dirigent_anth/src/lib.rs @@ -0,0 +1,52 @@ +//! dirigent_anth — Claude Code Session Parser & Toolkit +//! +//! Reads Claude Code's local JSONL session storage and produces typed, +//! deduplicated, correlated Rust data structures. +//! +//! # Design +//! +//! See `docs/superpowers/plans/2026-03-23-dirigent-ant-design.md` + +pub mod claude_grab; +pub mod anth_usage; +pub mod correlation; +pub mod dedup; +pub mod discovery; +pub mod error; +pub mod noise; +pub mod parser; +pub mod subagent; +pub mod tree; +pub mod types; +pub mod util; + +/// Load and fully parse a session: dedup, correlate, tree, subagents. +pub fn load_session(session_ref: &types::SessionRef) -> error::Result { + let messages = parser::parse_session_deduped(&session_ref.jsonl_path)?; + let tree = tree::ConversationTree::build(&messages); + let tool_exchanges = correlation::correlate_tools(&messages); + let mut subagents = if let Some(ref dir) = session_ref.artifacts_dir { + subagent::load_subagents(dir)? + } else { + Vec::new() + }; + subagent::link_subagents_to_calls(&mut subagents, &tool_exchanges); + + Ok(types::ParsedSession { + messages, + tree, + tool_exchanges, + subagents, + }) +} + +pub use correlation::correlate_tools; +pub use dedup::dedup_messages; +pub use discovery::{decode_project_path, discover_claude_home, discover_projects, discover_sessions, resolve_original_path}; +pub use error::{AntError, Result}; +pub use noise::{classify_noise, NoiseKind}; +pub use parser::{parse_line, parse_session, parse_session_deduped}; +pub use subagent::{link_subagents_to_calls, load_subagents}; +pub use tree::{message_parent_uuid, message_uuid, ConversationNode, ConversationTree}; +pub use types::*; +pub use util::parse_timestamp; diff --git a/crates/dirigent_anth/src/noise.rs b/crates/dirigent_anth/src/noise.rs new file mode 100644 index 0000000..6e986aa --- /dev/null +++ b/crates/dirigent_anth/src/noise.rs @@ -0,0 +1,72 @@ +use crate::types::*; + +/// Classification of noise patterns in Claude Code JSONL. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum NoiseKind { + Meta, + Warmup, + Interrupted, + Continuation, + ApiError, + SystemCaveat, + QueueOp, +} + +/// Classify a message as noise, if applicable. +/// Returns None for normal messages. +pub fn classify_noise(message: &RawMessage) -> Option { + match message { + RawMessage::QueueOperation(_) => Some(NoiseKind::QueueOp), + RawMessage::User(user) => { + if user.is_meta.unwrap_or(false) { + return Some(NoiseKind::Meta); + } + if let Some(text) = extract_user_text(user) { + if text == "Warmup" { + return Some(NoiseKind::Warmup); + } + if text.starts_with("[Request interrupted") { + return Some(NoiseKind::Interrupted); + } + if text.starts_with("This session is being continued") { + return Some(NoiseKind::Continuation); + } + if text.starts_with("API Error") { + return Some(NoiseKind::ApiError); + } + if text.starts_with("Caveat: The messages below") { + return Some(NoiseKind::SystemCaveat); + } + } + None + } + _ => None, + } +} + +/// Extract plain text from a user message's content. +fn extract_user_text(user: &RawUserMessage) -> Option<&str> { + match &user.message.content { + Content::Text(s) => Some(s.as_str()), + Content::Blocks(_) => None, // tool_result blocks, not plain text + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn normal_assistant_is_not_noise() { + let json = r#"{"type":"assistant","uuid":"x","timestamp":"2026-01-01T00:00:00Z","sessionId":"s","message":{"id":"m","role":"assistant","content":[{"type":"text","text":"Hello"}],"stop_reason":"end_turn"}}"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + assert_eq!(classify_noise(&msg), None); + } + + #[test] + fn queue_op_is_noise() { + let json = r#"{"type":"queue-operation","operation":"enqueue","timestamp":"2026-01-01T00:00:00Z","sessionId":"s"}"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + assert_eq!(classify_noise(&msg), Some(NoiseKind::QueueOp)); + } +} diff --git a/crates/dirigent_anth/src/parser.rs b/crates/dirigent_anth/src/parser.rs new file mode 100644 index 0000000..e7db24d --- /dev/null +++ b/crates/dirigent_anth/src/parser.rs @@ -0,0 +1,50 @@ +//! JSONL line parser for Claude Code session files. + +use std::io::BufRead; + +use camino::Utf8Path; + +use crate::error::Result; +use crate::types::RawMessage; + +/// Parse a single JSONL line into a RawMessage. +/// Returns None for lines that cannot be parsed (logged via tracing). +pub fn parse_line(line: &str, line_number: usize) -> Option { + match serde_json::from_str::(line) { + Ok(msg) => Some(msg), + Err(e) => { + tracing::warn!(line = line_number, error = %e, "Skipping unparseable JSONL line"); + None + } + } +} + +/// Parse all messages from a JSONL file. +/// Skips unparseable lines (lenient). Returns I/O errors. +pub fn parse_session(path: &Utf8Path) -> Result> { + let file = std::fs::File::open(path.as_std_path())?; + let reader = std::io::BufReader::new(file); + let mut messages = Vec::new(); + + for (i, line) in reader.lines().enumerate() { + let line = line?; + if line.trim().is_empty() { + continue; + } + if let Some(msg) = parse_line(&line, i + 1) { + messages.push(msg); + } + } + + Ok(messages) +} + +/// Parse a session JSONL file with streaming deduplication applied. +/// +/// Claude Code writes multiple JSONL lines for the same assistant message as +/// it streams. This function collapses those into a single final version per +/// uuid. See [`crate::dedup::dedup_messages`] for details. +pub fn parse_session_deduped(path: &Utf8Path) -> Result> { + let messages = parse_session(path)?; + Ok(crate::dedup::dedup_messages(messages)) +} diff --git a/crates/dirigent_anth/src/subagent.rs b/crates/dirigent_anth/src/subagent.rs new file mode 100644 index 0000000..8fff571 --- /dev/null +++ b/crates/dirigent_anth/src/subagent.rs @@ -0,0 +1,215 @@ +//! Sub-agent session loading. +//! +//! Claude Code spawns sub-agents for Agent tool calls and stores their +//! conversations under `/subagents/`. Each sub-agent +//! has a JSONL file and an optional `.meta.json` with metadata such as the +//! agent type. + +use camino::Utf8Path; + +use crate::error::Result; +use crate::parser::parse_session; +use crate::types::{SubAgentMeta, SubAgentSession, ToolExchange}; + +/// Load all sub-agent sessions from a session's artifacts directory. +/// +/// Expects files at: `/subagents/agent-.jsonl` +/// with optional companion: `/subagents/agent-.meta.json` +/// +/// Returns an empty `Vec` if the `subagents/` subdirectory does not exist. +pub fn load_subagents(session_artifacts_dir: &Utf8Path) -> Result> { + let subagents_dir = session_artifacts_dir.join("subagents"); + if !subagents_dir.as_std_path().exists() { + return Ok(Vec::new()); + } + + let mut subagents = Vec::new(); + + for entry in std::fs::read_dir(subagents_dir.as_std_path())? { + let entry = entry?; + let path = entry.path(); + + // Only process agent-*.jsonl files + let file_name = match path.file_name().and_then(|n| n.to_str()) { + Some(name) => name.to_string(), + None => continue, + }; + + if !file_name.starts_with("agent-") || !file_name.ends_with(".jsonl") { + continue; + } + + // Extract agent ID: "agent-abc123.jsonl" → "abc123" + let agent_id = file_name + .strip_prefix("agent-") + .and_then(|s| s.strip_suffix(".jsonl")) + .unwrap_or(&file_name) + .to_string(); + + let jsonl_path = match camino::Utf8PathBuf::try_from(path.clone()) { + Ok(p) => p, + Err(_) => continue, + }; + + // Parse the JSONL session + let messages = parse_session(&jsonl_path)?; + + // Try to load companion metadata file + let meta_path = path.with_file_name(format!("agent-{}.meta.json", agent_id)); + let meta = if meta_path.exists() { + let content = std::fs::read_to_string(&meta_path)?; + serde_json::from_str::(&content) + .unwrap_or(SubAgentMeta { agent_type: None }) + } else { + SubAgentMeta { agent_type: None } + }; + + subagents.push(SubAgentSession { + agent_id, + meta, + messages, + parent_tool_call_id: None, + }); + } + + Ok(subagents) +} + +/// Try to link sub-agent sessions to their parent Agent tool calls. +/// +/// For each Agent tool call in `tool_exchanges`, parses the tool result text +/// for `agentId: ` and matches it against sub-agent sessions. On match, +/// sets `SubAgentSession.parent_tool_call_id` to the tool call's ID. +/// +/// This is best-effort: if the agentId text format changes or a result is +/// missing, the sub-agent is still usable but without tool_use linkage. +pub fn link_subagents_to_calls( + subagents: &mut [SubAgentSession], + tool_exchanges: &[ToolExchange], +) { + use regex::Regex; + + if subagents.is_empty() || tool_exchanges.is_empty() { + return; + } + + // Compile once, match many + let re = Regex::new(r"agentId:\s*(\S+)").expect("valid regex"); + + for exchange in tool_exchanges { + // Only look at Agent tool calls + if exchange.call.name != crate::types::ToolName::Agent { + continue; + } + + // Extract agentId from the tool result text + let agent_id = exchange + .result + .as_ref() + .and_then(|r| r.content.as_deref()) + .and_then(|text| re.captures(text)) + .and_then(|caps| caps.get(1)) + .map(|m| m.as_str()); + + let agent_id = match agent_id { + Some(id) => id, + None => continue, + }; + + // Find matching sub-agent and set the linkage + if let Some(subagent) = subagents.iter_mut().find(|s| s.agent_id == agent_id) { + subagent.parent_tool_call_id = Some(exchange.call.id.clone()); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{ToolCall, ToolName, ToolResultData}; + + #[test] + fn test_link_subagents_to_calls_matches_agent_id() { + let mut subagents = vec![ + SubAgentSession { + agent_id: "abc123def".to_string(), + meta: SubAgentMeta { agent_type: Some("Explore".to_string()) }, + messages: vec![], + parent_tool_call_id: None, + }, + SubAgentSession { + agent_id: "xyz789".to_string(), + meta: SubAgentMeta { agent_type: None }, + messages: vec![], + parent_tool_call_id: None, + }, + ]; + + let exchanges = vec![ + ToolExchange { + call: ToolCall { + id: "toolu_01ABC".to_string(), + name: ToolName::Agent, + input: serde_json::json!({"description": "test"}), + source_message_uuid: "msg-1".to_string(), + }, + result: Some(ToolResultData { + tool_use_id: "toolu_01ABC".to_string(), + content: Some("agentId: abc123def (use SendMessage with to: 'abc123def' to continue)\ntotal_tokens: 1000".to_string()), + is_error: false, + source_message_uuid: "msg-2".to_string(), + }), + }, + ToolExchange { + call: ToolCall { + id: "toolu_02DEF".to_string(), + name: ToolName::Read, + input: serde_json::json!({}), + source_message_uuid: "msg-3".to_string(), + }, + result: None, + }, + ]; + + link_subagents_to_calls(&mut subagents, &exchanges); + + assert_eq!(subagents[0].parent_tool_call_id, Some("toolu_01ABC".to_string())); + assert_eq!(subagents[1].parent_tool_call_id, None); + } + + #[test] + fn test_link_subagents_empty_inputs() { + let mut empty_subagents: Vec = vec![]; + let empty_exchanges: Vec = vec![]; + link_subagents_to_calls(&mut empty_subagents, &empty_exchanges); + // No panic + } + + #[test] + fn test_link_subagents_no_match() { + let mut subagents = vec![SubAgentSession { + agent_id: "no_match".to_string(), + meta: SubAgentMeta { agent_type: None }, + messages: vec![], + parent_tool_call_id: None, + }]; + + let exchanges = vec![ToolExchange { + call: ToolCall { + id: "toolu_99".to_string(), + name: ToolName::Agent, + input: serde_json::json!({}), + source_message_uuid: "msg-1".to_string(), + }, + result: Some(ToolResultData { + tool_use_id: "toolu_99".to_string(), + content: Some("agentId: different_id\ntokens: 500".to_string()), + is_error: false, + source_message_uuid: "msg-2".to_string(), + }), + }]; + + link_subagents_to_calls(&mut subagents, &exchanges); + assert_eq!(subagents[0].parent_tool_call_id, None); + } +} diff --git a/crates/dirigent_anth/src/tree.rs b/crates/dirigent_anth/src/tree.rs new file mode 100644 index 0000000..ca8f37d --- /dev/null +++ b/crates/dirigent_anth/src/tree.rs @@ -0,0 +1,171 @@ +//! Conversation tree module — builds a parent/child tree from `RawMessage`s. +//! +//! Claude Code sessions are not purely linear: the user can edit earlier +//! messages, producing branches. Each message carries a `uuid` and a +//! `parentUuid` that describe the relationship. This module reconstructs +//! the tree so callers can walk threads, detect branches, and select the +//! main thread. + +use std::collections::HashMap; + +use crate::types::RawMessage; + +// --------------------------------------------------------------------------- +// Node & tree types +// --------------------------------------------------------------------------- + +/// A single node in the conversation tree. +#[derive(Debug)] +pub struct ConversationNode { + /// The UUID of this message. + pub uuid: String, + /// The raw message stored at this node. + pub message: RawMessage, + /// UUIDs of direct children, in insertion order. + pub children: Vec, +} + +/// The full conversation tree for a session. +/// +/// A session may have multiple roots when the first message has no +/// `parentUuid`, or when a message refers to a parent that is not present +/// in the slice provided to [`ConversationTree::build`]. +#[derive(Debug)] +pub struct ConversationTree { + /// Root node UUIDs (messages with no parent or with an unknown parent). + pub roots: Vec, + /// All nodes indexed by UUID. + pub nodes: HashMap, +} + +// --------------------------------------------------------------------------- +// UUID / parent-UUID helpers +// --------------------------------------------------------------------------- + +/// Extract the `uuid` from any `RawMessage` variant. +/// +/// Returns `None` for variants that carry no UUID (e.g. `QueueOperation`). +pub fn message_uuid(msg: &RawMessage) -> Option<&str> { + match msg { + RawMessage::User(m) => m.uuid.as_deref(), + RawMessage::Assistant(m) => m.uuid.as_deref(), + RawMessage::Progress(m) => m.uuid.as_deref(), + RawMessage::System(m) => m.uuid.as_deref(), + RawMessage::QueueOperation(_) + | RawMessage::FileHistorySnapshot(_) + | RawMessage::LastPrompt(_) => None, + } +} + +/// Extract the `parent_uuid` from any `RawMessage` variant. +/// +/// Returns `None` for variants that carry no parent UUID. +pub fn message_parent_uuid(msg: &RawMessage) -> Option<&str> { + match msg { + RawMessage::User(m) => m.parent_uuid.as_deref(), + RawMessage::Assistant(m) => m.parent_uuid.as_deref(), + RawMessage::Progress(m) => m.parent_uuid.as_deref(), + RawMessage::System(m) => m.parent_uuid.as_deref(), + RawMessage::QueueOperation(_) + | RawMessage::FileHistorySnapshot(_) + | RawMessage::LastPrompt(_) => None, + } +} + +// --------------------------------------------------------------------------- +// ConversationTree impl +// --------------------------------------------------------------------------- + +impl ConversationTree { + /// Build a conversation tree from a sequence of messages. + /// + /// Messages without a UUID (e.g. `QueueOperation`) are silently skipped. + /// If a message's `parentUuid` is present but not found in the set, + /// that message is treated as a root. + pub fn build(messages: &[RawMessage]) -> Self { + let mut nodes: HashMap = HashMap::new(); + let mut roots: Vec = Vec::new(); + + // First pass: insert every addressable message as a node. + for msg in messages { + if let Some(uuid) = message_uuid(msg) { + nodes.insert( + uuid.to_string(), + ConversationNode { + uuid: uuid.to_string(), + message: msg.clone(), + children: Vec::new(), + }, + ); + } + } + + // Second pass: collect (uuid, parent_uuid) pairs so we can wire up + // parent→child edges without a simultaneous mutable borrow. + let parent_links: Vec<(String, Option)> = messages + .iter() + .filter_map(|msg| { + let uuid = message_uuid(msg)?.to_string(); + let parent = message_parent_uuid(msg).map(|s| s.to_string()); + Some((uuid, parent)) + }) + .collect(); + + for (uuid, parent_uuid) in parent_links { + match parent_uuid { + Some(parent_id) if nodes.contains_key(&parent_id) => { + // Safe: parent_id != uuid (a message cannot be its own parent). + nodes + .get_mut(&parent_id) + .expect("parent key confirmed above") + .children + .push(uuid); + } + _ => { + // No parent, or parent not in the provided slice — treat as root. + roots.push(uuid); + } + } + } + + ConversationTree { roots, nodes } + } + + /// Walk the *main thread*: start from the first root and always follow + /// the first child at each step. + /// + /// In a linear session this is the complete conversation. In a branching + /// session this is the path taken before any edits. + pub fn main_thread(&self) -> Vec<&ConversationNode> { + let mut result = Vec::new(); + if let Some(root_id) = self.roots.first() { + let mut current = root_id.as_str(); + loop { + match self.nodes.get(current) { + Some(node) => { + result.push(node); + match node.children.first() { + Some(first_child) => current = first_child.as_str(), + None => break, + } + } + None => break, + } + } + } + result + } + + /// Returns `true` when every node has at most one child (no branches). + pub fn is_linear(&self) -> bool { + self.nodes.values().all(|n| n.children.len() <= 1) + } + + /// Returns all nodes that have more than one child (branch points). + pub fn branch_points(&self) -> Vec<&ConversationNode> { + self.nodes + .values() + .filter(|n| n.children.len() > 1) + .collect() + } +} diff --git a/crates/dirigent_anth/src/types.rs b/crates/dirigent_anth/src/types.rs new file mode 100644 index 0000000..e4c92ac --- /dev/null +++ b/crates/dirigent_anth/src/types.rs @@ -0,0 +1,847 @@ +//! Core types for parsing Claude Code JSONL session data. + +use camino::Utf8PathBuf; +use serde::{Deserialize, Serialize}; + +// --------------------------------------------------------------------------- +// Content types +// --------------------------------------------------------------------------- + +/// Content is either a plain string or an array of content blocks. +/// +/// Uses a custom deserializer so that `Blocks` variant applies lenient +/// deserialization — unknown content block types (e.g. `tool_reference`) +/// are silently skipped instead of failing the entire message. +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub enum Content { + Text(String), + Blocks(Vec), +} + +impl<'de> serde::Deserialize<'de> for Content { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + let value = serde_json::Value::deserialize(deserializer)?; + match value { + serde_json::Value::String(s) => Ok(Content::Text(s)), + serde_json::Value::Array(arr) => { + let blocks = arr + .into_iter() + .filter_map(|v| { + serde_json::from_value::(v.clone()) + .ok() + .or_else(|| { + tracing::debug!( + "Skipping unknown content block: {:?}", + v.get("type") + ); + None + }) + }) + .collect(); + Ok(Content::Blocks(blocks)) + } + other => Err(serde::de::Error::custom(format!( + "expected string or array for Content, got {}", + match &other { + serde_json::Value::Null => "null", + serde_json::Value::Bool(_) => "bool", + serde_json::Value::Number(_) => "number", + serde_json::Value::Object(_) => "object", + _ => "unknown", + } + ))), + } + } +} + +/// Typed content block inside messages. +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentBlock { + Text { + text: String, + }, + ToolUse { + id: String, + name: String, + input: serde_json::Value, + #[serde(default)] + caller: Option, + }, + ToolResult { + tool_use_id: String, + #[serde(default)] + content: Option, + #[serde(default)] + is_error: bool, + }, + Thinking { + thinking: String, + }, + Image { + source: serde_json::Value, + }, +} + +// --------------------------------------------------------------------------- +// Lenient content block deserialization +// --------------------------------------------------------------------------- + +/// Deserializes a `Vec` leniently — unknown block types are +/// silently skipped instead of failing the entire message. +fn deserialize_content_blocks<'de, D>( + deserializer: D, +) -> std::result::Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + use serde::Deserialize as _; + let raw: Vec = Vec::deserialize(deserializer)?; + Ok(raw + .into_iter() + .filter_map(|v| { + serde_json::from_value::(v.clone()).ok().or_else(|| { + tracing::debug!("Skipping unknown content block: {:?}", v.get("type")); + None + }) + }) + .collect()) +} + +// --------------------------------------------------------------------------- +// Top-level JSONL line discriminator +// --------------------------------------------------------------------------- + +/// Top-level JSONL line discriminator. +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +pub enum RawMessage { + User(RawUserMessage), + Assistant(RawAssistantMessage), + Progress(RawProgressMessage), + System(RawSystemMessage), + QueueOperation(RawQueueOperation), + FileHistorySnapshot(RawFileHistorySnapshot), + LastPrompt(RawLastPrompt), +} + +// --------------------------------------------------------------------------- +// User message +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RawUserMessage { + #[serde(default)] + pub uuid: Option, + #[serde(default)] + pub parent_uuid: Option, + #[serde(default)] + pub timestamp: Option, + #[serde(default)] + pub session_id: Option, + #[serde(default)] + pub cwd: Option, + #[serde(default)] + pub version: Option, + #[serde(default)] + pub git_branch: Option, + #[serde(default)] + pub is_sidechain: bool, + #[serde(default)] + pub is_meta: Option, + #[serde(default)] + pub user_type: Option, + pub message: UserMessageInner, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct UserMessageInner { + pub role: String, + pub content: Content, +} + +// --------------------------------------------------------------------------- +// Assistant message +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RawAssistantMessage { + #[serde(default)] + pub uuid: Option, + #[serde(default)] + pub parent_uuid: Option, + #[serde(default)] + pub timestamp: Option, + #[serde(default)] + pub session_id: Option, + #[serde(default)] + pub cwd: Option, + #[serde(default)] + pub version: Option, + #[serde(default)] + pub git_branch: Option, + #[serde(default)] + pub is_sidechain: bool, + #[serde(default)] + pub request_id: Option, + pub message: AssistantInner, +} + +// NOTE: AssistantInner is the Anthropic API response object nested inside +// the Claude Code JSONL wrapper. The API uses snake_case (stop_reason, etc.) +// unlike the outer JSONL wrapper which uses camelCase. +#[derive(Debug, Clone, Deserialize)] +pub struct AssistantInner { + #[serde(default)] + pub model: Option, + #[serde(default)] + pub id: Option, + #[serde(default, rename = "type")] + pub message_type: Option, + #[serde(default)] + pub role: Option, + #[serde(default, deserialize_with = "deserialize_content_blocks")] + pub content: Vec, + #[serde(default)] + pub stop_reason: Option, + #[serde(default)] + pub stop_sequence: Option, + #[serde(default)] + pub usage: Option, +} + +// --------------------------------------------------------------------------- +// Progress message +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RawProgressMessage { + #[serde(default)] + pub uuid: Option, + #[serde(default)] + pub parent_uuid: Option, + #[serde(default)] + pub timestamp: Option, + #[serde(default)] + pub session_id: Option, + #[serde(default)] + pub cwd: Option, + #[serde(default)] + pub version: Option, + #[serde(default)] + pub git_branch: Option, + #[serde(default)] + pub is_sidechain: bool, + #[serde(default)] + pub data: Option, +} + +// --------------------------------------------------------------------------- +// System message +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RawSystemMessage { + #[serde(default)] + pub uuid: Option, + #[serde(default)] + pub parent_uuid: Option, + #[serde(default)] + pub timestamp: Option, + #[serde(default)] + pub session_id: Option, + #[serde(default)] + pub cwd: Option, + #[serde(default)] + pub version: Option, + #[serde(default)] + pub git_branch: Option, + #[serde(default)] + pub is_sidechain: bool, + #[serde(default)] + pub data: Option, +} + +// --------------------------------------------------------------------------- +// Queue operation +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RawQueueOperation { + pub operation: String, + #[serde(default)] + pub timestamp: Option, + #[serde(default)] + pub session_id: Option, +} + +// --------------------------------------------------------------------------- +// File history snapshot +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RawFileHistorySnapshot { + #[serde(default)] + pub message_id: Option, + #[serde(default)] + pub is_snapshot_update: bool, + #[serde(default)] + pub snapshot: Option, +} + +// --------------------------------------------------------------------------- +// Last prompt +// --------------------------------------------------------------------------- + +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RawLastPrompt { + #[serde(default)] + pub last_prompt: Option, + #[serde(default)] + pub session_id: Option, +} + +// --------------------------------------------------------------------------- +// Tool types (for correlation module later) +// --------------------------------------------------------------------------- + +/// Known tool names used by Claude Code. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ToolName { + Bash, + Read, + Write, + Edit, + Grep, + Glob, + Agent, + Skill, + WebSearch, + WebFetch, + TodoWrite, + NotebookEdit, + Other(String), +} + +impl From for ToolName { + fn from(s: String) -> Self { + match s.as_str() { + "Bash" => ToolName::Bash, + "Read" => ToolName::Read, + "Write" => ToolName::Write, + "Edit" => ToolName::Edit, + "Grep" => ToolName::Grep, + "Glob" => ToolName::Glob, + "Agent" => ToolName::Agent, + "Skill" => ToolName::Skill, + "WebSearch" => ToolName::WebSearch, + "WebFetch" => ToolName::WebFetch, + "TodoWrite" => ToolName::TodoWrite, + "NotebookEdit" => ToolName::NotebookEdit, + other => ToolName::Other(other.to_string()), + } + } +} + +/// A tool call extracted from an assistant message. +#[derive(Debug, Clone)] +pub struct ToolCall { + pub id: String, + pub name: ToolName, + pub input: serde_json::Value, + pub source_message_uuid: String, +} + +/// A tool result extracted from a user message. +#[derive(Debug, Clone)] +pub struct ToolResultData { + pub tool_use_id: String, + pub content: Option, + pub is_error: bool, + pub source_message_uuid: String, +} + +/// A correlated tool call + result pair. +#[derive(Debug, Clone)] +pub struct ToolExchange { + pub call: ToolCall, + pub result: Option, +} + +// --------------------------------------------------------------------------- +// Discovery types (for discovery module later) +// --------------------------------------------------------------------------- + +/// A discovered Claude Code project directory. +#[derive(Debug, Clone)] +pub struct ClaudeProject { + pub path: Utf8PathBuf, + pub original_path: String, + pub sessions: Vec, +} + +/// Reference to a session (not yet parsed). +#[derive(Debug, Clone)] +pub struct SessionRef { + pub id: String, + pub jsonl_path: Utf8PathBuf, + pub artifacts_dir: Option, + pub index_entry: Option, +} + +/// From sessions-index.json (when available). +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionIndexEntry { + #[serde(default)] + pub session_id: Option, + #[serde(default)] + pub first_prompt: Option, + #[serde(default)] + pub summary: Option, + #[serde(default)] + pub message_count: Option, + #[serde(default)] + pub created: Option, + #[serde(default)] + pub modified: Option, + #[serde(default)] + pub git_branch: Option, + #[serde(default)] + pub project_path: Option, +} + +// --------------------------------------------------------------------------- +// Sub-agent types +// --------------------------------------------------------------------------- + +/// Sub-agent metadata from .meta.json. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubAgentMeta { + #[serde(default)] + pub agent_type: Option, +} + +/// A parsed sub-agent session. +#[derive(Debug, Clone)] +pub struct SubAgentSession { + pub agent_id: String, + pub meta: SubAgentMeta, + pub messages: Vec, + pub parent_tool_call_id: Option, +} + +// --------------------------------------------------------------------------- +// MessageMeta (convenience, future use) +// --------------------------------------------------------------------------- + +/// Common metadata extracted from any message. Defined for future consumers. +#[derive(Debug, Clone)] +pub struct MessageMeta { + pub uuid: String, + pub parent_uuid: Option, + pub timestamp: Option, + pub session_id: String, + pub cwd: Option, + pub version: Option, + pub git_branch: Option, + pub is_sidechain: bool, +} + +// --------------------------------------------------------------------------- +// ParsedSession +// --------------------------------------------------------------------------- + +/// A fully parsed session with all correlations built. +#[derive(Debug)] +pub struct ParsedSession { + pub messages: Vec, + pub tree: crate::tree::ConversationTree, + pub tool_exchanges: Vec, + pub subagents: Vec, +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_content_text_string() { + let json = r#""Hello world""#; + let content: Content = serde_json::from_str(json).unwrap(); + match content { + Content::Text(s) => assert_eq!(s, "Hello world"), + _ => panic!("Expected Content::Text"), + } + } + + #[test] + fn parse_content_blocks() { + let json = r#"[{"type": "text", "text": "Hello"}]"#; + let content: Content = serde_json::from_str(json).unwrap(); + match content { + Content::Blocks(blocks) => { + assert_eq!(blocks.len(), 1); + match &blocks[0] { + ContentBlock::Text { text } => assert_eq!(text, "Hello"), + _ => panic!("Expected ContentBlock::Text"), + } + } + _ => panic!("Expected Content::Blocks"), + } + } + + #[test] + fn parse_tool_use_block() { + let json = r#"{"type": "tool_use", "id": "toolu_123", "name": "Bash", "input": {"command": "ls"}}"#; + let block: ContentBlock = serde_json::from_str(json).unwrap(); + match block { + ContentBlock::ToolUse { id, name, .. } => { + assert_eq!(id, "toolu_123"); + assert_eq!(name, "Bash"); + } + _ => panic!("Expected ContentBlock::ToolUse"), + } + } + + #[test] + fn parse_tool_result_block() { + let json = r#"{"type": "tool_result", "tool_use_id": "toolu_123", "content": "output text", "is_error": false}"#; + let block: ContentBlock = serde_json::from_str(json).unwrap(); + match block { + ContentBlock::ToolResult { + tool_use_id, + is_error, + .. + } => { + assert_eq!(tool_use_id, "toolu_123"); + assert!(!is_error); + } + _ => panic!("Expected ContentBlock::ToolResult"), + } + } + + #[test] + fn parse_thinking_block() { + let json = r#"{"type": "thinking", "thinking": "Let me consider..."}"#; + let block: ContentBlock = serde_json::from_str(json).unwrap(); + match block { + ContentBlock::Thinking { thinking } => { + assert_eq!(thinking, "Let me consider..."); + } + _ => panic!("Expected ContentBlock::Thinking"), + } + } + + #[test] + fn parse_queue_operation() { + let json = r#"{"type": "queue-operation", "operation": "enqueue", "timestamp": "2026-03-14T21:15:17.531Z", "sessionId": "00f72d8d-fc54-485c-a082-310ffcabdb73"}"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + match msg { + RawMessage::QueueOperation(op) => { + assert_eq!(op.operation, "enqueue"); + assert_eq!( + op.session_id.as_deref(), + Some("00f72d8d-fc54-485c-a082-310ffcabdb73") + ); + } + _ => panic!("Expected RawMessage::QueueOperation"), + } + } + + #[test] + fn parse_user_message_with_string_content() { + let json = r#"{ + "parentUuid": "b1ab1ac7-fdb6-4e25-bc17-4c060b470b4a", + "isSidechain": false, + "userType": "external", + "cwd": "G:\\dev\\projects\\dirigent", + "sessionId": "00f72d8d-fc54-485c-a082-310ffcabdb73", + "version": "2.1.71", + "gitBranch": "main", + "type": "user", + "message": { + "role": "user", + "content": "Hello world" + }, + "isMeta": false, + "uuid": "1d843a4a-b99d-4c02-91a3-7cfe3dcac9f0", + "timestamp": "2026-03-14T21:08:58.586Z" + }"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + match msg { + RawMessage::User(u) => { + assert_eq!(u.uuid.as_deref(), Some("1d843a4a-b99d-4c02-91a3-7cfe3dcac9f0")); + assert_eq!(u.session_id.as_deref(), Some("00f72d8d-fc54-485c-a082-310ffcabdb73")); + assert_eq!(u.is_meta, Some(false)); + match &u.message.content { + Content::Text(s) => assert_eq!(s, "Hello world"), + _ => panic!("Expected Content::Text"), + } + } + _ => panic!("Expected RawMessage::User"), + } + } + + #[test] + fn parse_assistant_message_with_tool_use() { + let json = r#"{ + "parentUuid": "77793647-f957-4aec-8b04-a9c07e01e37b", + "isSidechain": false, + "userType": "external", + "cwd": "G:\\dev\\projects\\dirigent", + "sessionId": "00f72d8d-fc54-485c-a082-310ffcabdb73", + "version": "2.1.71", + "gitBranch": "main", + "message": { + "model": "claude-opus-4-6", + "id": "msg_01NcwYjEydGEyZCSCgwmcnYd", + "type": "message", + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_01DP5mkAQnAi2o54idq24cPn", + "name": "Agent", + "input": { + "description": "Investigate config sources of truth", + "subagent_type": "Explore", + "prompt": "test prompt" + }, + "caller": { "type": "direct" } + } + ], + "stop_reason": null, + "stop_sequence": null, + "usage": { + "input_tokens": 3, + "cache_creation_input_tokens": 20147, + "cache_read_input_tokens": 0, + "output_tokens": 9, + "service_tier": "standard" + } + }, + "requestId": "req_011CZ3fYWGjcQCgh5d58d2k8", + "type": "assistant", + "uuid": "6cad0d13-d0ae-47fa-a6b1-b7b45a2b5e0b", + "timestamp": "2026-03-14T21:15:27.916Z" + }"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + match msg { + RawMessage::Assistant(a) => { + assert_eq!(a.uuid.as_deref(), Some("6cad0d13-d0ae-47fa-a6b1-b7b45a2b5e0b")); + assert_eq!(a.message.model.as_deref(), Some("claude-opus-4-6")); + assert_eq!(a.message.content.len(), 1); + match &a.message.content[0] { + ContentBlock::ToolUse { name, id, .. } => { + assert_eq!(name, "Agent"); + assert_eq!(id, "toolu_01DP5mkAQnAi2o54idq24cPn"); + } + _ => panic!("Expected ContentBlock::ToolUse"), + } + assert!(a.message.stop_reason.is_none()); + assert!(a.message.usage.is_some()); + } + _ => panic!("Expected RawMessage::Assistant"), + } + } + + #[test] + fn unknown_content_block_type_skipped_in_assistant() { + let json = r#"{ + "parentUuid": null, + "isSidechain": false, + "sessionId": "test", + "message": { + "role": "assistant", + "content": [ + {"type": "text", "text": "known"}, + {"type": "future_type", "data": "something"} + ] + }, + "type": "assistant", + "uuid": "test-uuid", + "timestamp": "2026-01-01T00:00:00Z" + }"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + match msg { + RawMessage::Assistant(a) => { + assert_eq!(a.message.content.len(), 1); + match &a.message.content[0] { + ContentBlock::Text { text } => assert_eq!(text, "known"), + _ => panic!("Expected ContentBlock::Text"), + } + } + _ => panic!("Expected RawMessage::Assistant"), + } + } + + // ----------------------------------------------------------------------- + // Regression tests for parse failure audit (2026-04-04) + // ----------------------------------------------------------------------- + + #[test] + fn tool_reference_in_tool_result_content_does_not_fail() { + // Suggestion 1 & 3: tool_reference blocks inside tool_result.content + // should be silently skipped, not fail the entire message. + let json = r#"{ + "type": "user", + "uuid": "test-uuid", + "parentUuid": null, + "isSidechain": false, + "sessionId": "test-session", + "message": { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_abc123", + "content": [ + {"type": "text", "text": "File contents here"}, + {"type": "tool_reference", "tool_name": "TodoWrite"} + ], + "is_error": false + } + ] + } + }"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + match msg { + RawMessage::User(u) => { + match &u.message.content { + Content::Blocks(blocks) => { + assert_eq!(blocks.len(), 1); + match &blocks[0] { + ContentBlock::ToolResult { tool_use_id, content, .. } => { + assert_eq!(tool_use_id, "toolu_abc123"); + // The inner content should have 1 block (text), tool_reference skipped + match content.as_ref().unwrap() { + Content::Blocks(inner) => { + assert_eq!(inner.len(), 1); + match &inner[0] { + ContentBlock::Text { text } => { + assert_eq!(text, "File contents here"); + } + _ => panic!("Expected inner ContentBlock::Text"), + } + } + _ => panic!("Expected inner Content::Blocks"), + } + } + _ => panic!("Expected ContentBlock::ToolResult"), + } + } + _ => panic!("Expected Content::Blocks"), + } + } + _ => panic!("Expected RawMessage::User"), + } + } + + #[test] + fn file_history_snapshot_parses() { + // Suggestion 2: file-history-snapshot lines should parse, not fail. + let json = r#"{ + "type": "file-history-snapshot", + "messageId": "abc-123", + "isSnapshotUpdate": false, + "snapshot": { + "messageId": "abc-123", + "trackedFileBackups": { + "src/main.rs": {"backupFileName": "main.rs.bak", "backupTime": "2026-01-01T00:00:00Z", "version": "1"} + }, + "timestamp": "2026-01-01T00:00:00Z" + } + }"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + match msg { + RawMessage::FileHistorySnapshot(s) => { + assert_eq!(s.message_id.as_deref(), Some("abc-123")); + assert!(!s.is_snapshot_update); + assert!(s.snapshot.is_some()); + } + _ => panic!("Expected RawMessage::FileHistorySnapshot"), + } + } + + #[test] + fn last_prompt_parses() { + // Suggestion 2: last-prompt lines should parse, not fail. + let json = r#"{ + "type": "last-prompt", + "lastPrompt": "Fix the bug in auth middleware", + "sessionId": "session-uuid-123" + }"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + match msg { + RawMessage::LastPrompt(lp) => { + assert_eq!(lp.last_prompt.as_deref(), Some("Fix the bug in auth middleware")); + assert_eq!(lp.session_id.as_deref(), Some("session-uuid-123")); + } + _ => panic!("Expected RawMessage::LastPrompt"), + } + } + + #[test] + fn unknown_content_block_in_user_message_skipped() { + // Suggestion 3: Unknown block types in user message content + // should be silently skipped (lenient everywhere). + let json = r#"{ + "type": "user", + "uuid": "test-uuid", + "isSidechain": false, + "sessionId": "test", + "message": { + "role": "user", + "content": [ + {"type": "text", "text": "known"}, + {"type": "future_unknown_type", "data": "something"} + ] + } + }"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + match msg { + RawMessage::User(u) => { + match &u.message.content { + Content::Blocks(blocks) => { + assert_eq!(blocks.len(), 1); + match &blocks[0] { + ContentBlock::Text { text } => assert_eq!(text, "known"), + _ => panic!("Expected ContentBlock::Text"), + } + } + _ => panic!("Expected Content::Blocks"), + } + } + _ => panic!("Expected RawMessage::User"), + } + } + + #[test] + fn tool_name_from_string() { + assert_eq!(ToolName::from("Bash".to_string()), ToolName::Bash); + assert_eq!(ToolName::from("Read".to_string()), ToolName::Read); + assert_eq!(ToolName::from("Agent".to_string()), ToolName::Agent); + assert_eq!(ToolName::from("WebSearch".to_string()), ToolName::WebSearch); + assert_eq!( + ToolName::from("CustomTool".to_string()), + ToolName::Other("CustomTool".to_string()) + ); + } +} diff --git a/crates/dirigent_anth/src/util.rs b/crates/dirigent_anth/src/util.rs new file mode 100644 index 0000000..2c010f4 --- /dev/null +++ b/crates/dirigent_anth/src/util.rs @@ -0,0 +1,70 @@ +use chrono::{DateTime, Utc}; + +/// Parse a timestamp from various formats found in Claude Code data. +/// +/// Supports: +/// - ISO 8601 string: "2026-03-22T17:00:13.192Z" +/// - Unix milliseconds (number > 1e12): 1769461914249 +/// - Unix seconds (number <= 1e12): 1769461914 +pub fn parse_timestamp(value: &serde_json::Value) -> Option> { + match value { + serde_json::Value::String(s) => { + DateTime::parse_from_rfc3339(s) + .ok() + .map(|dt| dt.with_timezone(&Utc)) + } + serde_json::Value::Number(n) => { + if let Some(ms) = n.as_i64() { + if ms > 1_000_000_000_000 { + DateTime::from_timestamp_millis(ms) + } else { + DateTime::from_timestamp(ms, 0) + } + } else { + None + } + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::Datelike; + + #[test] + fn parse_timestamp_iso8601() { + let v = serde_json::json!("2026-03-22T17:00:13.192Z"); + let dt = parse_timestamp(&v).unwrap(); + assert_eq!(dt.year(), 2026); + assert_eq!(dt.month(), 3); + assert_eq!(dt.day(), 22); + } + + #[test] + fn parse_timestamp_unix_millis() { + let v = serde_json::json!(1769461914249_i64); + let dt = parse_timestamp(&v).unwrap(); + assert!(dt.year() >= 2025); + } + + #[test] + fn parse_timestamp_unix_seconds() { + let v = serde_json::json!(1769461914_i64); + let dt = parse_timestamp(&v).unwrap(); + assert!(dt.year() >= 2025); + } + + #[test] + fn parse_timestamp_null_returns_none() { + let v = serde_json::json!(null); + assert!(parse_timestamp(&v).is_none()); + } + + #[test] + fn parse_timestamp_invalid_string_returns_none() { + let v = serde_json::json!("not a date"); + assert!(parse_timestamp(&v).is_none()); + } +} diff --git a/crates/dirigent_anth/tests/fixtures/branching_tree.jsonl b/crates/dirigent_anth/tests/fixtures/branching_tree.jsonl new file mode 100644 index 0000000..93f440c --- /dev/null +++ b/crates/dirigent_anth/tests/fixtures/branching_tree.jsonl @@ -0,0 +1,6 @@ +{"type":"user","uuid":"r-001","parentUuid":null,"timestamp":"2026-03-23T10:00:00.000Z","sessionId":"test-session-tree","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"userType":"external","message":{"role":"user","content":"Help me"}} +{"type":"assistant","uuid":"a-001","parentUuid":"r-001","timestamp":"2026-03-23T10:00:01.000Z","sessionId":"test-session-tree","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-001","message":{"model":"claude-opus-4-6","id":"msg-001","type":"message","role":"assistant","content":[{"type":"text","text":"Sure"}],"stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}} +{"type":"user","uuid":"u-002","parentUuid":"a-001","timestamp":"2026-03-23T10:00:02.000Z","sessionId":"test-session-tree","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"userType":"external","message":{"role":"user","content":"Do option A"}} +{"type":"assistant","uuid":"a-003","parentUuid":"u-002","timestamp":"2026-03-23T10:00:03.000Z","sessionId":"test-session-tree","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-002","message":{"model":"claude-opus-4-6","id":"msg-003","type":"message","role":"assistant","content":[{"type":"text","text":"Doing A"}],"stop_reason":"end_turn","usage":{"input_tokens":15,"output_tokens":5}}} +{"type":"user","uuid":"u-002b","parentUuid":"a-001","timestamp":"2026-03-23T10:00:04.000Z","sessionId":"test-session-tree","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"userType":"external","message":{"role":"user","content":"Actually, do option B"}} +{"type":"assistant","uuid":"a-003b","parentUuid":"u-002b","timestamp":"2026-03-23T10:00:05.000Z","sessionId":"test-session-tree","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-003","message":{"model":"claude-opus-4-6","id":"msg-003b","type":"message","role":"assistant","content":[{"type":"text","text":"Doing B"}],"stop_reason":"end_turn","usage":{"input_tokens":15,"output_tokens":5}}} diff --git a/crates/dirigent_anth/tests/fixtures/minimal_session.jsonl b/crates/dirigent_anth/tests/fixtures/minimal_session.jsonl new file mode 100644 index 0000000..6476f6f --- /dev/null +++ b/crates/dirigent_anth/tests/fixtures/minimal_session.jsonl @@ -0,0 +1,6 @@ +{"type":"queue-operation","operation":"enqueue","timestamp":"2026-03-14T21:00:00.000Z","sessionId":"test-session-001"} +{"type":"queue-operation","operation":"dequeue","timestamp":"2026-03-14T21:00:00.001Z","sessionId":"test-session-001"} +{"type":"user","uuid":"u-001","parentUuid":null,"timestamp":"2026-03-14T21:00:01.000Z","sessionId":"test-session-001","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"userType":"external","message":{"role":"user","content":"Hello, help me with this project"}} +{"type":"assistant","uuid":"a-001","parentUuid":"u-001","timestamp":"2026-03-14T21:00:02.000Z","sessionId":"test-session-001","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-001","message":{"model":"claude-opus-4-6","id":"msg-001","type":"message","role":"assistant","content":[{"type":"text","text":"I'll help you."},{"type":"tool_use","id":"toolu_01","name":"Bash","input":{"command":"ls","description":"List files"}}],"stop_reason":"tool_use","usage":{"input_tokens":100,"output_tokens":50}}} +{"type":"user","uuid":"u-002","parentUuid":"a-001","timestamp":"2026-03-14T21:00:03.000Z","sessionId":"test-session-001","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":true,"userType":"external","message":{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_01","content":"file1.rs\nfile2.rs","is_error":false}]}} +{"type":"assistant","uuid":"a-002","parentUuid":"u-002","timestamp":"2026-03-14T21:00:04.000Z","sessionId":"test-session-001","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-002","message":{"model":"claude-opus-4-6","id":"msg-002","type":"message","role":"assistant","content":[{"type":"text","text":"I can see two Rust files in the directory."}],"stop_reason":"end_turn","usage":{"input_tokens":200,"output_tokens":30}}} diff --git a/crates/dirigent_anth/tests/fixtures/noise_patterns.jsonl b/crates/dirigent_anth/tests/fixtures/noise_patterns.jsonl new file mode 100644 index 0000000..1369fe2 --- /dev/null +++ b/crates/dirigent_anth/tests/fixtures/noise_patterns.jsonl @@ -0,0 +1,9 @@ +{"type":"queue-operation","operation":"enqueue","timestamp":"2026-03-14T21:00:00.000Z","sessionId":"test-session-noise"} +{"type":"user","uuid":"u-n-001","parentUuid":null,"timestamp":"2026-03-14T21:00:01.000Z","sessionId":"test-session-noise","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":true,"message":{"role":"user","content":"system injected stuff"}} +{"type":"user","uuid":"u-n-002","parentUuid":"u-n-001","timestamp":"2026-03-14T21:00:02.000Z","sessionId":"test-session-noise","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"message":{"role":"user","content":"Warmup"}} +{"type":"user","uuid":"u-n-003","parentUuid":"u-n-002","timestamp":"2026-03-14T21:00:03.000Z","sessionId":"test-session-noise","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"message":{"role":"user","content":"[Request interrupted by user"}} +{"type":"user","uuid":"u-n-004","parentUuid":"u-n-003","timestamp":"2026-03-14T21:00:04.000Z","sessionId":"test-session-noise","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"message":{"role":"user","content":"This session is being continued from a previous conversation"}} +{"type":"user","uuid":"u-n-005","parentUuid":"u-n-004","timestamp":"2026-03-14T21:00:05.000Z","sessionId":"test-session-noise","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"message":{"role":"user","content":"API Error: rate limit exceeded"}} +{"type":"user","uuid":"u-n-006","parentUuid":"u-n-005","timestamp":"2026-03-14T21:00:06.000Z","sessionId":"test-session-noise","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"message":{"role":"user","content":"Caveat: The messages below were generated by the user"}} +{"type":"user","uuid":"u-n-007","parentUuid":"u-n-006","timestamp":"2026-03-14T21:00:07.000Z","sessionId":"test-session-noise","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"message":{"role":"user","content":"Please help me fix this bug"}} +{"type":"assistant","uuid":"a-n-001","parentUuid":"u-n-007","timestamp":"2026-03-14T21:00:08.000Z","sessionId":"test-session-noise","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"message":{"id":"msg-n-001","role":"assistant","content":[{"type":"text","text":"Sure, let me help."}],"stop_reason":"end_turn"}} diff --git a/crates/dirigent_anth/tests/fixtures/streaming_dedup.jsonl b/crates/dirigent_anth/tests/fixtures/streaming_dedup.jsonl new file mode 100644 index 0000000..e787149 --- /dev/null +++ b/crates/dirigent_anth/tests/fixtures/streaming_dedup.jsonl @@ -0,0 +1,6 @@ +{"type":"user","uuid":"u-100","parentUuid":null,"timestamp":"2026-03-23T10:00:00.000Z","sessionId":"test-session-dedup","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"userType":"external","message":{"role":"user","content":"What files are here?"}} +{"type":"assistant","uuid":"a-100","parentUuid":"u-100","timestamp":"2026-03-23T10:00:01.000Z","sessionId":"test-session-dedup","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-100","message":{"model":"claude-opus-4-6","id":"msg-100","type":"message","role":"assistant","content":[{"type":"text","text":"Let me"}],"stop_reason":null,"usage":{"input_tokens":50,"output_tokens":3}}} +{"type":"assistant","uuid":"a-100","parentUuid":"u-100","timestamp":"2026-03-23T10:00:01.100Z","sessionId":"test-session-dedup","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-100","message":{"model":"claude-opus-4-6","id":"msg-100","type":"message","role":"assistant","content":[{"type":"text","text":"Let me look"},{"type":"tool_use","id":"toolu_100","name":"Bash","input":{"command":""}}],"stop_reason":null,"usage":{"input_tokens":50,"output_tokens":12}}} +{"type":"assistant","uuid":"a-100","parentUuid":"u-100","timestamp":"2026-03-23T10:00:01.200Z","sessionId":"test-session-dedup","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-100","message":{"model":"claude-opus-4-6","id":"msg-100","type":"message","role":"assistant","content":[{"type":"text","text":"Let me look at this."},{"type":"tool_use","id":"toolu_100","name":"Bash","input":{"command":"ls"}}],"stop_reason":"tool_use","usage":{"input_tokens":50,"output_tokens":20}}} +{"type":"user","uuid":"u-101","parentUuid":"a-100","timestamp":"2026-03-23T10:00:02.000Z","sessionId":"test-session-dedup","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":true,"userType":"external","message":{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_100","content":"main.rs\nlib.rs","is_error":false}]}} +{"type":"assistant","uuid":"a-101","parentUuid":"u-101","timestamp":"2026-03-23T10:00:03.000Z","sessionId":"test-session-dedup","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-101","message":{"model":"claude-opus-4-6","id":"msg-101","type":"message","role":"assistant","content":[{"type":"text","text":"Done."}],"stop_reason":"end_turn","usage":{"input_tokens":100,"output_tokens":5}}} diff --git a/crates/dirigent_anth/tests/fixtures/subagent/parent.jsonl b/crates/dirigent_anth/tests/fixtures/subagent/parent.jsonl new file mode 100644 index 0000000..c391246 --- /dev/null +++ b/crates/dirigent_anth/tests/fixtures/subagent/parent.jsonl @@ -0,0 +1,4 @@ +{"type":"user","uuid":"u-300","parentUuid":null,"timestamp":"2026-03-23T12:00:00.000Z","sessionId":"test-session-sub","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"message":{"role":"user","content":"Search the codebase"}} +{"type":"assistant","uuid":"a-300","parentUuid":"u-300","timestamp":"2026-03-23T12:00:01.000Z","sessionId":"test-session-sub","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-300","message":{"model":"claude-opus-4-6","id":"msg-300","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_300","name":"Agent","input":{"description":"Search codebase","subagent_type":"Explore","prompt":"Find all config files"}}],"stop_reason":"tool_use","usage":{"input_tokens":100,"output_tokens":20}}} +{"type":"user","uuid":"u-301","parentUuid":"a-300","timestamp":"2026-03-23T12:00:30.000Z","sessionId":"test-session-sub","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":true,"message":{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_300","content":"Found 3 config files","is_error":false}]}} +{"type":"assistant","uuid":"a-301","parentUuid":"u-301","timestamp":"2026-03-23T12:00:31.000Z","sessionId":"test-session-sub","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-301","message":{"model":"claude-opus-4-6","id":"msg-301","type":"message","role":"assistant","content":[{"type":"text","text":"I found the config files."}],"stop_reason":"end_turn","usage":{"input_tokens":200,"output_tokens":10}}} diff --git a/crates/dirigent_anth/tests/fixtures/subagent/parent/subagents/agent-abc123.jsonl b/crates/dirigent_anth/tests/fixtures/subagent/parent/subagents/agent-abc123.jsonl new file mode 100644 index 0000000..7ed323e --- /dev/null +++ b/crates/dirigent_anth/tests/fixtures/subagent/parent/subagents/agent-abc123.jsonl @@ -0,0 +1,2 @@ +{"type":"user","uuid":"sa-u1","parentUuid":null,"timestamp":"2026-03-23T12:00:02.000Z","sessionId":"agent-abc123","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":true,"isMeta":false,"message":{"role":"user","content":"Find all config files"}} +{"type":"assistant","uuid":"sa-a1","parentUuid":"sa-u1","timestamp":"2026-03-23T12:00:03.000Z","sessionId":"agent-abc123","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":true,"requestId":"req-sa1","message":{"model":"claude-opus-4-6","id":"msg-sa1","type":"message","role":"assistant","content":[{"type":"text","text":"Found config.toml, settings.json, .env"}],"stop_reason":"end_turn","usage":{"input_tokens":50,"output_tokens":15}}} diff --git a/crates/dirigent_anth/tests/fixtures/subagent/parent/subagents/agent-abc123.meta.json b/crates/dirigent_anth/tests/fixtures/subagent/parent/subagents/agent-abc123.meta.json new file mode 100644 index 0000000..9fdf2d6 --- /dev/null +++ b/crates/dirigent_anth/tests/fixtures/subagent/parent/subagents/agent-abc123.meta.json @@ -0,0 +1 @@ +{"agentType": "Explore"} diff --git a/crates/dirigent_anth/tests/fixtures/tool_correlation.jsonl b/crates/dirigent_anth/tests/fixtures/tool_correlation.jsonl new file mode 100644 index 0000000..0f4801d --- /dev/null +++ b/crates/dirigent_anth/tests/fixtures/tool_correlation.jsonl @@ -0,0 +1,6 @@ +{"type":"user","uuid":"u-200","parentUuid":null,"timestamp":"2026-03-23T10:00:00.000Z","sessionId":"test-session-corr","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"userType":"external","message":{"role":"user","content":"Fix the bug"}} +{"type":"assistant","uuid":"a-200","parentUuid":"u-200","timestamp":"2026-03-23T10:00:01.000Z","sessionId":"test-session-corr","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-200","message":{"model":"claude-opus-4-6","id":"msg-200","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_200","name":"Bash","input":{"command":"cargo test"}},{"type":"tool_use","id":"toolu_201","name":"Read","input":{"file_path":"src/main.rs"}}],"stop_reason":"tool_use","usage":{"input_tokens":100,"output_tokens":50}}} +{"type":"user","uuid":"u-201","parentUuid":"a-200","timestamp":"2026-03-23T10:00:02.000Z","sessionId":"test-session-corr","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":true,"userType":"external","message":{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_200","content":"test result output","is_error":false},{"type":"tool_result","tool_use_id":"toolu_201","content":"fn main() {}","is_error":false}]}} +{"type":"assistant","uuid":"a-201","parentUuid":"u-201","timestamp":"2026-03-23T10:00:03.000Z","sessionId":"test-session-corr","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-201","message":{"model":"claude-opus-4-6","id":"msg-201","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_202","name":"Write","input":{"file_path":"src/fix.rs","content":"fixed"}}],"stop_reason":"tool_use","usage":{"input_tokens":150,"output_tokens":30}}} +{"type":"user","uuid":"u-202","parentUuid":"a-201","timestamp":"2026-03-23T10:00:04.000Z","sessionId":"test-session-corr","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":true,"userType":"external","message":{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_202","content":"File written successfully","is_error":false}]}} +{"type":"assistant","uuid":"a-202","parentUuid":"u-202","timestamp":"2026-03-23T10:00:05.000Z","sessionId":"test-session-corr","cwd":"G:\\dev\\test","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-202","message":{"model":"claude-opus-4-6","id":"msg-202","type":"message","role":"assistant","content":[{"type":"text","text":"Bug is fixed."}],"stop_reason":"end_turn","usage":{"input_tokens":200,"output_tokens":20}}} diff --git a/crates/dirigent_anth/tests/integration_tests.rs b/crates/dirigent_anth/tests/integration_tests.rs new file mode 100644 index 0000000..cfd0134 --- /dev/null +++ b/crates/dirigent_anth/tests/integration_tests.rs @@ -0,0 +1,294 @@ +use camino::{Utf8Path, Utf8PathBuf}; +use chrono::Datelike; +use dirigent_anth::{ + correlation::correlate_tools, + dedup::dedup_messages, + noise::{classify_noise, NoiseKind}, + parse_session, + tree::ConversationTree, + types::{ContentBlock, RawMessage}, + util::parse_timestamp, +}; + +#[test] +fn parse_minimal_session() { + let path = Utf8Path::new("tests/fixtures/minimal_session.jsonl"); + let messages = parse_session(path).unwrap(); + + assert_eq!(messages.len(), 6, "Expected 6 messages, got {}", messages.len()); + + let type_names: Vec<&str> = messages + .iter() + .map(|m| match m { + RawMessage::User(_) => "user", + RawMessage::Assistant(_) => "assistant", + RawMessage::Progress(_) => "progress", + RawMessage::System(_) => "system", + RawMessage::QueueOperation(_) => "queue-operation", + RawMessage::FileHistorySnapshot(_) => "file-history-snapshot", + RawMessage::LastPrompt(_) => "last-prompt", + }) + .collect(); + + assert_eq!( + type_names.iter().filter(|&&t| t == "queue-operation").count(), + 2 + ); + assert_eq!(type_names.iter().filter(|&&t| t == "user").count(), 2); + assert_eq!( + type_names.iter().filter(|&&t| t == "assistant").count(), + 2 + ); +} + +#[test] +fn parse_line_returns_none_for_invalid_json() { + assert!(dirigent_anth::parse_line("not valid json", 1).is_none()); + assert!(dirigent_anth::parse_line("{}", 1).is_none()); +} + +#[test] +fn dedup_streaming_session() { + let path = Utf8Path::new("tests/fixtures/streaming_dedup.jsonl"); + let messages = parse_session(path).unwrap(); + + // Raw should have 6 lines (including 3 versions of same assistant message) + assert_eq!(messages.len(), 6, "Raw messages: expected 6, got {}", messages.len()); + + let deduped = dedup_messages(messages); + + // After dedup: U1, A1(final), U2, A2 = 4 + assert_eq!(deduped.len(), 4, "Deduped messages: expected 4, got {}", deduped.len()); + + // The kept assistant message must be the final version + let first_assistant = deduped.iter().find(|m| matches!(m, RawMessage::Assistant(_))).unwrap(); + if let RawMessage::Assistant(a) = first_assistant { + assert!(a.message.stop_reason.is_some(), "Deduped assistant should have stop_reason set"); + assert_eq!(a.message.stop_reason.as_deref(), Some("tool_use")); + assert_eq!(a.message.content.len(), 2, "Final version should have 2 content blocks"); + } else { + unreachable!(); + } +} + +#[test] +fn dedup_preserves_non_streamed_messages() { + let path = Utf8Path::new("tests/fixtures/minimal_session.jsonl"); + let messages = parse_session(path).unwrap(); + let count_before = messages.len(); + let deduped = dedup_messages(messages); + // No streaming in minimal_session, so count should be same + assert_eq!(deduped.len(), count_before); +} + +#[test] +fn correlate_parallel_tools() { + let path = Utf8Path::new("tests/fixtures/tool_correlation.jsonl"); + let messages = dirigent_anth::parse_session_deduped(path).unwrap(); + let exchanges = correlate_tools(&messages); + + // 3 tool calls: 2 parallel (Bash+Read) + 1 sequential (Write) + assert_eq!(exchanges.len(), 3); + + // All should have results + assert!(exchanges.iter().all(|e| e.result.is_some())); + + // Verify correct pairing by ID + for ex in &exchanges { + assert_eq!(ex.call.id, ex.result.as_ref().unwrap().tool_use_id); + } +} + +#[test] +fn correlate_no_tools_returns_empty() { + // Test with just a plain user message — no tool calls or results + let messages = vec![ + serde_json::from_str::( + r#"{"type":"user","uuid":"x","timestamp":"2026-01-01T00:00:00Z","sessionId":"s","message":{"role":"user","content":"hello"}}"#, + ) + .unwrap(), + ]; + let exchanges = correlate_tools(&messages); + assert!(exchanges.is_empty()); +} + +#[test] +fn build_branching_tree() { + let path = Utf8Path::new("tests/fixtures/branching_tree.jsonl"); + let messages = dirigent_anth::parse_session(path).unwrap(); + let tree = ConversationTree::build(&messages); + + assert_eq!(tree.roots.len(), 1); + assert!(!tree.is_linear()); + assert_eq!(tree.branch_points().len(), 1); // A1 has 2 children + + let main = tree.main_thread(); + assert_eq!(main.len(), 4); // R → A1 → U2 → A3 (first branch) +} + +#[test] +fn linear_conversation_is_linear() { + let path = Utf8Path::new("tests/fixtures/minimal_session.jsonl"); + let messages = dirigent_anth::parse_session(path).unwrap(); + let tree = ConversationTree::build(&messages); + assert!(tree.is_linear()); +} + +#[test] +fn classify_noise_from_fixture() { + let path = Utf8Path::new("tests/fixtures/noise_patterns.jsonl"); + let messages = dirigent_anth::parse_session(path).unwrap(); + + assert_eq!(messages.len(), 9, "Expected 9 messages in noise fixture"); + + let classifications: Vec> = messages.iter() + .map(classify_noise) + .collect(); + + assert_eq!(classifications[0], Some(NoiseKind::QueueOp)); + assert_eq!(classifications[1], Some(NoiseKind::Meta)); + assert_eq!(classifications[2], Some(NoiseKind::Warmup)); + assert_eq!(classifications[3], Some(NoiseKind::Interrupted)); + assert_eq!(classifications[4], Some(NoiseKind::Continuation)); + assert_eq!(classifications[5], Some(NoiseKind::ApiError)); + assert_eq!(classifications[6], Some(NoiseKind::SystemCaveat)); + assert_eq!(classifications[7], None); // normal user + assert_eq!(classifications[8], None); // normal assistant +} + +#[test] +fn load_subagent_from_fixture() { + let artifacts_dir = Utf8Path::new("tests/fixtures/subagent/parent"); + let subagents = dirigent_anth::load_subagents(artifacts_dir).unwrap(); + + assert_eq!(subagents.len(), 1); + assert_eq!(subagents[0].agent_id, "abc123"); + assert_eq!(subagents[0].meta.agent_type.as_deref(), Some("Explore")); + assert_eq!(subagents[0].messages.len(), 2); +} + +#[test] +fn load_subagents_empty_dir() { + // Non-existent artifacts dir should return empty vec + let artifacts_dir = Utf8Path::new("tests/fixtures/nonexistent"); + let subagents = dirigent_anth::load_subagents(artifacts_dir).unwrap(); + assert!(subagents.is_empty()); +} + +#[test] +fn load_full_session_with_subagents() { + use dirigent_anth::types::SessionRef; + + let session_ref = SessionRef { + id: "parent".to_string(), + jsonl_path: Utf8PathBuf::from("tests/fixtures/subagent/parent.jsonl"), + artifacts_dir: Some(Utf8PathBuf::from("tests/fixtures/subagent/parent")), + index_entry: None, + }; + + let session = dirigent_anth::load_session(&session_ref).unwrap(); + assert!(!session.messages.is_empty()); + assert!(!session.subagents.is_empty()); + assert!(!session.tree.roots.is_empty()); + assert!(!session.tool_exchanges.is_empty()); +} + +#[test] +fn load_session_without_artifacts() { + use dirigent_anth::types::SessionRef; + + let session_ref = SessionRef { + id: "minimal".to_string(), + jsonl_path: Utf8PathBuf::from("tests/fixtures/minimal_session.jsonl"), + artifacts_dir: None, + index_entry: None, + }; + + let session = dirigent_anth::load_session(&session_ref).unwrap(); + assert_eq!(session.messages.len(), 6); // 2 queue-ops + 2 users + 2 assistants + assert!(session.subagents.is_empty()); + assert!(session.tree.is_linear()); +} + +#[test] +fn content_as_string_or_blocks() { + // String content + let s: dirigent_anth::types::Content = serde_json::from_str(r#""hello""#).unwrap(); + assert!(matches!(s, dirigent_anth::types::Content::Text(_))); + + // Block content + let b: dirigent_anth::types::Content = + serde_json::from_str(r#"[{"type":"text","text":"hi"}]"#).unwrap(); + assert!(matches!(b, dirigent_anth::types::Content::Blocks(_))); + + // Empty blocks + let empty: dirigent_anth::types::Content = serde_json::from_str(r#"[]"#).unwrap(); + assert!(matches!(empty, dirigent_anth::types::Content::Blocks(ref v) if v.is_empty())); +} + +#[test] +fn missing_optional_fields_dont_crash() { + // Minimal assistant message with many fields missing + let json = r#"{ + "type": "assistant", + "message": { + "content": [{"type": "text", "text": "hi"}] + } + }"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + assert!(matches!(msg, RawMessage::Assistant(_))); +} + +#[test] +fn tool_result_content_string_and_blocks() { + // tool_result with string content + let json = r#"{"type":"tool_result","tool_use_id":"t1","content":"output text","is_error":false}"#; + let block: ContentBlock = serde_json::from_str(json).unwrap(); + if let ContentBlock::ToolResult { content, is_error, .. } = block { + assert!(!is_error); + assert!(content.is_some()); + } else { + panic!("Expected ToolResult"); + } + + // tool_result with no content + let json2 = r#"{"type":"tool_result","tool_use_id":"t2"}"#; + let block2: ContentBlock = serde_json::from_str(json2).unwrap(); + if let ContentBlock::ToolResult { content, is_error, .. } = block2 { + assert!(!is_error); + assert!(content.is_none()); + } else { + panic!("Expected ToolResult"); + } +} + +#[test] +fn extra_unknown_fields_are_ignored() { + // Messages with extra fields not in our structs should parse fine + let json = r#"{ + "type": "user", + "uuid": "x", + "timestamp": "2026-01-01T00:00:00Z", + "sessionId": "s", + "unknownField": "should be ignored", + "anotherExtra": 42, + "message": {"role": "user", "content": "hello"} + }"#; + let msg: RawMessage = serde_json::from_str(json).unwrap(); + assert!(matches!(msg, RawMessage::User(_))); +} + +#[test] +fn timestamp_parsing_all_formats() { + // ISO 8601 + let iso = parse_timestamp(&serde_json::json!("2026-03-22T17:00:13.192Z")).unwrap(); + assert_eq!(iso.year(), 2026); + + // Unix millis + let ms = parse_timestamp(&serde_json::json!(1769461914249_i64)).unwrap(); + assert!(ms.year() >= 2025); + + // Unix seconds + let secs = parse_timestamp(&serde_json::json!(1769461914_i64)).unwrap(); + assert!(secs.year() >= 2025); +} diff --git a/crates/dirigent_anth/tests/usage_parse.rs b/crates/dirigent_anth/tests/usage_parse.rs new file mode 100644 index 0000000..fca1916 --- /dev/null +++ b/crates/dirigent_anth/tests/usage_parse.rs @@ -0,0 +1,101 @@ +use dirigent_anth::anth_usage::process_usage_screen; + +const SAMPLE: &str = r#" +──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────── + Status Config Usage Stats + + Session + Total cost: $0.0000 + Total duration (API): 0s + Total duration (wall): 4s + Total code changes: 0 lines added, 0 lines removed + Usage: 0 input, 0 output, 0 cache read, 0 cache write + + Current session + ███████ 14% used + Resets 12:30pm (Europe/Vienna) + + Current week (all models) + ██████ 12% used + Resets May 12, 9am (Europe/Vienna) + + Current week (Sonnet only) + 0% used + Resets May 12, 9am (Europe/Vienna) + + What's contributing to your limits usage? + Approximate, based on local sessions on this machine — does not include other devices or claude.ai + + Last 24h · these are independent characteristics of your usage, not a breakdown + + 97% of your usage came from subagent-heavy sessions + Each subagent runs its own requests. Be deliberate about spawning them — and + consider configuring a cheaper model for simpler subagents. + + 16% of your usage was at >150k context + Longer sessions are more expensive even when cached. /compact mid-task, /clear + when switches to new tasks. + + Subagents % of usage + Explore 3% + claude-code-guide 2% + + d to day · w to week + + Esc to cancel +"#; + +#[test] +fn parses_gauges() { + let result = process_usage_screen(SAMPLE); + assert_eq!(result.data.gauges.len(), 3); + + assert_eq!(result.data.gauges[0].name, "Current session"); + assert_eq!(result.data.gauges[0].percent_used, 14); + assert_eq!( + result.data.gauges[0].resets.as_deref(), + Some("12:30pm (Europe/Vienna)") + ); + + assert_eq!(result.data.gauges[1].name, "Current week (all models)"); + assert_eq!(result.data.gauges[1].percent_used, 12); + assert_eq!( + result.data.gauges[1].resets.as_deref(), + Some("May 12, 9am (Europe/Vienna)") + ); + + assert_eq!(result.data.gauges[2].name, "Current week (Sonnet only)"); + assert_eq!(result.data.gauges[2].percent_used, 0); + + // resets_iso should be present for all gauges with reset info + assert!(result.data.gauges[0].resets_iso.is_some()); + assert!(result.data.gauges[1].resets_iso.is_some()); + assert!(result.data.gauges[2].resets_iso.is_some()); + + // Week resets should contain the right date + let week_iso = result.data.gauges[1].resets_iso.as_ref().unwrap(); + assert!(week_iso.starts_with("2026-05-12") || week_iso.contains("05-12")); +} + +#[test] +fn parses_contributions() { + let result = process_usage_screen(SAMPLE); + let contrib = result.data.contributions.as_ref().unwrap(); + + assert_eq!(contrib.factors.len(), 2); + assert_eq!(contrib.factors[0].percent, 97); + assert!(contrib.factors[0].description.contains("subagent-heavy")); + assert_eq!(contrib.factors[1].percent, 16); + + assert_eq!(contrib.subagents.len(), 2); + assert_eq!(contrib.subagents[0].name, "Explore"); + assert_eq!(contrib.subagents[0].percent, 3); + assert_eq!(contrib.subagents[1].name, "claude-code-guide"); + assert_eq!(contrib.subagents[1].percent, 2); +} + +#[test] +fn raw_screen_starts_with_rule() { + let result = process_usage_screen(SAMPLE); + assert!(result.raw_screen.starts_with('─')); +} diff --git a/crates/dirigent_archivist/CLAUDE.md b/crates/dirigent_archivist/CLAUDE.md new file mode 100644 index 0000000..d87b46b --- /dev/null +++ b/crates/dirigent_archivist/CLAUDE.md @@ -0,0 +1,761 @@ +# Package: dirigent_archivist + +Persistent storage for all agentic interactions in Dirigent. + +## Quick Facts +- **Type**: Library +- **Main Entry**: src/lib.rs +- **Dependencies**: dirigent_protocol, uuid, chrono, serde, tokio, tracing, thiserror, sha2, hex, async-trait +- **Status**: Complete - Production ready with comprehensive tests + +## Purpose + +The Archivist provides file-based archival storage for all session data, messages, and attachments in Dirigent. It implements an archive-first architecture with connector API fallback, using NDJSON, JSON, and TSV formats for durability and human-readability. + +## Key Features + +- **File-based Storage**: NDJSON for messages, JSON for metadata, TSV for indices +- **Content-Addressable Files**: SHA-256 based storage for attachments with automatic deduplication +- **Session Lineage**: Track splits, continuations, and mutations with parent references +- **Connector Registry**: Coordinate UID assignment across connectors with collision detection +- **Event Streaming**: Real-time updates via EventHandler subscribing to dirigent_protocol events +- **Archive-First Design**: Read from archive first, fall back to connector API when needed +- **Caching**: In-memory caching of connector and session mappings for performance + +## Architecture + +The Archivist is built on three core architectural principles: + +### 1. Archive-First Read Strategy + +The Archivist is the primary source of truth for historical data: +- UI and APIs query the archive first +- Only fall back to connector APIs if data is not in archive +- This enables offline access and consistent history across restarts + +### 2. Write-Through Event Capture (Append-Only) + +The EventHandler subscribes to the global event stream from dirigent_core: +- Captures session creation, message streaming, and tool calls in real-time +- Uses MessageAccumulator to assemble streaming chunks into complete messages +- Writes complete messages to archive immediately upon finalization +- No polling required - fully event-driven +- **Append-only writes**: Messages are appended as events arrive, NOT in chronological order +- File order reflects event timing, not message timestamps + +### 3. File-Based Storage with Sort-on-Read + +All data is stored in human-readable, grep-able formats: +- **NDJSON** (Newline-Delimited JSON): Incremental append-only logs for messages and mappings +- **JSON**: Structured metadata for sessions and connectors +- **TSV** (Tab-Separated Values): Fast indices for cross-references +- **Content-Addressed Files**: Binary attachments stored by SHA-256 hash for deduplication +- **Sort-on-Read**: `get_messages()` sorts by timestamp and message_id to ensure chronological order despite append-only writes + +## Backend Trait Layer (Phase 2) + +The archivist uses a trait-based backend abstraction. `ArchiveBackend` +defines the mandatory session and message primitives every backend must +provide, plus `as_xxx()` accessors returning optional sub-traits: + +- `SearchBackend` — reserved for Phase 3+ indexed backends (not wired) +- `DagBackend` — session lineage DAG edges +- `MetaEventsBackend` — ACP connection lifecycle events +- `ConnectorRegistryBackend` — per-archive connector metadata +- `SessionMappingBackend` — native↔scroll session ID mapping + +`JsonlBackend` is the Phase 2 concrete implementation (file-based +NDJSON/JSON/TSV) and opts into every sub-trait except `SearchBackend` +(content search continues to be served by ripgrep via +`crates/api/src/archivist/search_task.rs`). + +The `Archivist` struct (in `src/coordinator/`) owns a registry of backends +keyed by archive name and performs orchestration (alias detection, session +lineage, move/copy, DAG walks, archive lifecycle). Consumers hold +`Arc` directly — the coordinator is concrete, not a trait. + +See `docs/plans/2026-04-18-archivist-phase2-design.md` for design rationale. + +## Multi-Backend Registry (Phase 3) + +The coordinator (`Archivist`) holds `Vec>` sorted +by `read_priority` instead of a flat `HashMap>`. +Each registration carries: + +- `backend: Arc` + its declared capabilities +- `failure_mode`: `Required` (must succeed) | `BestEffort` (errors log + drift health) +- `read_priority`: lower = tried first for reads; also selects the default + write target when no archive is named +- `write_active`: participates in fanout writes +- `enabled`: kill-switch without removing config +- `write_policy`: `Inline` (default; `await` per call) or `Queued` + (mpsc + batch_window + overflow policy) +- Runtime state: `last_health`, `last_error`, `consecutive_failures` + (all `Arc>`, shared with the writer task when queued) +- Optional `writer: Option` (Some iff `write_policy = Queued`) + +Backends are declared in `dirigent.toml` under `[[archives]]` and +constructed at boot via `Archivist::from_config(cfg, &BackendRegistry)`. +Add a new backend type by implementing `BackendFactory` and registering +it on the `BackendRegistry` before `from_config`. + +### Reads + +`get_session`, `get_messages_paged`, `count_messages`, `get_meta_events`, +`get_children`, etc. walk the registry in priority order via +`read_walk_per_session(scroll_id, predicate, op)`. The predicate +capability-filters; `Unavailable` backends are skipped. The first backend +that returns `Some(value)` wins and its name is cached against the +`scroll_id` in a positive LRU (capacity 10_000). Subsequent reads for the +same `scroll_id` short-circuit to the cached backend before falling back +to the full priority walk. + +Collection-shape reads (`list_sessions_paged`, `list_connectors`, +`list_meta_sessions`, `find_meta_session_by_client`) use +`read_walk_collection` — first enabled backend that can answer wins, no +cache, no aggregation across backends. Phase 3 explicitly defers +cross-backend merge/dedup to a later phase. + +### Writes + +Mutating methods (`append_messages`, `register_session`, `update_session_*`, +`append_meta_events`, `append_dag_edge`, `clear_session_messages`, +`update_connector_fingerprint`) resolve a primary (per-call `archive: +Some(name)` override or the default-write target) and fan out to every +other `enabled && write_active` backend that has the required capability. +Capability-mismatched backends are skipped with a debug `capability_skip` +log (never an error). `Required` failures propagate to the caller; +`BestEffort` failures log + drift health. + +`register_connector` currently does NOT fan out — alias detection + the +tri-state `Accepted`/`Aliased`/`Rejected` return shape make replication +non-trivial. Fanout for connectors is deferred; single-backend setups are +unaffected. + +For `write_policy = Queued` backends, the primary/secondary write paths +enqueue a `WriteOp` into the backend's writer task instead of awaiting. +Errors drift the backend's health but do not propagate to the caller. +Coalescing merges consecutive `AppendMessages`/`AppendMetaEvents` for the +same `scroll_id` within `batch_window_ms`. + +### Cross-backend operations + +- `delete_session(scroll_id, _)` fans out to every enabled backend that has + the session. Copies in `write_active=false` backends produce + `ArchivistError::DeleteOnReadOnlyBackend` (write-active copies are still + deleted); cache invalidated regardless of outcome. +- `copy_session(scroll_id, from, to)` reads from `from`, writes to `to`, + including DAG and meta-events when both sides have the capability. The + source remains canonical (the cache is NOT rewritten). +- `move_session(scroll_id, from, to)` is `copy + delete-from-source`. If + the source-side delete fails after the copy succeeded, + `ArchivistError::PartialMove { copied_to, delete_error }` is returned so + the caller knows the session now lives in both places. + +The Phase 2 connector-aware `move_session(scroll_id, target_connector_uid, _)` +and `copy_session(scroll_id, target_connector_uid, _)` survived the Phase +3 rename as `move_session_to_connector` / `copy_session_to_connector`. +Their bulk variant is `move_sessions_to_connector`. + +### Health + +`HealthStatus` drifts on every coordinator call that observes a backend: + +- Successful write → `Healthy`; `consecutive_failures` reset to 0. +- Successful read → `Healthy` (only rescues `Degraded`; does not reset the counter). +- Write failure → `Degraded { reason }`; `consecutive_failures += 1`; after + K = 5 consecutive failures drifts to `Unavailable { reason }`. Reads skip + `Unavailable` backends; writes against an `Unavailable` `Required` + backend fail, while writes against an `Unavailable` `BestEffort` backend + are still attempted. +- Read failure alone never drifts past `Degraded`; writes are the + authoritative health signal. + +`list_archives_with_health()` returns a `Vec` snapshot of +every registration: name, type, capabilities, health, last_error, and +queue_depth (for queued backends). + +### Lifecycle + +Phase 3 is **startup-only**. `add_archive` / `remove_archive` / +`set_default_archive` on the coordinator return +`ArchivistError::DynamicRegistryUnsupported`. To change the registry, +edit `dirigent.toml` and restart the server. `Archivist::shutdown()` +drains queued writer tasks (sends `WriteOp::Shutdown` on each writer's +mpsc and awaits ack); call it before process exit. + +Test-only constructors `Archivist::from_registrations(regs)` and +`SessionMetadata::stub(scroll_id)` live under `#[cfg(any(test, feature = +"test-utils"))]` for integration tests that bypass the factory. + +See `docs/plans/2026-04-19-archivist-phase3-design.md` for the full +design rationale, and `examples/multi_backend.rs` for a runnable +end-to-end example. + +## Module Organization + +### Core Modules + +- **`lib.rs`**: Public API surface and re-exports +- **`types.rs`**: Core data structures (session metadata, message records, connector info, API types) +- **`error.rs`**: Error types and Result alias for archivist operations + +### Backend Layer (`backend/`) + +- **`traits.rs`**: `ArchiveBackend` trait + 5 optional sub-traits +- **`capability.rs`**: `ArchiveCapability` enum + `CapabilitySet` type +- **`health.rs`**: `HealthStatus` enum returned by `health_check` +- **`contract.rs`**: Reusable behavioral tests for any `&dyn ArchiveBackend` (cfg-gated) +- **`mock.rs`**: In-memory `MockBackend` for coordinator unit tests (cfg-gated) + +### Concrete Backends (`backends/`) + +- **`jsonl/`**: The file-based `JsonlBackend` — the only Phase 2 backend. + Reuses `storage/` primitives for NDJSON/JSON/TSV operations. + +### Coordinator (`coordinator/`) + +- **`mod.rs`**: The `Archivist` struct + constructors +- **`archives.rs`**: Archive lifecycle (add/remove/list/default) +- **`connectors.rs`**: Connector registration + alias detection +- **`sessions.rs`**: Session registration, metadata updates, move/copy +- **`meta.rs`**: Meta events, DAG walks, cleanup + +### Storage Layer (`storage/`) + +Low-level file I/O primitives used by `JsonlBackend`. All storage operations are async and use tokio. + +- **`paths.rs`**: ArchivePaths utility for consistent directory structure and path resolution +- **`ndjson.rs`**: Newline-delimited JSON operations (read_ndjson, append_ndjson) +- **`json.rs`**: JSON operations (read_json, write_json) +- **`tsv.rs`**: Tab-separated value operations for connector index +- **`files.rs`**: Content-addressable file storage with SHA-256 hashing and deduplication + +### Supporting Modules + +- **`registry.rs`**: Archive registry persistence (multi-archive metadata) +- **`migration.rs`**: Single-archive → multi-archive migration path +- **`session.rs`**: Session lineage types shared across layers +- **`accumulator.rs`**: MessageAccumulator for assembling streaming message chunks +- **`backfill.rs`**: Backfill helpers for importing historical sessions +- **`import/`**: External conversation importers (e.g. Claude export) + +### Events + +- **`events.rs`**: EventHandler for subscribing to dirigent_protocol events and archiving them + +## Configuration + +The Archivist archive root is determined by `DirigentPaths` resolution: + +- Set `DIRIGENT_DATA_DIR` to override the data directory; archives will be stored at `/archives/` +- Defaults to `~/.local/share/dirigent/archives/` (or platform equivalent) + +```bash +DIRIGENT_DATA_DIR=/path/to/data dx serve +``` + +## Archive Structure + +``` +dirigent_archive/ +├── .contexts/ +│ └── {scroll_id:uuidv7}/ # One directory per session +│ ├── session.json # Session metadata +│ ├── messages.jsonl # Incremental message log (.ndjson also supported) +│ └── lineage.json # Session lineage info (optional) +├── .db/ +│ └── connectors/ +│ ├── index.tsv # Fast connector lookup (TSV) +│ └── {connector_uid}/ +│ ├── connector.json # Connector metadata +│ └── sessions.jsonl # Session mappings (.ndjson also supported) +└── .files/ + └── {sha256-hash} # Content-addressable file storage +``` + +### Why Hidden Directories? + +The `.contexts`, `.db`, and `.files` directories are hidden (prefixed with `.`) to keep the archive root clean for future rendered outputs (like `chat.md` exports). This is similar to how `.git` hides implementation details in a codebase. + +## File Formats + +### Session Metadata (`session.json`) + +```json +{ + "version": 1, + "scroll_id": "01936e8f-e5a7-7000-8000-000000000001", + "created_at": "2025-01-01T12:00:00Z", + "updated_at": "2025-01-01T12:30:00Z", + "title": "Implement user authentication", + "connector_uid": "01936e8f-e5a7-7000-8000-000000000002", + "native_session_id": "abc123", + "agent_id": null, + "parent_scroll_id": null, + "continuation": null, + "tags": ["backend", "auth"], + "metadata": { + "source": "OpenCode", + "model": "claude-3-5-sonnet" + } +} +``` + +### Messages Log (`messages.jsonl`) + +One JSON object per line, **append-only**: + +```jsonl +{"version":1,"message_id":"01936e8f-e5a7-7000-8000-000000000003","session":"01936e8f-e5a7-7000-8000-000000000001","parent_id":null,"ts":"2025-01-01T12:01:00Z","role":"user","author":"alice","content_md":"How do I implement JWT auth?","attachments":[],"metadata":{}} +{"version":1,"message_id":"01936e8f-e5a7-7000-8000-000000000004","session":"01936e8f-e5a7-7000-8000-000000000001","parent_id":"01936e8f-e5a7-7000-8000-000000000003","ts":"2025-01-01T12:01:10Z","role":"assistant","author":"claude","content_md":"Here's how to implement JWT authentication...","attachments":[],"metadata":{"model":"claude-3-5-sonnet"}} +``` + +**IMPORTANT - Ordering**: The order of lines in the message log file (`messages.jsonl` or `messages.ndjson`) reflects **event arrival order**, NOT chronological order. Assistant replies often arrive after subsequent user messages due to streaming latency, resulting in non-chronological file order. Always use the `Archivist::get_messages()` API to retrieve messages, which sorts by `ts` (timestamp) and `message_id` (UUIDv7) to guarantee chronological order. + +**File Format Compatibility**: The archivist supports both `.ndjson` and `.jsonl` file extensions for newline-delimited JSON files. When reading, `.jsonl` is preferred if present, with automatic fallback to `.ndjson` for backward compatibility. Write operations use `.jsonl` (canonical format). Both formats are identical in content - the difference is purely the file extension. + +### Connector Index (`index.tsv`) + +Tab-separated values with header row: + +```tsv +connector_uid type title client_native_id alias_of created_at +01936e8f-e5a7-7000-8000-000000000002 OpenCode OpenCode Local opencode@http://localhost:12225 2025-01-01T12:00:00Z +``` + +### Session Mappings (`sessions.jsonl`) + +Maps native session IDs from connectors to scroll IDs in the archive: + +```jsonl +{"version":1,"connector_uid":"01936e8f-e5a7-7000-8000-000000000002","native_session_id":"abc123","scroll_id":"01936e8f-e5a7-7000-8000-000000000001","created_at":"2025-01-01T12:00:00Z","alias_of":null} +``` + +## Message Ordering Guarantees + +### The Problem: Append Order ≠ Chronological Order + +In the event-driven architecture, messages are written to the message log file (`messages.jsonl`) as completion events arrive. Due to streaming latency: + +- User messages complete nearly instantly and are written immediately +- Assistant messages stream over time and complete later +- A second user message can be written before the first assistant reply completes + +Example scenario: +``` +T0: User sends "tell me a joke about snakes" (ts=18:23:36.947) +T1: Assistant starts streaming reply (ts=18:23:36.969) +T2: User sends "now one about tigers" (ts=18:23:49.429) <- completes and writes BEFORE assistant finishes +T3: Assistant finishes "snakes" reply <- writes AFTER "tigers" user message +``` + +File order in the message log file: +``` +1. user "snakes" (18:23:36.947) +2. user "tigers" (18:23:49.429) <- written second +3. assistant "snakes" (18:23:36.969) <- written third, but timestamp is earlier! +``` + +### The Solution: Sort-on-Read + +The `Archivist::get_messages()` implementation sorts messages before returning: + +1. **Primary sort**: `ts` (timestamp) ascending +2. **Secondary sort**: `message_id` (UUIDv7) ascending for stable tie-breaking + +This guarantees chronological order regardless of NDJSON append order: +``` +1. user "snakes" (18:23:36.947) +2. assistant "snakes" (18:23:36.969) +3. user "tigers" (18:23:49.429) +``` + +### Why This Approach? + +- **Maintains durability**: Append-only writes preserve crash safety +- **No migration needed**: Existing archives work without rewrites +- **Simple implementation**: No buffered writes or complex write-time ordering +- **Performance trade-off**: Small CPU cost on read (sorting) vs. complex write-time coordination + +### Consumer Guidance + +- **DO**: Use `Archivist::get_messages()` to retrieve messages +- **DON'T**: Read the message log file directly and assume file order = chronological order +- **UI/API**: Always sort by `ts` then `message_id` for defense in depth +- **Tie-breaking**: Use `message_id` (UUIDv7) as secondary sort for stable ordering when timestamps match + +## Key Types + +### SessionMetadata + +Stores all metadata about a session including: +- **scroll_id**: UUIDv7 identifier for the session +- **connector_uid**: Which connector owns this session +- **native_session_id**: Original session ID from the connector (optional) +- **title**: Optional human-readable session title (see Title Management below) +- **parent_scroll_id**: For session lineage (splits, continuations) +- **continuation**: Type of continuation (SPLIT, COMPACT, REFERENCE, EDIT) +- **tags**: User-defined categorization +- **metadata**: Free-form JSON for connector-specific fields + +#### Title Management + +Session titles are fully supported and persist across restarts. Titles are stored in the `SessionMetadata` struct and saved to the `session.json` file. + +**Setting Titles:** +```rust +// Update title for an existing session +archivist.update_session_metadata( + scroll_id, + Some("My Custom Session Title".to_string()), + None, // model + None // archive +).await?; +``` + +**Default Behavior:** +- New sessions can specify an initial title during registration +- If no title is provided, sessions default to `None` +- The UI typically displays "Untitled" for sessions without titles + +**Title Loading:** +- Titles are automatically loaded when retrieving session metadata via `get_session_metadata()` +- Session lists include titles via `list_sessions()` and `list_sessions_all()` +- Titles are part of the `SessionMetadata` struct returned by all session queries + +**UI Integration:** +- The web UI displays session titles in the session list and sidebar +- Users can rename sessions via the "Rename" button in the session list view +- Renaming calls `api::archivist::rename_session()` which uses `update_session_metadata()` +- Title changes are persisted immediately and survive application restarts + +### MessageRecord + +Represents a single message in the archive: +- **message_id**: UUIDv7 identifier +- **session**: scroll_id this message belongs to +- **role**: "user", "assistant", or "system" +- **content_md**: Message content in Markdown format +- **attachments**: References to attached files +- **metadata**: Free-form JSON for connector-specific fields + +### ConnectorRecord + +Metadata about a connector: +- **connector_uid**: UUIDv7 identifier +- **type**: "OpenCode", "ACP", or custom +- **client_native_id**: Unique identifier from client (e.g., "opencode@http://localhost:12225") +- **alias_of**: If this connector is an alias of another (for deduplication) + +## Archivist Public API + +The `Archivist` struct (in `coordinator/`) is the main public entry point +for archival operations. Consumers hold `Arc` and call inherent +methods — there is no `Archivist` trait anymore. The coordinator resolves +the target backend per call (via `archive: Option`) and delegates +to `ArchiveBackend` methods. + +Key method families (see `coordinator/*.rs` for full signatures): + +- **Archive lifecycle** (`archives.rs`): `add_archive`, `remove_archive`, + `list_archives`, `set_default_archive` +- **Connectors** (`connectors.rs`): `register_connector` with tri-state + result (Accepted / Aliased / Rejected), `list_connectors` +- **Sessions** (`sessions.rs`): `register_session`, `get_session_metadata`, + `update_session_metadata`, `list_sessions_paged`, `move_session`, + `copy_session`, `resolve_session` +- **Messages**: `append_messages`, `get_messages` (sorts by `ts` then + `message_id` for stable chronological order) +- **Meta / DAG** (`meta.rs`): meta-event recording, session lineage DAG + walks, cleanup routines + +## List Filter vs. Full-Text Search + +Two distinct query paths exist — do not conflate them. + +**List filter** — `Archivist::list_sessions_paged(SessionListQuery)` returns a +cursor-paged list of sessions, AND-filtered by `title_query` (substring on +title), `tags`, `model_filter` (substring on `metadata.model`), `project_id`, +`connector_uid`, and `include_hidden`. This is the right tool for "narrow the +list of visible sessions." + +**Full-text search** — `api::search_sessions` (in the `api` package, backed by +`api::archivist::search_task::SearchTask`) spawns `rg --json` over the +archive's `.contexts/` tree to find messages containing text. It streams +`SearchExcerpt`s with parsed NDJSON content and supports cancellation via +`CancellationToken`. This is the right tool for "find messages containing +text." + +**Do not extend `list_sessions_paged` to do content search.** Content search +belongs in the ripgrep pipeline. Future improvements to content search +(indexed backends, relevance scoring) are Phase 2d / Phase 3 concerns. + +## JsonlBackend Implementation + +The Phase 2 production backend — an implementation of `ArchiveBackend` plus +every sub-trait except `SearchBackend`: + +- **Thread-safe**: Uses RwLock for in-memory caches +- **Async**: All operations use tokio for non-blocking I/O +- **Caching**: In-memory caches for connector and session mappings +- **Collision Detection**: Tri-state registration for connectors and sessions + +Located under `src/backends/jsonl/` and split by concern (`backend.rs`, +`connectors.rs`, `dag.rs`, `mapping.rs`, `meta.rs`). + +### Caching Strategy + +`JsonlBackend` maintains two in-memory caches: + +1. **connector_cache**: HashMap + - Populated on registration + - Read from TSV index on startup (future enhancement) + +2. **session_cache**: HashMap<(Uuid, String), Uuid> + - Maps (connector_uid, native_session_id) to scroll_id + - Populated on registration and session resolution + - Enables fast session lookups without disk I/O + +## Event Handling + +The EventHandler subscribes to dirigent_protocol events and archives them in real-time: + +```rust +// Create archivist and event handler +let archivist = Archivist::new_with_single_archive(archive_path).await?; +let handler = EventHandler::new(Arc::new(archivist)); + +// Subscribe to event stream from dirigent_core +let events = event_stream.subscribe(); + +// Run event loop (blocking) +handler.run(events).await; +``` + +### Supported Events + +- **SessionCreated**: Registers new sessions with the archivist +- **MessageCompleted**: Writes finalized messages to the archive +- **SessionUpdate**: Accumulates streaming message chunks + - AgentMessageChunk + - UserMessageChunk + - AgentThoughtChunk + - ToolCall + +### MessageAccumulator + +Assembles streaming message chunks into complete messages: + +- Accumulates text chunks by message_id +- Tracks thinking blocks separately +- Stores tool calls with input/output +- Finalizes messages on MessageCompleted event +- Converts to MessageRecord for archival + +## Integration with dirigent_core + +The Archivist integrates with dirigent_core via the global event stream: + +1. **CoreRuntime** emits events for all connector operations +2. **EventHandler** subscribes to event stream +3. **MessageAccumulator** assembles streaming chunks +4. **Archivist** writes complete messages to archive + +This enables: +- Automatic archival of all sessions and messages +- No polling required - fully event-driven +- Consistent history across restarts +- Offline access to historical data + +## Testing + +The package has comprehensive test coverage across multiple dimensions: + +### Unit Tests + +Located in each module (`src/*.rs`, `src/storage/*.rs`): +- Type serialization/deserialization +- UUIDv7 generation and ordering +- Timestamp formatting (RFC 3339) +- Storage operations (NDJSON, JSON, TSV, files) +- Connector registration tri-state logic +- Session registration and alias detection + +### Integration Tests + +Located in `tests/`: +- `integration_tests.rs`: Full `Archivist` + `JsonlBackend` lifecycle, event + handler integration, multi-connector scenarios, session lineage, message + accumulation +- `list_sessions_paged_test.rs`, `pagination_test.rs`: List filter + cursor + pagination coverage +- `import_claude_idempotency_test.rs`: Claude export re-import idempotency + +### Backend Contract Tests + +`src/backend/contract.rs` holds reusable async assertions that any +`&dyn ArchiveBackend` must pass. `JsonlBackend` and `MockBackend` both +run the contract suite; new backends added in Phase 3+ should do the same. + +### Examples + +Located in `examples/`: +- `basic_usage.rs`: Core archivist operations +- `event_handling.rs`: EventHandler and MessageAccumulator +- `file_storage.rs`: Content-addressable file storage + +Run tests: +```bash +cargo test --package dirigent_archivist +``` + +Run examples: +```bash +cargo run --package dirigent_archivist --example basic_usage +cargo run --package dirigent_archivist --example event_handling +cargo run --package dirigent_archivist --example file_storage +``` + +## Performance Characteristics + +- **Append Operations**: O(1) with sequential file writes +- **Session Lookup**: O(1) with in-memory cache, O(n) cache miss +- **Message Retrieval**: O(n) where n = number of messages (NDJSON parsing) +- **File Storage**: O(1) content-addressable lookup with SHA-256 hashing +- **Connector Index**: O(n) TSV scan, suitable for hundreds of connectors + +### Scalability Considerations + +- **Large Sessions**: NDJSON is append-only, so reading large sessions requires parsing all lines +- **Many Sessions**: TSV indices are suitable for thousands of sessions per connector +- **File Deduplication**: SHA-256 hashing provides automatic deduplication across sessions +- **Concurrent Access**: RwLock allows multiple concurrent readers, single writer + +## Error Handling + +The Archivist uses thiserror for rich error types: + +```rust +pub enum ArchivistError { + IoError(std::io::Error), + SerdeError(serde_json::Error), + SessionUnknown(Uuid), + CollisionInconsistent(Uuid), + // ... etc +} +``` + +All public APIs return `Result` for explicit error handling. + +## Development Notes + +- All storage operations are async (using tokio) +- Content-addressable storage uses SHA-256 hashes (hex-encoded) +- Archive directory structure mirrors session/message hierarchy +- UUIDv7 provides time-ordered, sortable identifiers +- RFC 3339 UTC timestamps for all time-based fields +- Schema versioning via `version` field in all records + +## Related Packages + +- **dirigent_protocol**: Shared types and protocol definitions (dependency) +- **dirigent_core**: Runtime integration for SSE event capture (integration point) +- **api**: Server functions for archive queries (future) +- **web**: UI for archive browsing and search (future) + +## Phase 4: `ArchiveFilter` (2026-04-21) + +Every `ArchiveRegistration` carries a `filter: ArchiveFilter`. The filter +describes which sessions/writes the backend wants to receive. Fields: + +- `include_connectors: Option>` — if Some, only these + connector UIDs pass. `None` means no connector gate. +- `exclude_connectors: HashSet` — always rejected. +- `include_tags: HashSet` — if non-empty, the session must carry + at least one matching tag. +- `exclude_tags: HashSet` — any matching tag rejects. +- `include_hidden: bool` — default `true`. If `false`, sessions whose + metadata has `"hidden": true` are skipped. + +### Primary-always-writes invariant + +The per-call primary (either the `archive: Some(name)` argument or the +default write-target) is **never** filtered. If a caller explicitly asks +to write to archive X, the filter on X is not consulted. Filters only +gate secondary fanout. + +### Boot validator + +At boot (`coordinator/boot.rs`), the validator rejects configurations +where: + +- No write-active + enabled registration has an **unrestricted** filter + (`ArchiveFilter::default()` is unrestricted). Prevents configurations + that silently drop all writes. +- An archive's filter has `include_connectors = Some(empty set)` — + equivalent to "reject everything", which is almost certainly a config + bug. + +See `docs/plans/2026-04-21-archivist-phase4-design.md` §4 for the full +design rationale. + +## Phase 5: Importers (2026-04-21) + +The `import::` module centres on an `Importer` trait with per-source +implementations under `import::sources::*`. Each source produces a +`ParsedConversation` (ChatGPT) / `ParsedSession` (Codex) / session +directory walk (Claude) and feeds the results through the common +`import_sessions` orchestrator, which fires `ImportProgressEvent`s on a +bounded `ImportProgressSink`. + +### `Importer` trait + +Every importer declares a `config_shape()` so UIs can render a dynamic +form; a `discover()` that returns an `ImportDiscovery` preview; and an +`import()` that does the actual work. All three methods are async. + +The trait lives in `import::trait_def`. Shape types (`ImportConfig`, +`ImportTarget`, `ConfigField`, `ConfigFieldKind`, `ImportError`) are +serialisable and safe to cross the WASM boundary. + +### Registry + +`ImporterRegistry::with_defaults()` registers every enabled +`importer-*` feature. Currently: `claude`, `chatgpt`, `codex`. The +registry is constructed at boot and stored on `AppState`. + +### Progress sink + +`ImportProgressSink::channel()` returns a bounded mpsc pair. +Non-terminal events use `try_send` (dropped on full); terminal events +use `send().await` so consumers always see the final state. + +### Source crates + +- `dirigent_chatgpt` — parses `conversations.json` from the OpenAI data + export. +- `dirigent_codex` — parses `*.jsonl` session files under + `~/.codex/sessions`. + +Both crates hold pure parser types with zero dirigent-specific types. + +See `docs/plans/2026-04-21-archivist-phase5-design.md`. + +## Future Enhancements + +- Indexed `SearchBackend` implementations (tantivy/sqlite) — currently + content search is ripgrep-based in the `api` package +- Session splitting and lineage management (mutations.ndjson) +- Knowledge overview generation (chat.md exports) +- Embedding storage and search (embeds/) +- Network RPC interface for remote archivist +- Compaction and pruning policies +- Additional concrete backends (e.g. SQLite, remote) + +## Documentation + +- **Package README**: `./README.md` - User-facing overview +- **Architecture Docs**: `../../docs/building/05_archivist/` - Design and planning +- **API Docs**: Run `cargo doc --package dirigent_archivist --open` +- **Examples**: See `examples/` directory for working code samples diff --git a/crates/dirigent_archivist/Cargo.toml b/crates/dirigent_archivist/Cargo.toml new file mode 100644 index 0000000..5f29eea --- /dev/null +++ b/crates/dirigent_archivist/Cargo.toml @@ -0,0 +1,69 @@ +[package] +name = "dirigent_archivist" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[features] +# All built-in importers are on by default. Turn the corresponding +# `importer-*` flag off (and opt out of `default`) to ship a slimmer build. +default = ["importer-claude", "importer-chatgpt", "importer-codex"] + +# Exposes the sub-trait contract test harness (`backend::contract`) to +# downstream crates so new backends can reuse the same behavioral checks. +test-utils = [] + +# Per-source importer feature gates. Each flag guards the corresponding +# `ImporterRegistry::with_defaults` registration and (where relevant) the +# source module itself. +importer-claude = [] +importer-chatgpt = ["dep:dirigent_chatgpt"] +importer-codex = ["dep:dirigent_codex"] + +[dependencies] +# Core dependencies +dirigent_protocol = { path = "../dirigent_protocol" } +dirigent_anth = { path = "../dirigent_anth" } +dirigent_chatgpt = { path = "../dirigent_chatgpt", optional = true } +dirigent_codex = { path = "../dirigent_codex", optional = true } +camino = "1.1" + +# UUID support with v7 and serde +uuid = { version = "1.11", features = ["v5", "v7", "serde"] } + +# Date/time handling +chrono = { version = "0.4", features = ["serde"] } + +# Serialization +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +toml = "0.8" + +# Async runtime and file operations +tokio = { version = "1.42", features = ["fs", "sync", "time", "io-util", "macros", "rt-multi-thread"] } + +# Logging +tracing = "0.1" + +# Error handling +thiserror = "2.0" +anyhow = "1" + +# Hashing for content-addressable storage +sha2 = "0.10" +hex = "0.4" + +# LRU read cache for registry backends +lru = "0.12" + +# Async traits +async-trait = "0.1" + +# Async futures +futures = "0.3" + +[dev-dependencies] +tempfile = "3.0" +walkdir = "2" diff --git a/crates/dirigent_archivist/README.md b/crates/dirigent_archivist/README.md new file mode 100644 index 0000000..5c860fc --- /dev/null +++ b/crates/dirigent_archivist/README.md @@ -0,0 +1,338 @@ +# Dirigent Archivist + +Persistent storage for all agentic interactions in Dirigent. + +## Overview + +The Archivist automatically archives every conversation, message, and file from your AI sessions into a local, grep-able, human-readable archive. No cloud required - your data stays on your machine in formats you can read and search manually. + +## Why Archivist? + +- **Offline Access**: All conversations are saved locally, accessible even when connectors are offline +- **Manual Curation**: Files are in plain JSON/NDJSON/TSV - grep, edit, or analyze them with any tool +- **Knowledge Base**: Build a searchable archive of all your AI interactions across projects +- **Session Lineage**: Track conversation branches, splits, and continuations +- **File Deduplication**: Attachments are stored once, referenced multiple times (content-addressable) +- **Archive-First**: UI reads from local archive first, only falls back to remote connectors when needed + +## Quick Start + +The Archivist runs automatically when you start Dirigent. The archive location is determined by the `DIRIGENT_DATA_DIR` environment variable (archives are stored at `/archives/`): + +```bash +# Override data directory (archives at /path/to/data/archives/) +DIRIGENT_DATA_DIR=/path/to/data dx serve +``` + +That's it! Every session and message will be automatically archived. + +## Archive Structure + +Your archive is organized like this: + +``` +dirigent_archive/ +├── .contexts/ # Session data +│ └── 01936e8f-e5a7-7000-8000.../ +│ ├── session.json # Session metadata +│ └── messages.ndjson # All messages (one JSON per line) +├── .db/ +│ └── connectors/ # Connector registry +│ ├── index.tsv # Fast lookup table +│ └── 01936e8f-e5a7.../ +│ ├── connector.json # Connector info +│ └── sessions.ndjson # Session ID mappings +└── .files/ # Attachments (by SHA-256) + └── a1b2c3d4... # Content-addressable storage +``` + +### Why Hidden Directories? + +The `.contexts`, `.db`, and `.files` directories start with `.` to keep them internal (like `.git`). In the future, you'll be able to export rendered markdown files into the archive root for easy reading. + +## File Formats + +### Session Metadata (`.contexts/{id}/session.json`) + +```json +{ + "version": 1, + "scroll_id": "01936e8f-e5a7-7000-8000-000000000001", + "created_at": "2025-01-01T12:00:00Z", + "updated_at": "2025-01-01T12:30:00Z", + "title": "Implement user authentication", + "connector_uid": "01936e8f-e5a7-7000-8000-000000000002", + "tags": ["backend", "auth"], + "metadata": { + "source": "OpenCode", + "model": "claude-3-5-sonnet" + } +} +``` + +### Messages (`.contexts/{id}/messages.ndjson`) + +Newline-delimited JSON - one message per line, **append-only**: + +```jsonl +{"version":1,"message_id":"...","session":"...","role":"user","ts":"2025-01-01T12:01:00Z","content_md":"How do I implement JWT auth?","attachments":[],"metadata":{}} +{"version":1,"message_id":"...","session":"...","role":"assistant","ts":"2025-01-01T12:01:10Z","content_md":"Here's how to implement JWT authentication...","attachments":[],"metadata":{"model":"claude-3-5-sonnet"}} +``` + +**IMPORTANT**: Messages are written as events arrive, NOT in chronological order. Assistant replies often appear after subsequent user messages due to streaming latency. When reading programmatically, use the Archivist API which sorts by timestamp (`ts`) to ensure correct order. For manual inspection, sort by the `ts` field. + +### Connector Index (`.db/connectors/index.tsv`) + +Tab-separated values for fast scanning: + +```tsv +connector_uid type title client_native_id alias_of created_at +01936e8f... OpenCode OpenCode Local opencode@http://localhost:12225 2025-01-01T12:00:00Z +``` + +## Searching Your Archive + +Since everything is plain text, you can use standard Unix tools: + +```bash +# Find all sessions about "authentication" +grep -r "authentication" dirigent_archive/.contexts/*/session.json + +# Find messages mentioning a specific error +grep "ECONNREFUSED" dirigent_archive/.contexts/*/messages.ndjson + +# List all sessions for a connector +cat dirigent_archive/.db/connectors/*/sessions.ndjson | jq . + +# Get all user messages from a session (sorted by timestamp) +cat dirigent_archive/.contexts/01936e8f.../messages.ndjson | jq -s 'sort_by(.ts) | .[] | select(.role=="user")' + +# View messages in chronological order +cat dirigent_archive/.contexts/01936e8f.../messages.ndjson | jq -s 'sort_by(.ts)' +``` + +**Note on ordering**: Remember that the file order is append-only (event arrival order). Always sort by `ts` (timestamp) when reading manually to see messages in chronological order. + +## Integration with Dirigent + +The Archivist integrates seamlessly with Dirigent's core runtime: + +1. **Automatic Archiving**: Every session and message is archived in real-time as events arrive +2. **Event-Driven**: No polling - listens to dirigent_core's event stream +3. **Append-Only Writes**: Messages written as completion events arrive (preserves durability) +4. **Sort-on-Read**: API returns messages in chronological order despite append-only file order +5. **UI Integration**: Web UI reads from archive first, shows data even when connectors are offline +6. **Connector Coordination**: Assigns stable UUIDs to connectors with collision detection + +## Key Concepts + +### Scroll IDs + +Every session gets a unique `scroll_id` (UUIDv7) that's independent of the connector's native session ID. This allows: +- Sessions to move between connectors +- Stable references even if connector data is deleted +- Time-ordered sorting (UUIDv7 encodes timestamp) + +### Session Lineage + +Sessions can have parent sessions, creating a tree of related conversations: +- **Split**: Fork conversation at a specific message +- **Compact**: Summarized version of parent +- **Reference**: Points to parent without duplication +- **Edit**: Modified version of parent + +### Content-Addressable Storage + +Files are stored by their SHA-256 hash, so: +- Same file uploaded twice uses same storage +- Files can be shared across sessions without duplication +- You can verify file integrity by hash + +## Configuration + +### Environment Variables + +- `DIRIGENT_DATA_DIR`: Override data directory; archives are stored at `/archives/` + +### Example Configurations + +```bash +# Use custom data directory (archives at /home/user/mydata/archives/) +DIRIGENT_DATA_DIR=/home/user/mydata dx serve + +# Use global data directory +DIRIGENT_DATA_DIR=/home/user/.dirigent dx serve + +# Use temporary data directory (testing) +DIRIGENT_DATA_DIR=/tmp/dirigent_test dx serve +``` + +## Programmatic Access + +While the Archivist runs automatically, you can also use it programmatically: + +```rust +use dirigent_archivist::Archivist; +use std::path::PathBuf; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Create an archivist over a single archive directory. + // Internally this wires up a `JsonlBackend` for the archive. + let archivist = Archivist::new_with_single_archive( + PathBuf::from("./dirigent_archive") + ).await?; + + // List sessions for a connector + let sessions = archivist.list_sessions(connector_uid).await?; + + for session in sessions { + println!("{}: {}", session.scroll_id, session.title.unwrap_or_default()); + } + + Ok(()) +} +``` + +`Archivist` is a concrete struct that owns a registry of `ArchiveBackend` +implementations keyed by archive name. In Phase 2 the only backend is +`JsonlBackend` (file-based NDJSON/JSON/TSV). See `examples/` for more +detailed usage. + +## Performance + +The Archivist is designed for human-scale workloads (thousands of sessions, millions of messages): + +- **Fast Writes**: Append-only NDJSON is O(1) +- **Cached Reads**: Common lookups cached in memory +- **Grep-able**: TSV indices can be scanned in milliseconds +- **Incremental**: Only new messages are written, no full re-writes + +### Scalability Notes + +- Large sessions (1000+ messages) may take a few seconds to load +- TSV indices are suitable for 100-1000 connectors +- File deduplication saves space for repeated attachments + +## Querying and Curation + +### Future: Knowledge Overviews + +The Archivist is designed to support knowledge curation workflows: +- Export sessions as clean markdown files +- Create summaries and overviews across sessions +- Tag and categorize conversations +- Build a personal knowledge base + +These features are planned for future releases. + +### Current: Manual Curation + +For now, you can manually curate your archive: +- Edit `session.json` to add tags +- Grep through messages for specific topics +- Copy/organize sessions into project folders +- Use jq/awk/sed to extract insights + +## Advanced Features + +### Session Splitting + +Create a new conversation branch from any point in history: + +```rust +// Future API (not yet implemented) +let new_session = archivist.split_session( + session_id, + at_message_id, + Continuation::Split +).await?; +``` + +### Attachment Storage + +Files are automatically deduplicated using SHA-256: + +```rust +// Store file (content-addressable) +let file_id = archivist.store_file( + &file_data, + "spec.pdf", + Some("application/pdf") +).await?; + +// Reference in message +let attachment = AttachmentRef { + file_id, // "sha256:abc123..." + name: "spec.pdf".to_string(), + mime_type: Some("application/pdf".to_string()), +}; +``` + +### Multi-Archive Support + +`Archivist` natively manages multiple named archives via an on-disk +registry. Each archive is backed by its own `ArchiveBackend` (currently +`JsonlBackend`) and selected per call via an optional `archive` argument. +This enables: +- Separate archives per project +- A default archive plus specialized side archives +- Moving or copying sessions between archives + +Future backends (e.g. SQLite, indexed, remote) will plug into the same +trait layer without changing the coordinator API. + +## Troubleshooting + +### Archive Not Created + +If the archive directory doesn't appear: +1. Check `DIRIGENT_DATA_DIR` is set correctly (or that the default data directory is writable) +2. Ensure write permissions on parent directory +3. Check logs for I/O errors + +### Missing Sessions + +If sessions don't appear in archive: +1. Verify EventHandler is running +2. Check for event subscription errors in logs +3. Ensure connector emits `SessionCreated` events + +### Large Archive Size + +If archive grows too large: +1. Check for duplicate files in `.files/` +2. Consider archiving old sessions separately +3. Future: Use compaction features (not yet implemented) + +## Development Status + +**Current** (Phase 2 complete): +- Automatic archival of sessions and messages +- Event-driven integration with dirigent_core +- File-based storage with NDJSON/JSON/TSV (`JsonlBackend`) +- Content-addressable file storage +- Multi-archive coordinator with per-archive backends +- Trait-based backend abstraction (`ArchiveBackend` + sub-traits) + +**Future**: +- Indexed `SearchBackend` implementations (full-text search) +- Additional concrete backends (SQLite, remote) +- Session splitting and lineage management +- Knowledge overview generation +- Network RPC interface + +## Documentation + +- **Developer Guide**: `CLAUDE.md` - Package architecture and implementation details +- **Architecture**: `docs/building/05_archivist/vision.md` - Design rationale +- **API Docs**: `cargo doc --package dirigent_archivist --open` +- **Examples**: See `examples/` for working code + +## Contributing + +The Archivist is part of the Dirigent project. See the main repository for contribution guidelines. + +## License + +Part of the Dirigent project. diff --git a/crates/dirigent_archivist/examples/basic_usage.rs b/crates/dirigent_archivist/examples/basic_usage.rs new file mode 100644 index 0000000..5eacea2 --- /dev/null +++ b/crates/dirigent_archivist/examples/basic_usage.rs @@ -0,0 +1,198 @@ +//! Basic usage example for dirigent_archivist +//! +//! This example demonstrates: +//! - Creating a Archivist +//! - Registering a connector +//! - Registering a session +//! - Appending messages to a session +//! - Listing sessions for a connector +//! - Retrieving messages for a session + +use chrono::Utc; +use dirigent_archivist::{ + Archivist, MessageRecord, RegisterConnectorRequest, RegisterSessionRequest, + Result, +}; +use std::path::PathBuf; +use uuid::Uuid; + +#[tokio::main] +async fn main() -> Result<()> { + // Create a temporary archive directory for this example + let temp_dir = std::env::temp_dir().join(format!("dirigent_example_{}", Uuid::now_v7())); + println!("Creating archive at: {}", temp_dir.display()); + + // Step 1: Create a Archivist + let archivist = Archivist::new_with_single_archive(temp_dir.clone()).await?; + println!("Archivist created successfully"); + + // Step 2: Register a connector + println!("\n--- Registering Connector ---"); + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "OpenCode Local".to_string(), + client_native_id: "opencode@http://localhost:12225".to_string(), + custom_uid: None, // Let archivist generate a UID + metadata: serde_json::json!({ + "version": "0.1.0", + "protocol": "OpenCode HTTP API" + }), + fingerprint: None, + }; + + let connector_resp = archivist.register_connector(connector_req, None).await?; + println!("Connector registered: {:?}", connector_resp); + let connector_uid = connector_resp.connector_uid; + + // Step 3: Register a session + println!("\n--- Registering Session ---"); + let session_req = RegisterSessionRequest { + connector_uid, + native_session_id: "session-abc123".to_string(), + title: Some("Example chat session".to_string()), + custom_scroll_id: None, // Let archivist generate a scroll ID + metadata: serde_json::json!({ + "project_path": "/home/user/projects/example", + "model": "claude-3-5-sonnet" + }), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_resp = archivist.register_session(session_req, None).await?; + println!("Session registered: {:?}", session_resp); + let scroll_id = session_resp.scroll_id; + + // Step 4: Append messages to the session + println!("\n--- Appending Messages ---"); + + // User message + let user_msg = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: scroll_id, + parent_id: None, + ts: Utc::now(), + role: "user".to_string(), + author: Some("alice".to_string()), + content_md: "Hello! Can you help me write a function to calculate fibonacci numbers?" + .to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + // Assistant message + let assistant_msg = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: scroll_id, + parent_id: Some(user_msg.message_id), + ts: Utc::now(), + role: "assistant".to_string(), + author: Some("claude".to_string()), + content_md: r#"Sure! Here's a recursive fibonacci function in Rust: + +```rust +fn fibonacci(n: u32) -> u64 { + match n { + 0 => 0, + 1 => 1, + _ => fibonacci(n - 1) + fibonacci(n - 2), + } +} +``` + +This is the classic recursive implementation, though it's not the most efficient for large values of n."# + .to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({ + "model": "claude-3-5-sonnet", + "latency_ms": 1245 + }), + }; + + archivist + .append_messages(scroll_id, vec![user_msg.clone(), assistant_msg.clone()], None) + .await?; + println!("Appended 2 messages to session"); + + // Step 5: List all sessions for the connector + println!("\n--- Listing Sessions ---"); + let page = archivist + .list_sessions_paged( + dirigent_archivist::SessionListQuery::default() + .with_connector(connector_uid) + .with_limit(100), + ) + .await?; + let sessions = page.items; + println!("Found {} session(s) for connector:", sessions.len()); + for session in &sessions { + println!( + " - {} ({}): {:?}", + session.scroll_id, + session.created_at.format("%Y-%m-%d %H:%M:%S"), + session.title + ); + } + + // Step 6: Retrieve all messages for the session + println!("\n--- Retrieving Messages ---"); + let messages = archivist.get_messages(scroll_id, None).await?; + println!("Retrieved {} message(s):", messages.len()); + for msg in &messages { + println!("\n[{}] {}", msg.role, msg.ts.format("%Y-%m-%d %H:%M:%S")); + println!("{}", msg.content_md); + } + + // Step 7: Demonstrate session resolution + println!("\n--- Resolving Session ---"); + let resolved_scroll_id = archivist + .resolve_session(connector_uid, "session-abc123", None) + .await?; + println!( + "Resolved native session 'session-abc123' to scroll_id: {}", + resolved_scroll_id + ); + assert_eq!(resolved_scroll_id, scroll_id); + + // Step 8: Show archive structure + println!("\n--- Archive Structure ---"); + println!("Archive root: {}", temp_dir.display()); + println!("\nDirectory structure:"); + show_directory_tree(&temp_dir, 0)?; + + // Cleanup + println!("\n--- Cleanup ---"); + std::fs::remove_dir_all(&temp_dir)?; + println!("Removed temporary archive"); + + Ok(()) +} + +/// Helper function to display directory tree +fn show_directory_tree(path: &PathBuf, depth: usize) -> Result<()> { + let indent = " ".repeat(depth); + + if path.is_dir() { + println!("{}{}/", indent, path.file_name().unwrap().to_string_lossy()); + + let mut entries: Vec<_> = std::fs::read_dir(path)?.filter_map(|e| e.ok()).collect(); + entries.sort_by_key(|e| e.path()); + + for entry in entries { + show_directory_tree(&entry.path(), depth + 1)?; + } + } else { + println!("{}{}", indent, path.file_name().unwrap().to_string_lossy()); + } + + Ok(()) +} diff --git a/crates/dirigent_archivist/examples/demo_types.rs b/crates/dirigent_archivist/examples/demo_types.rs new file mode 100644 index 0000000..3b3fa20 --- /dev/null +++ b/crates/dirigent_archivist/examples/demo_types.rs @@ -0,0 +1,156 @@ +// Demonstration of archivist types serialization +// Run with: cargo run --package dirigent_archivist --example demo_types + +use chrono::Utc; +use dirigent_archivist::*; +use uuid::Uuid; + +fn main() { + println!("=== ARCHIVIST TYPES DEMONSTRATION ===\n"); + + // Demo 1: SessionMetadata (matches session.json format) + println!("1. SessionMetadata (session.json):"); + let session_metadata = SessionMetadata { + version: 1, + scroll_id: Uuid::now_v7(), + created_at: Utc::now(), + updated_at: Utc::now(), + title: Some("Example Session".to_string()), + connector_uid: Uuid::now_v7(), + native_session_id: Some("abc123".to_string()), + agent_id: Some("claude-3-5".to_string()), + parent_scroll_id: None, + continuation: Some(Continuation::Split), + tags: vec!["example".to_string(), "test".to_string()], + metadata: serde_json::json!({ + "source": "OpenCode", + "project": "dirigent" + }), + no_update: false, + kind: SessionKind::Chat, + acp_client_id: None, + is_connected: None, + current_session_id: None, + models: None, + modes: None, + config_options: None, + completeness: SessionCompleteness::default(), + matrix_room_id: None, + matrix_sharing_active: false, + matrix_shared_at: None, + is_subagent: false, + subagent_type: None, + spawning_tool_use_id: None, + }; + println!( + "{}\n", + serde_json::to_string_pretty(&session_metadata).unwrap() + ); + + // Demo 2: MessageRecord (matches messages.ndjson format) + println!("2. MessageRecord (messages.ndjson line):"); + let message = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_metadata.scroll_id, + parent_id: None, + ts: Utc::now(), + role: "user".to_string(), + author: Some("alice".to_string()), + content_md: "How do I implement archivist types?".to_string(), + content_parts: None, + attachments: vec![AttachmentRef { + file_id: "sha256:abc123".to_string(), + name: "spec.pdf".to_string(), + mime_type: Some("application/pdf".to_string()), + }], + metadata: serde_json::json!({ + "connector_msg_id": "msg-456" + }), + }; + // NDJSON format (one line) + println!("{}\n", serde_json::to_string(&message).unwrap()); + + // Demo 3: ConnectorRecord (matches connector.json format) + println!("3. ConnectorRecord (connector.json):"); + let connector = ConnectorRecord { + version: 1, + connector_uid: session_metadata.connector_uid, + r#type: "OpenCode".to_string(), + title: "OpenCode Local".to_string(), + client_native_id: "opencode@http://localhost:12225".to_string(), + alias_of: None, + created_at: Utc::now(), + metadata: serde_json::json!({}), + fingerprint: None, + }; + println!("{}\n", serde_json::to_string_pretty(&connector).unwrap()); + + // Demo 4: SessionMapping (matches sessions.ndjson format) + println!("4. SessionMapping (sessions.ndjson line):"); + let mapping = SessionMapping { + version: 1, + connector_uid: connector.connector_uid, + native_session_id: "abc123".to_string(), + scroll_id: session_metadata.scroll_id, + created_at: Utc::now(), + alias_of: None, + }; + println!("{}\n", serde_json::to_string(&mapping).unwrap()); + + // Demo 5: FileRecord (matches file_index.jsonl format) + println!("5. FileRecord (file_index.jsonl line):"); + let file_record = FileRecord { + version: 1, + file_id: "sha256:abc123def456".to_string(), + path: ".files/ab/cd/abc123def456".to_string(), + size: 123456, + mime: Some("application/pdf".to_string()), + original_name: "spec.pdf".to_string(), + sessions: vec![session_metadata.scroll_id], + metadata: serde_json::json!({ + "source": "upload" + }), + }; + println!("{}\n", serde_json::to_string(&file_record).unwrap()); + + // Demo 6: Enum serialization + println!("6. Enum Serialization:"); + println!( + " Continuation::Split: {}", + serde_json::to_string(&Continuation::Split).unwrap() + ); + println!( + " Continuation::Compact: {}", + serde_json::to_string(&Continuation::Compact).unwrap() + ); + println!( + " RegisterStatus::Accepted: {}", + serde_json::to_string(&RegisterStatus::Accepted).unwrap() + ); + println!( + " RegisterStatus::Aliased: {}", + serde_json::to_string(&RegisterStatus::Aliased).unwrap() + ); + println!(); + + // Demo 7: API types + println!("7. RegisterConnectorResponse:"); + let response = RegisterConnectorResponse { + status: RegisterStatus::Accepted, + connector_uid: Uuid::now_v7(), + alias_of: None, + note: Some("Successfully registered".to_string()), + }; + println!("{}\n", serde_json::to_string_pretty(&response).unwrap()); + + println!("8. RegisterSessionResponse:"); + let response = RegisterSessionResponse { + status: RegisterStatus::Aliased, + scroll_id: Uuid::now_v7(), + alias_of: Some(Uuid::now_v7()), + }; + println!("{}\n", serde_json::to_string_pretty(&response).unwrap()); + + println!("=== ALL TYPES MATCH VISION.MD SPECIFICATION ==="); +} diff --git a/crates/dirigent_archivist/examples/event_handling.rs b/crates/dirigent_archivist/examples/event_handling.rs new file mode 100644 index 0000000..50b6acf --- /dev/null +++ b/crates/dirigent_archivist/examples/event_handling.rs @@ -0,0 +1,277 @@ +//! Event handling example for dirigent_archivist +//! +//! This example demonstrates: +//! - Creating an EventHandler +//! - Subscribing to dirigent_protocol events +//! - Accumulating streaming message chunks +//! - Finalizing complete messages +//! - Automatic archival via event stream + +use chrono::Utc; +use dirigent_archivist::{Archivist, EventHandler, Result}; +use dirigent_protocol::streaming::{BusEvent, BusReceiver, EventOrigin, EventRouting}; +use dirigent_protocol::{ + ContentBlock, Event, Message, MessageMetadata, MessagePart, MessageRole, MessageStatus, + Session, SessionMetadata, SessionUpdate, ToolCall, ToolCallStatus, +}; +use std::sync::Arc; +use std::sync::atomic::AtomicU64; +use tokio::sync::mpsc; +use uuid::Uuid; + +/// Wrap a raw `Event` in a `BusEvent` with default routing. +fn wrap(event: Event) -> BusEvent { + BusEvent { + routing: EventRouting::default(), + origin: EventOrigin::Runtime, + event: Arc::new(event), + } +} + +#[tokio::main] +async fn main() -> Result<()> { + // Create a temporary archive directory for this example + let temp_dir = std::env::temp_dir().join(format!("dirigent_event_example_{}", Uuid::now_v7())); + println!("Creating archive at: {}", temp_dir.display()); + + // Step 1: Create archivist and event handler + let archivist = Archivist::new_with_single_archive(temp_dir.clone()).await?; + let archivist = Arc::new(archivist); + let handler = EventHandler::new(archivist.clone()); + + println!("EventHandler created successfully"); + + // Step 2: Create a mock event stream. In production this is built + // by `SharingBus::subscribe_all()`; here we fabricate a `BusReceiver` + // directly so the example stays self-contained. + let (tx, rx) = mpsc::channel::(100); + let bus_rx = BusReceiver { + id: 0, + rx, + lagged: Arc::new(AtomicU64::new(0)), + }; + + // Step 3: Spawn event handler task + let handler_task = tokio::spawn(async move { + handler.run(bus_rx).await; + }); + + // Step 4: Simulate event flow + println!("\n--- Simulating Event Stream ---"); + + // Generate connector and session IDs + let connector_id = Uuid::now_v7().to_string(); + let session_id = Uuid::now_v7().to_string(); + let message_id = Uuid::now_v7().to_string(); + + // Event 1: SessionCreated + println!("\n1. Sending SessionCreated event..."); + let session_created = Event::SessionCreated { + connector_id: connector_id.clone(), + session: Session { + id: session_id.clone(), + title: "Example streaming session".to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + metadata: SessionMetadata { + project_path: "/home/user/project".to_string(), + model: Some("claude-3-5-sonnet".to_string()), + total_messages: 0, + system_message: None, + current_mode_id: None, + _meta: None, + project_id: None, + }, + cwd: None, + models: None, + modes: None, + config_options: None, + acp_client_id: None, + }, + }; + tx.send(wrap(session_created)).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Event 2-5: Streaming message chunks (AgentMessageChunk) + println!("2. Sending streaming message chunks..."); + let chunks = vec!["Hello! ", "I'm here to ", "help you with ", "your code."]; + + for (i, chunk) in chunks.iter().enumerate() { + let chunk_event = Event::SessionUpdate { + connector_id: connector_id.clone(), + session_id: session_id.clone(), + update: SessionUpdate::AgentMessageChunk { + message_id: message_id.clone(), + content: ContentBlock::Text { + text: chunk.to_string(), + }, + _meta: None, + }, + }; + tx.send(wrap(chunk_event)).await.unwrap(); + println!(" Chunk {}: {:?}", i + 1, chunk); + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + } + + // Event 6: Thinking chunk + println!("3. Sending thinking chunk..."); + let thinking_event = Event::SessionUpdate { + connector_id: connector_id.clone(), + session_id: session_id.clone(), + update: SessionUpdate::AgentThoughtChunk { + message_id: message_id.clone(), + content: ContentBlock::Text { + text: "Let me consider the best approach...".to_string(), + }, + _meta: None, + }, + }; + tx.send(wrap(thinking_event)).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Event 7: Tool call + println!("4. Sending tool call event..."); + let tool_call_event = Event::SessionUpdate { + connector_id: connector_id.clone(), + session_id: session_id.clone(), + update: SessionUpdate::ToolCall { + message_id: message_id.clone(), + tool_call: ToolCall { + id: "tool_call_123".to_string(), + tool_name: "read_file".to_string(), + status: ToolCallStatus::Completed, + content: vec![], + raw_input: Some(serde_json::json!({ + "path": "/home/user/project/main.rs" + })), + raw_output: Some(serde_json::json!({ + "content": "fn main() { println!(\"Hello\"); }" + })), + title: None, + error: None, + metadata: None, + origin: None, + }, + _meta: None, + }, + }; + tx.send(wrap(tool_call_event)).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Event 8: MessageCompleted (triggers finalization) + println!("5. Sending MessageCompleted event..."); + let message_completed = Event::MessageCompleted { + connector_id: connector_id.clone(), + message: Message { + id: message_id.clone(), + session_id: session_id.clone(), + role: MessageRole::Assistant, + created_at: Utc::now(), + content: vec![MessagePart::Text { + text: chunks.concat(), + }], + status: MessageStatus::Completed, + metadata: Some(MessageMetadata { + cost: None, + tokens_input: None, + tokens_output: None, + response_time_ms: None, + latency_ms: Some(1500), + model: Some("claude-3-5-sonnet".to_string()), + other: None, + }), + }, + }; + tx.send(wrap(message_completed)).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + // Event 9: Second message (user response) + println!("6. Sending user message..."); + let user_message_id = Uuid::now_v7().to_string(); + let user_chunks = vec!["Thanks! ", "Can you explain ", "the code?"]; + + for (i, chunk) in user_chunks.iter().enumerate() { + let chunk_event = Event::SessionUpdate { + connector_id: connector_id.clone(), + session_id: session_id.clone(), + update: SessionUpdate::UserMessageChunk { + message_id: user_message_id.clone(), + content: ContentBlock::Text { + text: chunk.to_string(), + }, + _meta: None, + }, + }; + tx.send(wrap(chunk_event)).await.unwrap(); + println!(" User chunk {}: {:?}", i + 1, chunk); + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + } + + let user_message_completed = Event::MessageCompleted { + connector_id: connector_id.clone(), + message: Message { + id: user_message_id.clone(), + session_id: session_id.clone(), + role: MessageRole::User, + created_at: Utc::now(), + content: vec![MessagePart::Text { + text: user_chunks.concat(), + }], + status: MessageStatus::Completed, + metadata: None, + }, + }; + tx.send(wrap(user_message_completed)).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; + + // Step 5: Verify archived data + println!("\n--- Verifying Archived Data ---"); + + // Parse connector_uid from connector_id string + let connector_uid = + Uuid::parse_str(&connector_id).expect("connector_id should be a valid UUID"); + + // List sessions + let page = archivist + .list_sessions_paged( + dirigent_archivist::SessionListQuery::default() + .with_connector(connector_uid) + .with_limit(100), + ) + .await?; + let sessions = page.items; + println!("Found {} session(s) in archive", sessions.len()); + for session in &sessions { + println!(" Session: {} - {:?}", session.scroll_id, session.title); + } + + // Get messages + if let Some(session) = sessions.first() { + let messages = archivist.get_messages(session.scroll_id, None).await?; + println!("\nFound {} message(s):", messages.len()); + for msg in &messages { + println!("\n[{}] {} chars", msg.role, msg.content_md.len()); + println!( + "Content preview: {}", + &msg.content_md.chars().take(100).collect::() + ); + } + } + + // Step 6: Cleanup + println!("\n--- Cleanup ---"); + + // Drop the event sender to close the channel + drop(tx); + + // Wait for handler to finish + handler_task.await.expect("Handler task failed"); + + // Remove temporary archive + std::fs::remove_dir_all(&temp_dir)?; + println!("Removed temporary archive"); + + println!("\nExample completed successfully!"); + + Ok(()) +} diff --git a/crates/dirigent_archivist/examples/file_storage.rs b/crates/dirigent_archivist/examples/file_storage.rs new file mode 100644 index 0000000..2c7737d --- /dev/null +++ b/crates/dirigent_archivist/examples/file_storage.rs @@ -0,0 +1,214 @@ +//! File storage example for dirigent_archivist +//! +//! This example demonstrates: +//! - Storing files with content-addressing +//! - Retrieving files by file_id +//! - Automatic deduplication of identical content +//! - Session tracking for file references + +use dirigent_archivist::storage::{files, ndjson, paths::ArchivePaths}; +use dirigent_archivist::types::FileRecord; +use dirigent_archivist::Result; +use uuid::Uuid; + +#[tokio::main] +async fn main() -> Result<()> { + // Create a temporary archive directory for this example + let temp_dir = std::env::temp_dir().join(format!("dirigent_files_example_{}", Uuid::now_v7())); + println!("Creating archive at: {}", temp_dir.display()); + + let paths = ArchivePaths::new(temp_dir.clone()); + + // Example 1: Store a file + println!("\n--- Example 1: Store a File ---"); + let content1 = b"This is a sample document with some text content."; + let session1 = Uuid::now_v7(); + + let file_id1 = files::store_file( + &paths, + content1, + "document.txt".to_string(), + Some("text/plain".to_string()), + session1, + ) + .await?; + + println!("Stored file with ID: {}", file_id1); + println!("Session: {}", session1); + + // Example 2: Retrieve the file + println!("\n--- Example 2: Retrieve the File ---"); + let retrieved1 = files::get_file(&paths, &file_id1).await?; + println!("Retrieved {} bytes", retrieved1.len()); + println!("Content: {}", String::from_utf8_lossy(&retrieved1)); + + // Example 3: Store the same content from a different session (deduplication) + println!("\n--- Example 3: Deduplication Demo ---"); + let session2 = Uuid::now_v7(); + + let file_id2 = files::store_file( + &paths, + content1, // Same content as before + "duplicate.txt".to_string(), // Different name + Some("text/plain".to_string()), + session2, + ) + .await?; + + println!("Stored same content with different name"); + println!("File ID 1: {}", file_id1); + println!("File ID 2: {}", file_id2); + println!("Same file_id? {}", file_id1 == file_id2); + println!("\nDeduplication: Same content produces same file_id, stored only once!"); + + // Example 4: Check the file index + println!("\n--- Example 4: File Index ---"); + let index_path = paths.root().join(".files").join("file_index.jsonl"); + let records: Vec = ndjson::read_ndjson(&index_path).await?; + + println!("File index contains {} record(s)", records.len()); + for record in &records { + println!("\nFile: {}", record.file_id); + println!(" Original name: {}", record.original_name); + println!(" MIME type: {:?}", record.mime); + println!(" Size: {} bytes", record.size); + println!(" Referenced by {} session(s):", record.sessions.len()); + for session_id in &record.sessions { + println!(" - {}", session_id); + } + } + + // Example 5: Store different content + println!("\n--- Example 5: Store Different Content ---"); + let content2 = b"This is completely different content with more data!"; + let session3 = Uuid::now_v7(); + + let file_id3 = files::store_file( + &paths, + content2, + "different.txt".to_string(), + Some("text/plain".to_string()), + session3, + ) + .await?; + + println!("Stored different content"); + println!("File ID 3: {}", file_id3); + println!("Different from file_id1? {}", file_id1 != file_id3); + + // Example 6: Store binary content + println!("\n--- Example 6: Binary Content ---"); + let binary_content: Vec = (0..256).map(|i| i as u8).collect(); + let session4 = Uuid::now_v7(); + + let file_id4 = files::store_file( + &paths, + &binary_content, + "binary.dat".to_string(), + Some("application/octet-stream".to_string()), + session4, + ) + .await?; + + println!("Stored binary content (256 bytes)"); + println!("File ID: {}", file_id4); + + // Retrieve and verify + let retrieved_binary = files::get_file(&paths, &file_id4).await?; + println!("Retrieved {} bytes", retrieved_binary.len()); + println!( + "Binary content verified: {}", + retrieved_binary == binary_content + ); + + // Example 7: Show final archive structure + println!("\n--- Example 7: Archive Structure ---"); + println!("Archive root: {}", temp_dir.display()); + show_files_directory(&paths)?; + + // Example 8: Final statistics + println!("\n--- Final Statistics ---"); + let final_records: Vec = ndjson::read_ndjson(&index_path).await?; + println!("Total unique files stored: {}", final_records.len()); + + let total_sessions: usize = final_records.iter().map(|r| r.sessions.len()).sum(); + println!("Total session references: {}", total_sessions); + + let total_size: u64 = final_records.iter().map(|r| r.size).sum(); + println!("Total storage used: {} bytes", total_size); + + // Content-addressing means if we had stored content1 1000 times, + // we'd still only use storage for it once! + println!("\nContent-addressing benefit:"); + println!(" File '{}' is referenced by {} sessions", file_id1, 2); + println!(" But stored only once on disk!"); + + // Cleanup + println!("\n--- Cleanup ---"); + std::fs::remove_dir_all(&temp_dir)?; + println!("Removed temporary archive"); + + println!("\nExample completed successfully!"); + + Ok(()) +} + +/// Helper function to show .files directory structure +fn show_files_directory(paths: &ArchivePaths) -> Result<()> { + let files_dir = paths.root().join(".files"); + + if !files_dir.exists() { + println!("No files directory found"); + return Ok(()); + } + + println!("\n.files/ directory:"); + + // Show index file + let index_path = files_dir.join("file_index.jsonl"); + if index_path.exists() { + let metadata = std::fs::metadata(&index_path)?; + println!(" file_index.jsonl ({} bytes)", metadata.len()); + } + + // Show shard directories + for entry in std::fs::read_dir(&files_dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() { + println!(" {}/", path.file_name().unwrap().to_string_lossy()); + + // Show files in shard + for sub_entry in std::fs::read_dir(&path)? { + let sub_entry = sub_entry?; + let sub_path = sub_entry.path(); + + if sub_path.is_dir() { + println!(" {}/", sub_path.file_name().unwrap().to_string_lossy()); + + // Show files in sub-shard + for file_entry in std::fs::read_dir(&sub_path)? { + let file_entry = file_entry?; + let file_path = file_entry.path(); + let metadata = std::fs::metadata(&file_path)?; + println!( + " {} ({} bytes)", + file_path.file_name().unwrap().to_string_lossy(), + metadata.len() + ); + } + } else { + let metadata = std::fs::metadata(&sub_path)?; + println!( + " {} ({} bytes)", + sub_path.file_name().unwrap().to_string_lossy(), + metadata.len() + ); + } + } + } + } + + Ok(()) +} diff --git a/crates/dirigent_archivist/examples/multi_backend.rs b/crates/dirigent_archivist/examples/multi_backend.rs new file mode 100644 index 0000000..73773ea --- /dev/null +++ b/crates/dirigent_archivist/examples/multi_backend.rs @@ -0,0 +1,199 @@ +//! Example: two `JsonlBackend`s side by side, demonstrating boot-from-config, +//! priority-ordered read routing, write fanout, and a health snapshot. +//! +//! Layout: +//! - `primary` → `read_priority = 0`, `failure_mode = required` (default) +//! - `mirror` → `read_priority = 10`, `failure_mode = best_effort` +//! +//! The primary is the default write target (lowest priority among +//! Required+write-active backends). `append_messages` fans out inline to the +//! mirror too. Reads walk the registrations in priority order, so the primary +//! answers first; if it is missing a session, the walk falls through to the +//! mirror. +//! +//! Run with: +//! +//! cargo run --package dirigent_archivist --example multi_backend + +use std::sync::Arc; + +use chrono::Utc; +use dirigent_archivist::coordinator::Archivist; +use dirigent_archivist::registry::{ArchivesConfig, BackendRegistry}; +use dirigent_archivist::types::{ + MessageRecord, RegisterConnectorRequest, RegisterSessionRequest, +}; +use uuid::Uuid; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let dir_a = tempfile::tempdir()?; + let dir_b = tempfile::tempdir()?; + + // Build a two-archive config entirely from TOML so the example doubles as + // a faithful demonstration of the config surface. + let cfg_src = format!( + r#" + [[archives]] + name = "primary" + type = "jsonl" + read_priority = 0 + [archives.params] + path = "{a}" + + [[archives]] + name = "mirror" + type = "jsonl" + failure_mode = "best_effort" + read_priority = 10 + [archives.params] + path = "{b}" + "#, + a = dir_a.path().to_string_lossy().replace('\\', "/"), + b = dir_b.path().to_string_lossy().replace('\\', "/"), + ); + let cfg: ArchivesConfig = toml::from_str(&cfg_src)?; + let registry = BackendRegistry::with_jsonl(); + let archivist = Arc::new(Archivist::from_config(cfg, ®istry, None).await?); + + println!("\n=== Multi-backend Archivist example ===\n"); + println!("Boot complete. Archives (ordered by read_priority):"); + for s in archivist.list_archives_with_health().await { + println!( + " - name={:<8} type={:<6} priority={:<3} enabled={} write_active={} failure_mode={:?} health={:?}", + s.name, + s.type_name, + s.read_priority, + s.enabled, + s.write_active, + s.failure_mode, + s.health, + ); + } + + // ------------------------------------------------------------------ + // Register a connector. The primary owns the canonical record; fanout + // mirrors it to the secondary. + // ------------------------------------------------------------------ + let connector_resp = archivist + .register_connector( + RegisterConnectorRequest { + r#type: "Example".into(), + title: "multi-backend demo".into(), + client_native_id: "example://multi_backend".into(), + custom_uid: None, + metadata: serde_json::json!({ "demo": true }), + fingerprint: None, + }, + None, + ) + .await?; + let connector_uid = connector_resp.connector_uid; + println!( + "\nRegistered connector: uid={} status={:?}", + connector_uid, connector_resp.status + ); + + // ------------------------------------------------------------------ + // Register a session under that connector. `register_session` writes + // the mapping and `session.json` on the primary first, then fans out + // to any enabled secondaries. + // ------------------------------------------------------------------ + let session_resp = archivist + .register_session( + RegisterSessionRequest { + connector_uid, + native_session_id: "demo-session-1".into(), + title: Some("multi-backend demo session".into()), + custom_scroll_id: None, + metadata: serde_json::json!({ "model": "demo" }), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await?; + let scroll_id = session_resp.scroll_id; + println!( + "Registered session: scroll_id={} status={:?}", + scroll_id, session_resp.status + ); + + // ------------------------------------------------------------------ + // Append a couple of messages. `append_messages` writes to the primary + // and then fans out inline to the mirror. + // ------------------------------------------------------------------ + let user_msg = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: scroll_id, + parent_id: None, + ts: Utc::now(), + role: "user".into(), + author: Some("alice".into()), + content_md: "Hello from the multi-backend example!".into(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + let asst_msg = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: scroll_id, + parent_id: Some(user_msg.message_id), + ts: Utc::now(), + role: "assistant".into(), + author: Some("claude".into()), + content_md: "Greetings. I have been written to two archives.".into(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + archivist + .append_messages(scroll_id, vec![user_msg, asst_msg], None) + .await?; + println!("\nAppended 2 messages — fanned out to primary + mirror."); + + // ------------------------------------------------------------------ + // Read path: the priority walk tries the primary first (priority=0). + // It finds the session there and never consults the mirror. + // ------------------------------------------------------------------ + let meta = archivist.get_session_metadata(scroll_id, None).await?; + println!( + "\nRead session via priority walk: title={:?} completeness={:?}", + meta.title, meta.completeness + ); + println!( + "Read cache size after read: {}", + archivist.read_cache_size().await + ); + + let messages = archivist.get_messages(scroll_id, None).await?; + println!("Read {} message(s) from the archive:", messages.len()); + for m in &messages { + println!(" - [{}] {}", m.role, m.content_md); + } + + // ------------------------------------------------------------------ + // Final health snapshot. Both backends should still be Available and + // have no queued writes (both run Inline write policies by default). + // ------------------------------------------------------------------ + println!("\nFinal health snapshot:"); + for s in archivist.list_archives_with_health().await { + println!( + " - {:<8} health={:?} queue_depth={:?} last_error={:?}", + s.name, s.health, s.queue_depth, s.last_error + ); + } + + // Clean shutdown drains any queued writer tasks. Both backends here run + // Inline, so this is effectively a no-op but remains the correct API. + archivist.shutdown().await?; + println!("\nShutdown complete."); + Ok(()) +} diff --git a/crates/dirigent_archivist/src/accumulator.rs b/crates/dirigent_archivist/src/accumulator.rs new file mode 100644 index 0000000..c6296b6 --- /dev/null +++ b/crates/dirigent_archivist/src/accumulator.rs @@ -0,0 +1,923 @@ +//! Message accumulator for incremental message assembly. +//! +//! This is a thin wrapper around [`dirigent_protocol::accumulator::MessageAccumulator`] +//! that delegates chunk/tool/thinking operations to the protocol accumulator and +//! converts [`AccumulatedMessage`] to [`MessageRecord`] on `finalize()`. +//! +//! The accumulator preserves the order of content parts (text, thinking, tool calls) +//! as they arrive in the event stream, enabling inline tool rendering in the UI. + +use chrono::{DateTime, Utc}; +use dirigent_protocol::accumulator::{ + AccumulatedMessage, AccumulatedPart, + MessageAccumulator as ProtocolAccumulator, +}; +#[cfg(test)] +use dirigent_protocol::MessagePart; +use dirigent_protocol::ContentBlock; +use serde_json::Value; +use std::collections::HashMap; +use uuid::Uuid; + +use crate::error::Result; +use crate::types::MessageRecord; + +// Re-export ToolCallData from the protocol for backward compatibility. +pub use dirigent_protocol::accumulator::ToolCallData; + +/// Accumulator for assembling streaming message deltas into [`MessageRecord`]s. +/// +/// Wraps the protocol-level [`ProtocolAccumulator`] and adds archivist-specific +/// concerns: per-message metadata, UUID parsing, and markdown generation. +#[derive(Debug, Default)] +pub struct MessageAccumulator { + inner: ProtocolAccumulator, + /// Per-message metadata not tracked by the protocol accumulator. + metadata: HashMap, +} + +impl MessageAccumulator { + /// Create a new message accumulator + pub fn new() -> Result { + Ok(Self { + inner: ProtocolAccumulator::new(), + metadata: HashMap::new(), + }) + } + + /// Add a content chunk to the message buffer + pub fn add_chunk( + &mut self, + message_id: String, + session_id: String, + connector_id: String, + role: String, + content: ContentBlock, + ) { + self.inner + .add_chunk(&message_id, &session_id, &connector_id, &role, content); + } + + /// Add thinking content to the message buffer + pub fn add_thinking( + &mut self, + message_id: String, + session_id: String, + connector_id: String, + content: String, + ) { + self.inner.add_thinking(&message_id, &session_id, &connector_id, &content); + } + + /// Add or update a tool call in the message buffer + /// + /// This method handles both initial ToolCall events and ToolCallUpdate events. + /// If a tool call with the given ID already exists, it updates the existing entry. + /// Otherwise, it adds a new entry. + /// + /// This ensures that each tool_call_id appears exactly ONCE in the final message, + /// with the most recent input/output data. + pub fn add_or_update_tool_call(&mut self, message_id: String, tool_call: ToolCallData) { + self.inner.add_or_update_tool_call(&message_id, tool_call); + } + + /// Add a tool call to the message buffer (DEPRECATED - use add_or_update_tool_call) + #[deprecated(note = "Use add_or_update_tool_call instead to avoid duplicates")] + pub fn add_tool_call(&mut self, message_id: String, tool_call: ToolCallData) { + self.add_or_update_tool_call(message_id, tool_call); + } + + /// Update an existing tool call in the message buffer + /// + /// Finds the tool call by ID and updates its input/output with non-empty values + /// from the update. If no matching tool call is found, this is a no-op (the + /// update arrived before the initial ToolCall). + pub fn update_tool_call( + &mut self, + message_id: String, + tool_call_id: &str, + input: Option, + output: Option, + ) { + // Construct a ToolCallData and delegate to add_or_update_tool_call. + // We need the tool_name but don't have it here; use an empty string + // since add_or_update_tool_call only updates existing entries when the + // id matches. However, if there's no existing entry, this would create + // a new one with empty tool_name - so we need to check first. + // + // Instead, we use the protocol accumulator's update semantics directly: + // build a ToolCallData with the values we have. + let tool_call = ToolCallData { + id: tool_call_id.to_string(), + tool_name: String::new(), // Will be overwritten by existing entry's name + input: input.unwrap_or(Value::Null), + output, + }; + + // Only delegate if a buffer exists for this message (matching original behavior). + if self.inner.has_buffer(&message_id) { + self.inner.add_or_update_tool_call(&message_id, tool_call); + } + } + + /// Get all message IDs for a given session that have active buffers + pub fn get_message_ids_for_session(&self, session_id: &str) -> Vec { + self.inner.message_ids_for_session(session_id) + } + + /// Get message IDs for buffers that have been inactive longer than the threshold + pub fn get_stale_message_ids( + &self, + _now: DateTime, + threshold: std::time::Duration, + ) -> Vec { + self.inner.stale_message_ids(threshold) + } + + /// Get all message IDs that have active buffers + pub fn get_all_message_ids(&self) -> Vec { + self.inner.active_message_ids() + } + + /// Finalize a message and produce a complete `(MessageRecord, connector_id, native_session_id)`. + /// + /// Returns `None` if no buffer exists for the given `message_id`. + /// The `connector_id` and `native_session_id` in the tuple are the raw values + /// that were passed into `add_chunk`/`add_thinking` — callers in Task 5 will use + /// these to resolve the canonical scroll_id. + pub fn finalize(&mut self, message_id: &str) -> Option<(MessageRecord, String, String)> { + let accumulated = self.inner.finalize(message_id)?; + + let connector_id = accumulated.connector_id.clone(); + let native_session_id = accumulated.session_id.clone(); + + // Take stored metadata for this message (if any). + let metadata = self + .metadata + .remove(message_id) + .unwrap_or(Value::Null); + + let record = accumulated_to_record(accumulated, metadata); + Some((record, connector_id, native_session_id)) + } +} + +// --------------------------------------------------------------------------- +// Conversion helpers +// --------------------------------------------------------------------------- + +/// Convert an [`AccumulatedMessage`] into a [`MessageRecord`] for archival. +fn accumulated_to_record(accumulated: AccumulatedMessage, metadata: Value) -> MessageRecord { + // Build content_md by iterating parts in order + let mut content_md = String::new(); + + for part in &accumulated.parts { + match part { + AccumulatedPart::Text { text } => { + content_md.push_str(text); + } + AccumulatedPart::Thinking { text } => { + content_md.push_str("\n\n\n"); + content_md.push_str(text); + content_md.push_str("\n"); + } + AccumulatedPart::Tool { data } => { + content_md.push_str(&format!( + "\n\n**Tool**: {}\n```json\n{}\n```", + data.tool_name, + serde_json::to_string_pretty(&data.input) + .unwrap_or_else(|_| "{}".to_string()) + )); + } + } + } + + // Convert accumulated parts to protocol MessageParts for rich rendering + let message_parts = accumulated.to_message_parts(); + + // Serialize content_parts for storage (None if empty to save space) + let content_parts = if message_parts.is_empty() { + None + } else { + serde_json::to_value(&message_parts).ok() + }; + + // Parse UUIDs from strings + // Strip "msg-" prefix if present (ACP connectors use this format) + let message_id_str = accumulated + .message_id + .strip_prefix("msg-") + .unwrap_or(&accumulated.message_id); + + if message_id_str != accumulated.message_id.as_str() { + tracing::debug!( + "Stripped 'msg-' prefix from message_id: {} -> {}", + accumulated.message_id, + message_id_str + ); + } + + let message_uuid = match Uuid::parse_str(message_id_str) { + Ok(uuid) => uuid, + Err(_) => { + tracing::warn!( + "Failed to parse message_id as UUID: {}", + accumulated.message_id + ); + Uuid::now_v7() + } + }; + + // Strip "msg-" prefix from session_id if present (for consistency) + let session_id_str = accumulated + .session_id + .strip_prefix("msg-") + .unwrap_or(&accumulated.session_id); + + if session_id_str != accumulated.session_id.as_str() { + tracing::debug!( + "Stripped 'msg-' prefix from session_id: {} -> {}", + accumulated.session_id, + session_id_str + ); + } + + let session_uuid = match Uuid::parse_str(session_id_str) { + Ok(uuid) => uuid, + Err(_) => { + tracing::warn!( + "Failed to parse session_id as UUID: {}", + accumulated.session_id + ); + Uuid::now_v7() + } + }; + + MessageRecord { + version: 1, + message_id: message_uuid, + session: session_uuid, + parent_id: None, + ts: accumulated.created_at.unwrap_or_else(Utc::now), + role: accumulated.role, + author: None, + content_md, + content_parts, + attachments: Vec::new(), + metadata, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_accumulator_creation() { + let acc = MessageAccumulator::new().unwrap(); + assert_eq!(acc.get_all_message_ids().len(), 0); + } + + #[test] + fn test_add_text_chunk() { + let mut acc = MessageAccumulator::new().unwrap(); + + acc.add_chunk( + "msg_1".to_string(), + "session_1".to_string(), + "connector_1".to_string(), + "user".to_string(), + ContentBlock::Text { + text: "Hello, ".to_string(), + }, + ); + + acc.add_chunk( + "msg_1".to_string(), + "session_1".to_string(), + "connector_1".to_string(), + "user".to_string(), + ContentBlock::Text { + text: "world!".to_string(), + }, + ); + + // Two consecutive text chunks should be coalesced. + // Finalize and check the result. + let (record, _, _) = acc.finalize("msg_1").unwrap(); + assert_eq!(record.content_md, "Hello, world!"); + } + + #[test] + fn test_add_thinking_chunk() { + let mut acc = MessageAccumulator::new().unwrap(); + + acc.add_thinking( + "msg_2".to_string(), + "session_2".to_string(), + "connector_1".to_string(), + "Let me think... ".to_string(), + ); + + acc.add_thinking( + "msg_2".to_string(), + "session_2".to_string(), + "connector_1".to_string(), + "I need to analyze this.".to_string(), + ); + + // Finalize and verify thinking was coalesced + let (record, _, _) = acc.finalize("msg_2").unwrap(); + assert!(record.content_md.contains("Let me think... I need to analyze this.")); + assert_eq!(record.role, "assistant"); + } + + #[test] + fn test_finalize_text_only() { + let mut acc = MessageAccumulator::new().unwrap(); + + acc.add_chunk( + "01936e8f-e5a7-7000-8000-000000000001".to_string(), + "01936e8f-e5a7-7000-8000-000000000002".to_string(), + "connector_1".to_string(), + "user".to_string(), + ContentBlock::Text { + text: "Hello, ".to_string(), + }, + ); + + acc.add_chunk( + "01936e8f-e5a7-7000-8000-000000000001".to_string(), + "01936e8f-e5a7-7000-8000-000000000002".to_string(), + "connector_1".to_string(), + "user".to_string(), + ContentBlock::Text { + text: "world!".to_string(), + }, + ); + + let (record, _, _) = acc + .finalize("01936e8f-e5a7-7000-8000-000000000001") + .unwrap(); + + assert_eq!(record.content_md, "Hello, world!"); + assert_eq!(record.role, "user"); + assert!(record.ts <= Utc::now()); + } + + #[test] + fn test_finalize_with_thinking() { + let mut acc = MessageAccumulator::new().unwrap(); + + acc.add_chunk( + "01936e8f-e5a7-7000-8000-000000000003".to_string(), + "01936e8f-e5a7-7000-8000-000000000004".to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "Here's my response.".to_string(), + }, + ); + + acc.add_thinking( + "01936e8f-e5a7-7000-8000-000000000003".to_string(), + "01936e8f-e5a7-7000-8000-000000000004".to_string(), + "connector_1".to_string(), + "Let me analyze this carefully.".to_string(), + ); + + let (record, _, _) = acc + .finalize("01936e8f-e5a7-7000-8000-000000000003") + .unwrap(); + + assert!(record.content_md.contains("Here's my response.")); + assert!(record.content_md.contains("")); + assert!(record.content_md.contains("Let me analyze this carefully.")); + assert!(record.content_md.contains("")); + } + + #[test] + fn test_finalize_nonexistent_message() { + let mut acc = MessageAccumulator::new().unwrap(); + let result = acc.finalize("nonexistent"); + assert!(result.is_none()); + } + + #[test] + fn test_add_tool_call() { + let mut acc = MessageAccumulator::new().unwrap(); + + // First add a text chunk to create the buffer + acc.add_chunk( + "msg_tool".to_string(), + "session_tool".to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "I'll use a tool.".to_string(), + }, + ); + + // Add a tool call + let tool_call = ToolCallData { + id: "call_123".to_string(), + tool_name: "search".to_string(), + input: serde_json::json!({"query": "test"}), + output: Some(serde_json::json!({"results": ["a", "b"]})), + }; + + #[allow(deprecated)] + acc.add_tool_call("msg_tool".to_string(), tool_call); + + // Finalize and verify + let (record, _, _) = acc.finalize("msg_tool").unwrap(); + let parts = + serde_json::from_value::>(record.content_parts.unwrap()).unwrap(); + assert_eq!(parts.len(), 2); // One Text, one Tool + assert!(matches!(parts[1], MessagePart::Tool { .. })); + if let MessagePart::Tool { tool, .. } = &parts[1] { + assert_eq!(tool, "search"); + } + } + + #[test] + fn test_finalize_with_tool_calls() { + let mut acc = MessageAccumulator::new().unwrap(); + + acc.add_chunk( + "01936e8f-e5a7-7000-8000-000000000005".to_string(), + "01936e8f-e5a7-7000-8000-000000000006".to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "Let me search for that.".to_string(), + }, + ); + + let tool_call = ToolCallData { + id: "call_456".to_string(), + tool_name: "web_search".to_string(), + input: serde_json::json!({"query": "Rust async"}), + output: None, + }; + + #[allow(deprecated)] + acc.add_tool_call( + "01936e8f-e5a7-7000-8000-000000000005".to_string(), + tool_call, + ); + + let (record, _, _) = acc + .finalize("01936e8f-e5a7-7000-8000-000000000005") + .unwrap(); + + assert!(record.content_md.contains("Let me search for that.")); + assert!(record.content_md.contains("**Tool**: web_search")); + assert!(record.content_md.contains("Rust async")); + } + + #[test] + fn test_concurrent_messages() { + let mut acc = MessageAccumulator::new().unwrap(); + + // Add chunks for two different messages + acc.add_chunk( + "msg_a".to_string(), + "session_1".to_string(), + "connector_1".to_string(), + "user".to_string(), + ContentBlock::Text { + text: "Message A".to_string(), + }, + ); + + acc.add_chunk( + "msg_b".to_string(), + "session_1".to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "Message B".to_string(), + }, + ); + + acc.add_chunk( + "msg_a".to_string(), + "session_1".to_string(), + "connector_1".to_string(), + "user".to_string(), + ContentBlock::Text { + text: " continued".to_string(), + }, + ); + + // Both messages should be buffered + assert_eq!(acc.get_all_message_ids().len(), 2); + + // Finalize and check + let (record_a, _, _) = acc.finalize("msg_a").unwrap(); + assert_eq!(record_a.content_md, "Message A continued"); + + let (record_b, _, _) = acc.finalize("msg_b").unwrap(); + assert_eq!(record_b.content_md, "Message B"); + } + + #[test] + fn test_get_message_ids_for_session() { + let mut acc = MessageAccumulator::new().unwrap(); + + // Add messages to different sessions + acc.add_chunk( + "msg_1".to_string(), + "session_a".to_string(), + "connector_1".to_string(), + "user".to_string(), + ContentBlock::Text { + text: "Message 1".to_string(), + }, + ); + + acc.add_chunk( + "msg_2".to_string(), + "session_a".to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "Message 2".to_string(), + }, + ); + + acc.add_chunk( + "msg_3".to_string(), + "session_b".to_string(), + "connector_1".to_string(), + "user".to_string(), + ContentBlock::Text { + text: "Message 3".to_string(), + }, + ); + + // Get message IDs for session_a + let mut session_a_ids = acc.get_message_ids_for_session("session_a"); + session_a_ids.sort(); + assert_eq!(session_a_ids, vec!["msg_1", "msg_2"]); + + // Get message IDs for session_b + let session_b_ids = acc.get_message_ids_for_session("session_b"); + assert_eq!(session_b_ids, vec!["msg_3"]); + + // Get message IDs for non-existent session + let empty_ids = acc.get_message_ids_for_session("session_c"); + assert!(empty_ids.is_empty()); + } + + #[test] + fn test_finalize_with_msg_prefix() { + let mut acc = MessageAccumulator::new().unwrap(); + + // Use message_id and session_id with "msg-" prefix (ACP format) + let uuid_str = "01936e8f-e5a7-7000-8000-000000000007"; + let session_uuid_str = "01936e8f-e5a7-7000-8000-000000000008"; + + acc.add_chunk( + format!("msg-{}", uuid_str), + format!("msg-{}", session_uuid_str), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "Testing msg- prefix handling.".to_string(), + }, + ); + + let (record, _, _) = acc.finalize(&format!("msg-{}", uuid_str)).unwrap(); + + // Verify that the UUID was correctly parsed (not regenerated) + assert_eq!(record.message_id.to_string(), uuid_str); + assert_eq!(record.session.to_string(), session_uuid_str); + assert_eq!(record.content_md, "Testing msg- prefix handling."); + } + + #[test] + fn test_finalize_without_msg_prefix() { + let mut acc = MessageAccumulator::new().unwrap(); + + // Use message_id and session_id without "msg-" prefix + let uuid_str = "01936e8f-e5a7-7000-8000-000000000009"; + let session_uuid_str = "01936e8f-e5a7-7000-8000-00000000000a"; + + acc.add_chunk( + uuid_str.to_string(), + session_uuid_str.to_string(), + "connector_1".to_string(), + "user".to_string(), + ContentBlock::Text { + text: "Testing without prefix.".to_string(), + }, + ); + + let (record, _, _) = acc.finalize(uuid_str).unwrap(); + + // Verify that the UUID was correctly parsed + assert_eq!(record.message_id.to_string(), uuid_str); + assert_eq!(record.session.to_string(), session_uuid_str); + assert_eq!(record.content_md, "Testing without prefix."); + } + + #[test] + fn test_interleaved_tool_calls() { + let mut acc = MessageAccumulator::new().unwrap(); + let msg_id = "01936e8f-e5a7-7000-8000-000000000010"; + let session_id = "01936e8f-e5a7-7000-8000-000000000011"; + + // Text chunk 1 + acc.add_chunk( + msg_id.to_string(), + session_id.to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "Let me search for that. ".to_string(), + }, + ); + + // Tool call 1 + acc.add_or_update_tool_call( + msg_id.to_string(), + ToolCallData { + id: "call_1".to_string(), + tool_name: "search".to_string(), + input: serde_json::json!({"query": "rust"}), + output: None, + }, + ); + + // Text chunk 2 + acc.add_chunk( + msg_id.to_string(), + session_id.to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "Now let me check the documentation. ".to_string(), + }, + ); + + // Tool call 2 + acc.add_or_update_tool_call( + msg_id.to_string(), + ToolCallData { + id: "call_2".to_string(), + tool_name: "read_docs".to_string(), + input: serde_json::json!({"path": "README.md"}), + output: None, + }, + ); + + // Text chunk 3 + acc.add_chunk( + msg_id.to_string(), + session_id.to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "Based on my research...".to_string(), + }, + ); + + let (record, _, _) = acc.finalize(msg_id).unwrap(); + + // Verify content_md has correct order (text1, tool1, text2, tool2, text3) + let content = &record.content_md; + let search_pos = content.find("**Tool**: search").expect("search tool not found"); + let docs_pos = content + .find("**Tool**: read_docs") + .expect("read_docs tool not found"); + let text1_pos = content.find("Let me search").expect("text1 not found"); + let text2_pos = content.find("Now let me check").expect("text2 not found"); + let text3_pos = content.find("Based on my research").expect("text3 not found"); + + // Verify order: text1 < search < text2 < read_docs < text3 + assert!( + text1_pos < search_pos, + "text1 should come before search tool" + ); + assert!( + search_pos < text2_pos, + "search tool should come before text2" + ); + assert!( + text2_pos < docs_pos, + "text2 should come before read_docs tool" + ); + assert!( + docs_pos < text3_pos, + "read_docs tool should come before text3" + ); + + // Verify content_parts structure + let parts = + serde_json::from_value::>(record.content_parts.unwrap()).unwrap(); + assert_eq!( + parts.len(), + 5, + "Should have 5 parts: text, tool, text, tool, text" + ); + + // Verify each part type in order + assert!(matches!(parts[0], MessagePart::Text { .. })); + assert!(matches!(parts[1], MessagePart::Tool { .. })); + assert!(matches!(parts[2], MessagePart::Text { .. })); + assert!(matches!(parts[3], MessagePart::Tool { .. })); + assert!(matches!(parts[4], MessagePart::Text { .. })); + } + + #[test] + fn test_text_coalescing_with_tool_separation() { + let mut acc = MessageAccumulator::new().unwrap(); + let msg_id = "msg1"; + let session_id = "session1"; + + // Two consecutive text chunks (should coalesce) + acc.add_chunk( + msg_id.to_string(), + session_id.to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "Hello ".to_string(), + }, + ); + acc.add_chunk( + msg_id.to_string(), + session_id.to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "world. ".to_string(), + }, + ); + + // Tool call (separates text) + acc.add_or_update_tool_call( + msg_id.to_string(), + ToolCallData { + id: "call_1".to_string(), + tool_name: "search".to_string(), + input: serde_json::json!({}), + output: None, + }, + ); + + // Two more consecutive text chunks (should coalesce separately) + acc.add_chunk( + msg_id.to_string(), + session_id.to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "More ".to_string(), + }, + ); + acc.add_chunk( + msg_id.to_string(), + session_id.to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "text.".to_string(), + }, + ); + + let (record, _, _) = acc.finalize(msg_id).unwrap(); + + // Should have 3 parts: coalesced text1, tool, coalesced text2 + let parts = + serde_json::from_value::>(record.content_parts.unwrap()).unwrap(); + assert_eq!(parts.len(), 3); + + // Verify first text part is coalesced + if let MessagePart::Text { text } = &parts[0] { + assert_eq!(text, "Hello world. "); + } else { + panic!("Expected Text part"); + } + + // Verify tool part + assert!(matches!(parts[1], MessagePart::Tool { .. })); + + // Verify second text part is coalesced + if let MessagePart::Text { text } = &parts[2] { + assert_eq!(text, "More text."); + } else { + panic!("Expected Text part"); + } + } + + #[test] + fn test_tool_call_progressive_updates() { + let mut acc = MessageAccumulator::new().unwrap(); + let msg_id = "msg1"; + let session_id = "session1"; + + // Create buffer with initial text chunk + acc.add_chunk( + msg_id.to_string(), + session_id.to_string(), + "connector_1".to_string(), + "assistant".to_string(), + ContentBlock::Text { + text: "Using grep... ".to_string(), + }, + ); + + // Initial tool call (empty input, no output) + acc.add_or_update_tool_call( + msg_id.to_string(), + ToolCallData { + id: "call_1".to_string(), + tool_name: "grep".to_string(), + input: serde_json::json!({}), + output: None, + }, + ); + + // Update with actual input + acc.add_or_update_tool_call( + msg_id.to_string(), + ToolCallData { + id: "call_1".to_string(), + tool_name: "grep".to_string(), + input: serde_json::json!({"pattern": "rust"}), + output: None, + }, + ); + + // Update with output + acc.add_or_update_tool_call( + msg_id.to_string(), + ToolCallData { + id: "call_1".to_string(), + tool_name: "grep".to_string(), + input: serde_json::json!({}), // Empty, should not overwrite + output: Some(serde_json::json!({"results": ["match1", "match2"]})), + }, + ); + + let (record, _, _) = acc.finalize(msg_id).unwrap(); + + // Should have 2 parts: text and tool (tool merged from 3 updates) + let parts = + serde_json::from_value::>(record.content_parts.unwrap()).unwrap(); + assert_eq!(parts.len(), 2); + + // Verify first part is text + assert!(matches!(parts[0], MessagePart::Text { .. })); + + // Verify second part is tool with merged data + if let MessagePart::Tool { + tool, input, output, .. + } = &parts[1] + { + assert_eq!(tool, "grep"); + assert_eq!(input, &serde_json::json!({"pattern": "rust"})); // Input preserved + assert!(output.is_some()); // Output added + } else { + panic!("Expected Tool part"); + } + } + + #[test] + fn test_thinking_coalescing() { + let mut acc = MessageAccumulator::new().unwrap(); + let msg_id = "msg1"; + let session_id = "session1"; + + // Add multiple thinking chunks + acc.add_thinking( + msg_id.to_string(), + session_id.to_string(), + "connector_1".to_string(), + "First thought. ".to_string(), + ); + acc.add_thinking( + msg_id.to_string(), + session_id.to_string(), + "connector_1".to_string(), + "Second thought.".to_string(), + ); + + let (record, _, _) = acc.finalize(msg_id).unwrap(); + + // Should have 1 thinking part (coalesced) + let parts = + serde_json::from_value::>(record.content_parts.unwrap()).unwrap(); + assert_eq!(parts.len(), 1); + + // Verify it's coalesced thinking + if let MessagePart::Thinking { text } = &parts[0] { + assert_eq!(text, "First thought. Second thought."); + } else { + panic!("Expected Thinking part"); + } + } +} diff --git a/crates/dirigent_archivist/src/backend/capability.rs b/crates/dirigent_archivist/src/backend/capability.rs new file mode 100644 index 0000000..c5a3669 --- /dev/null +++ b/crates/dirigent_archivist/src/backend/capability.rs @@ -0,0 +1,18 @@ +//! Archive backend capability enumeration. +//! +//! Mandatory session + message primitives are NOT listed here — every +//! backend has them. This enum represents the *optional* sub-traits a +//! backend opts into, surfaced through `ArchiveBackend::as_xxx()` accessors. + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub enum ArchiveCapability { + Search, + Dag, + MetaEvents, + ConnectorRegistry, + SessionMapping, +} + +pub type CapabilitySet = std::collections::HashSet; diff --git a/crates/dirigent_archivist/src/backend/contract.rs b/crates/dirigent_archivist/src/backend/contract.rs new file mode 100644 index 0000000..5599620 --- /dev/null +++ b/crates/dirigent_archivist/src/backend/contract.rs @@ -0,0 +1,108 @@ +//! Reusable sub-trait contract tests. +//! +//! Pass any `&dyn ArchiveBackend` to verify it honors the behavioral +//! contract of each sub-trait it exposes. Phase 2 runs this against +//! `JsonlBackend`; Phase 3+ reuses it for every new backend. + +#![cfg(any(test, feature = "test-utils"))] + +use uuid::Uuid; + +use crate::backend::ArchiveBackend; + +/// Exercises `ConnectorRegistryBackend` through `as_connector_registry()`. +/// Skips silently if the backend does not expose the sub-trait. +pub async fn verify_connector_registry_contract(backend: &dyn ArchiveBackend) { + let Some(registry) = backend.as_connector_registry() else { + return; + }; + + // Empty state — listing returns Vec::new(), not an error. + let list = registry.list_connectors().await.expect("list_connectors"); + assert!(list.is_empty(), "fresh backend should have no connectors"); + + // get_connector on missing UID returns Ok(None). + let missing = registry + .get_connector(Uuid::new_v4()) + .await + .expect("get_connector"); + assert!(missing.is_none()); + + // resolve_connector_uid on unknown id returns Ok(None). + let unresolved = registry + .resolve_connector_uid("nonexistent@host") + .await + .expect("resolve_connector_uid"); + assert!(unresolved.is_none()); +} + +/// Exercises `SessionMappingBackend`. +pub async fn verify_session_mapping_contract(backend: &dyn ArchiveBackend) { + let Some(mapping) = backend.as_session_mapping() else { + return; + }; + + let missing = mapping + .get_mapping(Uuid::new_v4(), "absent") + .await + .expect("get_mapping"); + assert!(missing.is_none()); + + let owner = mapping + .find_owner("absent") + .await + .expect("find_owner"); + assert!(owner.is_none()); +} + +/// Exercises `DagBackend`. +pub async fn verify_dag_contract(backend: &dyn ArchiveBackend) { + let Some(dag) = backend.as_dag() else { + return; + }; + + let children = dag + .get_children(Uuid::new_v4()) + .await + .expect("get_children"); + assert!(children.is_empty()); + + let edges = dag + .get_dag_edges(Uuid::new_v4()) + .await + .expect("get_dag_edges"); + assert!(edges.is_empty()); +} + +/// Exercises `MetaEventsBackend`. +pub async fn verify_meta_events_contract(backend: &dyn ArchiveBackend) { + let Some(meta) = backend.as_meta_events() else { + return; + }; + + let events = meta + .get_meta_events(Uuid::new_v4()) + .await + .expect("get_meta_events"); + assert!(events.is_empty()); + + let by_client = meta + .find_meta_session_by_client("absent") + .await + .expect("find_meta_session_by_client"); + assert!(by_client.is_none()); + + let all = meta + .list_meta_sessions() + .await + .expect("list_meta_sessions"); + assert!(all.is_empty()); +} + +/// One-shot helper: runs every sub-trait contract whose capability is present. +pub async fn verify_all_contracts(backend: &dyn ArchiveBackend) { + verify_connector_registry_contract(backend).await; + verify_session_mapping_contract(backend).await; + verify_dag_contract(backend).await; + verify_meta_events_contract(backend).await; +} diff --git a/crates/dirigent_archivist/src/backend/health.rs b/crates/dirigent_archivist/src/backend/health.rs new file mode 100644 index 0000000..6cc3882 --- /dev/null +++ b/crates/dirigent_archivist/src/backend/health.rs @@ -0,0 +1,10 @@ +//! Health status reported by `ArchiveBackend::health_check`. + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum HealthStatus { + Healthy, + Degraded { reason: String }, + Unavailable { reason: String }, +} diff --git a/crates/dirigent_archivist/src/backend/mock.rs b/crates/dirigent_archivist/src/backend/mock.rs new file mode 100644 index 0000000..e2215bc --- /dev/null +++ b/crates/dirigent_archivist/src/backend/mock.rs @@ -0,0 +1,574 @@ +//! In-memory `ArchiveBackend` for coordinator unit tests. +//! +//! Fully supports every sub-trait. State lives in `Mutex>`. + +#![cfg(any(test, feature = "test-utils"))] + +use std::collections::HashMap; +use std::sync::Mutex; + +use async_trait::async_trait; +use uuid::Uuid; + +use crate::backend::{ + ArchiveBackend, ArchiveCapability, CapabilitySet, ConnectorRegistryBackend, + DagBackend, HealthStatus, MetaEventsBackend, SessionMappingBackend, +}; +use crate::error::{ArchivistError, Result}; +use crate::types::{ + ConnectorRecord, DagEdge, MessageCursor, MessagePage, MessageRecord, + MetaEventRecord, SessionListQuery, SessionMapping, SessionMetadata, SessionPage, +}; + +pub struct MockBackend { + capabilities: CapabilitySet, + sessions: Mutex>, + messages: Mutex>>, + connectors: Mutex>, + mappings: Mutex>, + meta_events: Mutex>>, + dag_edges: Mutex>, + fail_next_writes: std::sync::atomic::AtomicUsize, + fail_next_reads: std::sync::atomic::AtomicUsize, + permanent_error: std::sync::Mutex>, + append_calls: std::sync::Mutex>, + per_op_delay: std::sync::Mutex, +} + +impl MockBackend { + pub fn new() -> Self { + let mut capabilities = CapabilitySet::new(); + capabilities.insert(ArchiveCapability::Dag); + capabilities.insert(ArchiveCapability::MetaEvents); + capabilities.insert(ArchiveCapability::ConnectorRegistry); + capabilities.insert(ArchiveCapability::SessionMapping); + Self { + capabilities, + sessions: Mutex::new(HashMap::new()), + messages: Mutex::new(HashMap::new()), + connectors: Mutex::new(HashMap::new()), + mappings: Mutex::new(HashMap::new()), + meta_events: Mutex::new(HashMap::new()), + dag_edges: Mutex::new(Vec::new()), + fail_next_writes: std::sync::atomic::AtomicUsize::new(0), + fail_next_reads: std::sync::atomic::AtomicUsize::new(0), + permanent_error: std::sync::Mutex::new(None), + append_calls: std::sync::Mutex::new(std::collections::HashMap::new()), + per_op_delay: std::sync::Mutex::new(std::time::Duration::ZERO), + } + } +} + +impl MockBackend { + /// Build a mock with the exact capability set provided. All other state + /// starts empty (same as `new()`). + pub fn with_capabilities(capabilities: CapabilitySet) -> Self { + let mut m = Self::new(); + m.capabilities = capabilities; + m + } + + /// Test helper: does this mock have any meta events for the given session? + pub fn has_meta_events(&self, scroll_id: uuid::Uuid) -> bool { + self.meta_events + .lock() + .unwrap() + .get(&scroll_id) + .map(|v| !v.is_empty()) + .unwrap_or(false) + } + + /// Queue up `count` injected write failures. The next `count` calls to + /// any mutating API return `ArchivistError::Other("injected write failure")` + /// before touching state. + pub fn inject_write_failures(&self, count: usize) { + self.fail_next_writes + .store(count, std::sync::atomic::Ordering::SeqCst); + } + + /// Queue up `count` injected read failures for per-scroll_id reads. + pub fn inject_read_failures(&self, count: usize) { + self.fail_next_reads + .store(count, std::sync::atomic::Ordering::SeqCst); + } + + /// Simulate a permanently broken backend. + pub fn break_permanently(&self, reason: impl Into) { + *self.permanent_error.lock().unwrap() = Some(reason.into()); + } + + pub fn clear_failures(&self) { + self.fail_next_writes + .store(0, std::sync::atomic::Ordering::SeqCst); + self.fail_next_reads + .store(0, std::sync::atomic::Ordering::SeqCst); + *self.permanent_error.lock().unwrap() = None; + } + + /// Test helper: how many `MessageRecord`s this mock has for the given session. + pub fn appended_count(&self, scroll_id: uuid::Uuid) -> usize { + self.messages + .lock() + .unwrap() + .get(&scroll_id) + .map(|v| v.len()) + .unwrap_or(0) + } + + /// Test helper: how many times `append_messages` was invoked for the + /// given session (regardless of message count per invocation). + pub fn append_call_count(&self, scroll_id: uuid::Uuid) -> usize { + self.append_calls + .lock() + .unwrap() + .get(&scroll_id) + .copied() + .unwrap_or(0) + } + + /// Test helper: artificially slow every mutating backend operation by + /// sleeping `d` before it touches state. Used to simulate a slow backend + /// for backpressure tests. + pub fn set_per_op_delay(&self, d: std::time::Duration) { + *self.per_op_delay.lock().unwrap() = d; + } + + async fn maybe_delay(&self) { + let d = *self.per_op_delay.lock().unwrap(); + if !d.is_zero() { + tokio::time::sleep(d).await; + } + } + + pub(crate) fn check_write_failure(&self) -> Result<()> { + if let Some(reason) = self.permanent_error.lock().unwrap().clone() { + return Err(ArchivistError::Other(reason)); + } + let prev = self + .fail_next_writes + .fetch_update( + std::sync::atomic::Ordering::SeqCst, + std::sync::atomic::Ordering::SeqCst, + |n| if n > 0 { Some(n - 1) } else { None }, + ) + .ok(); + if prev.is_some() { + return Err(ArchivistError::Other("injected write failure".into())); + } + Ok(()) + } + + pub(crate) fn check_read_failure(&self) -> Result<()> { + if let Some(reason) = self.permanent_error.lock().unwrap().clone() { + return Err(ArchivistError::Other(reason)); + } + let prev = self + .fail_next_reads + .fetch_update( + std::sync::atomic::Ordering::SeqCst, + std::sync::atomic::Ordering::SeqCst, + |n| if n > 0 { Some(n - 1) } else { None }, + ) + .ok(); + if prev.is_some() { + return Err(ArchivistError::Other("injected read failure".into())); + } + Ok(()) + } +} + +impl Default for MockBackend { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl ArchiveBackend for MockBackend { + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + async fn health_check(&self) -> HealthStatus { + HealthStatus::Healthy + } + + async fn put_session(&self, meta: SessionMetadata) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + self.sessions.lock().unwrap().insert(meta.scroll_id, meta); + Ok(()) + } + async fn get_session(&self, scroll_id: Uuid) -> Result> { + self.check_read_failure()?; + Ok(self.sessions.lock().unwrap().get(&scroll_id).cloned()) + } + async fn list_sessions_paged(&self, query: SessionListQuery) -> Result { + let mut items: Vec = + self.sessions.lock().unwrap().values().cloned().collect(); + if !query.connector_uids.is_empty() { + items.retain(|s| query.connector_uids.contains(&s.connector_uid)); + } + items.sort_by(|a, b| { + b.updated_at + .cmp(&a.updated_at) + .then(b.scroll_id.cmp(&a.scroll_id)) + }); + let limit = query.limit.min(crate::types::MAX_PAGE_LIMIT).max(1); + let total_count = items.len(); + let items: Vec<_> = items.into_iter().take(limit).collect(); + Ok(SessionPage { + items, + next_cursor: None, + total_count: Some(total_count), + }) + } + async fn delete_session(&self, scroll_id: Uuid) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + if self.sessions.lock().unwrap().remove(&scroll_id).is_none() { + return Err(ArchivistError::SessionUnknown(scroll_id)); + } + self.messages.lock().unwrap().remove(&scroll_id); + Ok(()) + } + + async fn append_messages( + &self, + scroll_id: Uuid, + msgs: Vec, + ) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + *self + .append_calls + .lock() + .unwrap() + .entry(scroll_id) + .or_insert(0) += 1; + self.messages + .lock() + .unwrap() + .entry(scroll_id) + .or_default() + .extend(msgs); + Ok(()) + } + async fn get_messages_paged( + &self, + scroll_id: Uuid, + cursor: Option, + limit: usize, + ) -> Result { + self.check_read_failure()?; + let mut all = self + .messages + .lock() + .unwrap() + .get(&scroll_id) + .cloned() + .unwrap_or_default(); + all.sort_by(|a, b| a.ts.cmp(&b.ts).then(a.message_id.cmp(&b.message_id))); + if let Some(c) = cursor.as_ref() { + all.retain(|m| (m.ts, m.message_id) > (c.ts, c.message_id)); + } + let total = all.len(); + let taken: Vec<_> = all.into_iter().take(limit.max(1)).collect(); + let next_cursor = if total > taken.len() { + taken.last().map(|m| MessageCursor { + ts: m.ts, + message_id: m.message_id, + }) + } else { + None + }; + Ok(MessagePage { + items: taken, + next_cursor, + }) + } + async fn count_messages(&self, scroll_id: Uuid) -> Result { + self.check_read_failure()?; + Ok(self + .messages + .lock() + .unwrap() + .get(&scroll_id) + .map(|v| v.len()) + .unwrap_or(0)) + } + async fn clear_session_messages(&self, scroll_id: Uuid) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + self.messages.lock().unwrap().remove(&scroll_id); + Ok(()) + } + + fn as_dag(&self) -> Option<&dyn DagBackend> { + if self.capabilities.contains(&ArchiveCapability::Dag) { + Some(self) + } else { + None + } + } + fn as_meta_events(&self) -> Option<&dyn MetaEventsBackend> { + if self.capabilities.contains(&ArchiveCapability::MetaEvents) { + Some(self) + } else { + None + } + } + fn as_connector_registry(&self) -> Option<&dyn ConnectorRegistryBackend> { + if self + .capabilities + .contains(&ArchiveCapability::ConnectorRegistry) + { + Some(self) + } else { + None + } + } + fn as_session_mapping(&self) -> Option<&dyn SessionMappingBackend> { + if self + .capabilities + .contains(&ArchiveCapability::SessionMapping) + { + Some(self) + } else { + None + } + } +} + +#[async_trait] +impl ConnectorRegistryBackend for MockBackend { + async fn put_connector(&self, record: ConnectorRecord) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + self.connectors + .lock() + .unwrap() + .insert(record.connector_uid, record); + Ok(()) + } + async fn get_connector(&self, connector_uid: Uuid) -> Result> { + Ok(self + .connectors + .lock() + .unwrap() + .get(&connector_uid) + .cloned()) + } + async fn list_connectors(&self) -> Result> { + Ok(self.connectors.lock().unwrap().values().cloned().collect()) + } + async fn resolve_connector_uid(&self, client_native_id: &str) -> Result> { + Ok(self + .connectors + .lock() + .unwrap() + .values() + .find(|c| c.client_native_id == client_native_id) + .map(|c| c.connector_uid)) + } + async fn update_connector_fingerprint( + &self, + connector_uid: Uuid, + fingerprint: String, + ) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + if let Some(r) = self.connectors.lock().unwrap().get_mut(&connector_uid) { + r.fingerprint = Some(fingerprint); + Ok(()) + } else { + Err(ArchivistError::ConnectorUnknown(connector_uid)) + } + } +} + +#[async_trait] +impl SessionMappingBackend for MockBackend { + async fn put_mapping( + &self, + connector_uid: Uuid, + native_session_id: &str, + scroll_id: Uuid, + ) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + self.mappings + .lock() + .unwrap() + .insert((connector_uid, native_session_id.to_string()), scroll_id); + Ok(()) + } + async fn get_mapping( + &self, + connector_uid: Uuid, + native_session_id: &str, + ) -> Result> { + Ok(self + .mappings + .lock() + .unwrap() + .get(&(connector_uid, native_session_id.to_string())) + .copied()) + } + async fn list_mappings_for_connector( + &self, + connector_uid: Uuid, + ) -> Result> { + Ok(self + .mappings + .lock() + .unwrap() + .iter() + .filter(|((c, _), _)| *c == connector_uid) + .map(|((c, n), s)| SessionMapping { + version: 1, + connector_uid: *c, + native_session_id: n.clone(), + scroll_id: *s, + created_at: chrono::Utc::now(), + alias_of: None, + }) + .collect()) + } + async fn find_owner(&self, native_session_id: &str) -> Result> { + Ok(self + .mappings + .lock() + .unwrap() + .iter() + .find(|((_, n), _)| n == native_session_id) + .map(|((c, _), s)| (*c, *s))) + } + + async fn rewrite_connector_mappings( + &self, + connector_uid: Uuid, + mappings: Vec, + ) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + let mut map = self.mappings.lock().unwrap(); + map.retain(|(c, _), _| *c != connector_uid); + for m in mappings { + map.insert((connector_uid, m.native_session_id), m.scroll_id); + } + Ok(()) + } +} + +#[async_trait] +impl DagBackend for MockBackend { + async fn append_dag_edge(&self, edge: DagEdge) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + self.dag_edges.lock().unwrap().push(edge); + Ok(()) + } + async fn get_children(&self, parent: Uuid) -> Result> { + self.check_read_failure()?; + let edges = self.dag_edges.lock().unwrap(); + let sessions = self.sessions.lock().unwrap(); + Ok(edges + .iter() + .filter(|e| e.parent == parent) + .filter_map(|e| sessions.get(&e.child).cloned()) + .collect()) + } + async fn get_dag_edges(&self, root: Uuid) -> Result> { + self.check_read_failure()?; + Ok(self + .dag_edges + .lock() + .unwrap() + .iter() + .filter(|e| e.parent == root) + .cloned() + .collect()) + } +} + +#[async_trait] +impl MetaEventsBackend for MockBackend { + async fn append_meta_events( + &self, + scroll_id: Uuid, + events: Vec, + ) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + self.meta_events + .lock() + .unwrap() + .entry(scroll_id) + .or_default() + .extend(events); + Ok(()) + } + async fn get_meta_events(&self, scroll_id: Uuid) -> Result> { + self.check_read_failure()?; + Ok(self + .meta_events + .lock() + .unwrap() + .get(&scroll_id) + .cloned() + .unwrap_or_default()) + } + async fn update_meta_session_status( + &self, + scroll_id: Uuid, + is_connected: bool, + current_session_id: Option, + ) -> Result<()> { + self.check_write_failure()?; + self.maybe_delay().await; + if let Some(s) = self.sessions.lock().unwrap().get_mut(&scroll_id) { + s.is_connected = Some(is_connected); + s.current_session_id = current_session_id; + Ok(()) + } else { + Err(ArchivistError::SessionUnknown(scroll_id)) + } + } + async fn list_meta_sessions(&self) -> Result> { + Ok(self + .sessions + .lock() + .unwrap() + .values() + .filter(|s| matches!(s.kind, crate::types::SessionKind::AcpConnection)) + .cloned() + .collect()) + } + async fn find_meta_session_by_client( + &self, + client_id: &str, + ) -> Result> { + Ok(self + .sessions + .lock() + .unwrap() + .values() + .find(|s| s.acp_client_id.as_deref() == Some(client_id)) + .cloned()) + } +} + +#[cfg(test)] +mod failure_injection_tests { + use super::*; + + #[tokio::test] + async fn injected_write_failure_returns_error_then_recovers() { + let m = MockBackend::new(); + m.inject_write_failures(2); + let scroll = uuid::Uuid::nil(); + assert!(m.append_messages(scroll, vec![]).await.is_err()); + assert!(m.append_messages(scroll, vec![]).await.is_err()); + assert!(m.append_messages(scroll, vec![]).await.is_ok()); // back to normal + } +} diff --git a/crates/dirigent_archivist/src/backend/mod.rs b/crates/dirigent_archivist/src/backend/mod.rs new file mode 100644 index 0000000..24fa7d3 --- /dev/null +++ b/crates/dirigent_archivist/src/backend/mod.rs @@ -0,0 +1,20 @@ +//! Archive backend trait layer. +//! +//! See `docs/plans/2026-04-18-archivist-phase2-design.md` for the design. + +pub mod capability; +pub mod health; +pub mod traits; + +#[cfg(any(test, feature = "test-utils"))] +pub mod contract; + +#[cfg(any(test, feature = "test-utils"))] +pub mod mock; + +pub use capability::{ArchiveCapability, CapabilitySet}; +pub use health::HealthStatus; +pub use traits::{ + ArchiveBackend, ConnectorRegistryBackend, DagBackend, MetaEventsBackend, + SearchBackend, SessionMappingBackend, +}; diff --git a/crates/dirigent_archivist/src/backend/traits.rs b/crates/dirigent_archivist/src/backend/traits.rs new file mode 100644 index 0000000..62b49aa --- /dev/null +++ b/crates/dirigent_archivist/src/backend/traits.rs @@ -0,0 +1,167 @@ +//! Archive backend trait definitions. +//! +//! `ArchiveBackend` is mandatory for every backend: session + message +//! primitives plus self-description (capabilities, health). Optional +//! sub-traits (`SearchBackend`, `DagBackend`, `MetaEventsBackend`, +//! `ConnectorRegistryBackend`, `SessionMappingBackend`) are surfaced +//! via `as_xxx() -> Option<&dyn SubTrait>` accessors returning a +//! borrow from `self`. +//! +//! See `docs/plans/2026-04-18-archivist-phase2-design.md` §Trait Definitions. + +use async_trait::async_trait; +use uuid::Uuid; + +use crate::backend::capability::CapabilitySet; +use crate::backend::health::HealthStatus; +use crate::error::Result; +use crate::types::{ + ConnectorRecord, DagEdge, MessageCursor, MessagePage, MessageRecord, + MetaEventRecord, SessionListQuery, SessionMapping, SessionMetadata, + SessionPage, +}; + +// --------------------------------------------------------------------------- +// Mandatory backend surface +// --------------------------------------------------------------------------- + +/// An archive storage backend. +/// +/// All backends must implement session metadata and message primitives; +/// optional capabilities are exposed through `as_xxx()` accessors that +/// return `None` when unsupported. `JsonlBackend` implements every +/// sub-trait except `SearchBackend`. +#[async_trait] +pub trait ArchiveBackend: Send + Sync { + // --- Self-description --- + fn capabilities(&self) -> &CapabilitySet; + async fn health_check(&self) -> HealthStatus; + + // --- Session metadata --- + async fn put_session(&self, meta: SessionMetadata) -> Result<()>; + async fn get_session(&self, scroll_id: Uuid) -> Result>; + async fn list_sessions_paged(&self, query: SessionListQuery) -> Result; + async fn delete_session(&self, scroll_id: Uuid) -> Result<()>; + + // --- Messages --- + async fn append_messages( + &self, + scroll_id: Uuid, + messages: Vec, + ) -> Result<()>; + async fn get_messages_paged( + &self, + scroll_id: Uuid, + cursor: Option, + limit: usize, + ) -> Result; + async fn count_messages(&self, scroll_id: Uuid) -> Result; + async fn clear_session_messages(&self, scroll_id: Uuid) -> Result<()>; + + // --- Optional capability accessors --- + fn as_search(&self) -> Option<&dyn SearchBackend> { + None + } + fn as_dag(&self) -> Option<&dyn DagBackend> { + None + } + fn as_meta_events(&self) -> Option<&dyn MetaEventsBackend> { + None + } + fn as_connector_registry(&self) -> Option<&dyn ConnectorRegistryBackend> { + None + } + fn as_session_mapping(&self) -> Option<&dyn SessionMappingBackend> { + None + } +} + +// --------------------------------------------------------------------------- +// Optional sub-traits +// --------------------------------------------------------------------------- + +/// Content search. Reserved in Phase 2; not wired to `JsonlBackend`. +/// +/// `packages/api/src/archivist/search_task.rs` continues to serve content +/// search via ripgrep — this trait exists as a forward-compatible hook for +/// indexed backends (ChromaDB, tantivy, …) arriving in Phase 3+. +#[async_trait] +pub trait SearchBackend: Send + Sync { + // Deliberately left without methods; Phase 3 adds the concrete + // query/result shapes when a real indexed backend lands. +} + +#[async_trait] +pub trait DagBackend: Send + Sync { + async fn append_dag_edge(&self, edge: DagEdge) -> Result<()>; + async fn get_children(&self, parent: Uuid) -> Result>; + async fn get_dag_edges(&self, root: Uuid) -> Result>; +} + +#[async_trait] +pub trait MetaEventsBackend: Send + Sync { + async fn append_meta_events( + &self, + scroll_id: Uuid, + events: Vec, + ) -> Result<()>; + async fn get_meta_events(&self, scroll_id: Uuid) -> Result>; + async fn update_meta_session_status( + &self, + scroll_id: Uuid, + is_connected: bool, + current_session_id: Option, + ) -> Result<()>; + async fn list_meta_sessions(&self) -> Result>; + async fn find_meta_session_by_client( + &self, + client_id: &str, + ) -> Result>; +} + +#[async_trait] +pub trait ConnectorRegistryBackend: Send + Sync { + async fn put_connector(&self, record: ConnectorRecord) -> Result<()>; + async fn get_connector(&self, connector_uid: Uuid) -> Result>; + async fn list_connectors(&self) -> Result>; + async fn resolve_connector_uid(&self, client_native_id: &str) -> Result>; + async fn update_connector_fingerprint( + &self, + connector_uid: Uuid, + fingerprint: String, + ) -> Result<()>; +} + +#[async_trait] +pub trait SessionMappingBackend: Send + Sync { + async fn put_mapping( + &self, + connector_uid: Uuid, + native_session_id: &str, + scroll_id: Uuid, + ) -> Result<()>; + async fn get_mapping( + &self, + connector_uid: Uuid, + native_session_id: &str, + ) -> Result>; + async fn list_mappings_for_connector( + &self, + connector_uid: Uuid, + ) -> Result>; + async fn find_owner(&self, native_session_id: &str) -> Result>; + + /// Replace the entire mapping table for `connector_uid` with `mappings`. + /// + /// Phase 2 uses this to remove an individual mapping — callers read the + /// current table via `list_mappings_for_connector`, filter out the + /// unwanted row, and call this method with the remainder. Implementations + /// must also invalidate any in-memory cache entries that reference the + /// removed rows so subsequent `get_mapping` / `find_owner` calls don't + /// return stale hits. + async fn rewrite_connector_mappings( + &self, + connector_uid: Uuid, + mappings: Vec, + ) -> Result<()>; +} diff --git a/crates/dirigent_archivist/src/backends/jsonl/backend.rs b/crates/dirigent_archivist/src/backends/jsonl/backend.rs new file mode 100644 index 0000000..4e5ae54 --- /dev/null +++ b/crates/dirigent_archivist/src/backends/jsonl/backend.rs @@ -0,0 +1,624 @@ +//! `JsonlBackend` — the Phase 2 concrete backend. + +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; + +use async_trait::async_trait; +use chrono::Utc; +use tokio::sync::RwLock; +use uuid::Uuid; + +use crate::backend::{ + ArchiveBackend, ArchiveCapability, CapabilitySet, ConnectorRegistryBackend, + DagBackend, HealthStatus, MetaEventsBackend, SessionMappingBackend, +}; +use crate::error::{ArchivistError, Result}; +use crate::storage::{ + append_ndjson, read_connector_index, read_json, read_ndjson, write_json, ArchivePaths, +}; +use crate::types::{ + ConnectorRecord, MessageCursor, MessagePage, MessageRecord, SessionCompleteness, + SessionKind, SessionListQuery, SessionMapping, SessionMetadata, SessionPage, +}; + +/// NDJSON/JSON/TSV file-based `ArchiveBackend`. +pub struct JsonlBackend { + pub(crate) paths: ArchivePaths, + pub(crate) connector_cache: RwLock>, + pub(crate) session_cache: RwLock>, + pub(crate) capabilities: CapabilitySet, +} + +impl JsonlBackend { + /// Create a new backend rooted at `archive_root`. + /// + /// Creates the required directories (`.contexts`, `.db/connectors`, `.files`) + /// and initializes empty caches. Matches `FileBasedArchivist::new`. + pub async fn new(archive_root: PathBuf) -> Result { + let paths = ArchivePaths::new(archive_root); + + tokio::fs::create_dir_all(paths.root().join(".contexts")).await?; + tokio::fs::create_dir_all(paths.root().join(".db").join("connectors")).await?; + tokio::fs::create_dir_all(paths.root().join(".files")).await?; + + let mut capabilities = HashSet::new(); + capabilities.insert(ArchiveCapability::Dag); + capabilities.insert(ArchiveCapability::MetaEvents); + capabilities.insert(ArchiveCapability::ConnectorRegistry); + capabilities.insert(ArchiveCapability::SessionMapping); + + Ok(Self { + paths, + connector_cache: RwLock::new(HashMap::new()), + session_cache: RwLock::new(HashMap::new()), + capabilities, + }) + } + + /// Filesystem path utilities for this backend. + pub fn paths(&self) -> &ArchivePaths { + &self.paths + } + + /// Read and chronologically sort all messages for a session. + /// + /// See module docs for the append-order vs. chronological-order rationale. + pub(crate) async fn read_messages_sorted( + &self, + scroll_id: Uuid, + ) -> Result> { + let path = self.paths.messages_path_for_read(scroll_id); + let mut msgs: Vec = + read_ndjson(&path).await.unwrap_or_default(); + msgs.sort_by(|a, b| { + a.ts.cmp(&b.ts).then(a.message_id.cmp(&b.message_id)) + }); + Ok(msgs) + } + + /// Locate the (connector_uid, native_session_id) owning `scroll_id` by + /// scanning the session cache first, then each connector's session + /// mapping files on disk. + async fn find_mapping_for_scroll_id(&self, scroll_id: Uuid) -> Option<(Uuid, String)> { + // Check cache first + { + let cache = self.session_cache.read().await; + for ((connector_uid, native_id), cached_scroll_id) in cache.iter() { + if *cached_scroll_id == scroll_id { + return Some((*connector_uid, native_id.clone())); + } + } + } + + // Cache miss: scan connector index and each connector's sessions file + let index_path = self.paths.connector_index_tsv(); + let rows = match read_connector_index(&index_path).await { + Ok(rows) => rows, + Err(_) => return None, + }; + + for row in &rows { + let sessions_path = self.paths.sessions_path_for_read(row.connector_uid); + let mappings: Vec = match read_ndjson(&sessions_path).await { + Ok(m) => m, + Err(_) => continue, + }; + for mapping in mappings { + if mapping.scroll_id == scroll_id { + return Some((row.connector_uid, mapping.native_session_id)); + } + } + } + + None + } + + /// Load every session for a connector, including hidden ones. Used by + /// `list_sessions_paged` — it applies visibility filters itself. + async fn load_sessions_for_connector( + &self, + connector_uid: Uuid, + ) -> Result> { + let sessions_path = self.paths.sessions_path_for_read(connector_uid); + let mappings: Vec = read_ndjson(&sessions_path).await?; + + let mut sessions = Vec::new(); + for mapping in mappings { + let session_json_path = self.paths.session_json(mapping.scroll_id); + match read_json::(&session_json_path).await { + Ok(metadata) => sessions.push(metadata), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + tracing::debug!( + scroll_id = %mapping.scroll_id, + "session.json missing, surfacing as Discovered stub" + ); + sessions.push(SessionMetadata { + version: 1, + scroll_id: mapping.scroll_id, + created_at: mapping.created_at, + updated_at: mapping.created_at, + title: None, + connector_uid, + native_session_id: Some(mapping.native_session_id.clone()), + agent_id: None, + parent_scroll_id: None, + continuation: None, + tags: Vec::new(), + metadata: serde_json::json!({}), + no_update: false, + kind: SessionKind::Chat, + acp_client_id: None, + is_connected: None, + current_session_id: None, + models: None, + modes: None, + config_options: None, + completeness: SessionCompleteness::Discovered, + matrix_room_id: None, + matrix_sharing_active: false, + matrix_shared_at: None, + is_subagent: false, + subagent_type: None, + spawning_tool_use_id: None, + }); + } + Err(e) => return Err(e.into()), + } + } + Ok(sessions) + } +} + +/// Returns true if `session` satisfies every filter in `query`. +/// +/// `connector_uid` is already honored by the caller (it picks which connector +/// directories to scan), so we do not re-check it here. +fn matches_query( + session: &SessionMetadata, + query: &crate::types::SessionListQuery, +) -> bool { + // Visibility + if !query.include_hidden && (session.no_update || session.is_subagent) { + return false; + } + + // Project scope — project_ids lives in metadata.project_id + if !query.project_ids.is_empty() { + let session_project_id = session + .metadata + .get("project_id") + .and_then(|v| v.as_str()); + match session_project_id { + Some(pid) => { + if !query.project_ids.iter().any(|q| q.as_str() == pid) { + return false; + } + } + None => return false, + } + } + + // Project path filter — exact match on metadata.project_path + if let Some(ref path) = query.project_path { + let session_path = session + .metadata + .get("project_path") + .and_then(|v| v.as_str()); + if session_path != Some(path.as_str()) { + return false; + } + } + + // Title filter — case-insensitive substring. + if let Some(q) = query.title_query.as_ref() { + let needle = q.to_lowercase(); + let haystack = match session.title.as_ref() { + Some(t) => t.to_lowercase(), + None => return false, + }; + if !haystack.contains(&needle) { + return false; + } + } + + // Tag filter — all requested tags must be present on the session. + if !query.tags.is_empty() { + for required in &query.tags { + if !session.tags.iter().any(|t| t == required) { + return false; + } + } + } + + // Model filter — case-insensitive substring on metadata.model. + if let Some(q) = query.model_filter.as_ref() { + let needle = q.to_lowercase(); + let haystack = session + .metadata + .get("model") + .and_then(|v| v.as_str()) + .map(|s| s.to_lowercase()); + match haystack { + Some(h) if h.contains(&needle) => {} + _ => return false, + } + } + + true +} + +#[async_trait] +impl ArchiveBackend for JsonlBackend { + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn health_check(&self) -> HealthStatus { + match tokio::fs::metadata(self.paths.root()).await { + Ok(m) if m.is_dir() => HealthStatus::Healthy, + Ok(_) => HealthStatus::Unavailable { + reason: "archive root is not a directory".into(), + }, + Err(e) => HealthStatus::Unavailable { + reason: format!("stat archive root failed: {e}"), + }, + } + } + + async fn put_session(&self, meta: SessionMetadata) -> Result<()> { + tokio::fs::create_dir_all(&self.paths.session_dir(meta.scroll_id)).await?; + write_json(&self.paths.session_json(meta.scroll_id), &meta).await?; + Ok(()) + } + + async fn get_session(&self, scroll_id: Uuid) -> Result> { + // FileBasedArchivist ignores archive parameter (single-archive only) + let session_json_path = self.paths.session_json(scroll_id); + + match read_json(&session_json_path).await { + Ok(metadata) => Ok(Some(metadata)), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None), + Err(e) => Err(e.into()), + } + } + + async fn list_sessions_paged(&self, query: SessionListQuery) -> Result { + use crate::types::{SessionCursor, SessionPage, MAX_PAGE_LIMIT}; + + // Determine which connectors to scan. + let connector_uids: Vec = if !query.connector_uids.is_empty() { + query.connector_uids.clone() + } else { + // Iterate every primary (non-alias) connector. + let index_path = self.paths.connector_index_tsv(); + let connectors = read_connector_index(&index_path).await?; + connectors + .into_iter() + .filter(|c| c.alias_of.is_none()) + .map(|c| c.connector_uid) + .collect() + }; + + // Stream matching sessions from every selected connector. + let mut matched: Vec = Vec::new(); + + for connector_uid in connector_uids { + let sessions = match self.load_sessions_for_connector(connector_uid).await { + Ok(s) => s, + Err(e) => { + tracing::warn!( + connector_uid = %connector_uid, + error = %e, + "Failed to list sessions for connector during paged scan, skipping" + ); + continue; + } + }; + + for session in sessions { + if !matches_query(&session, &query) { + continue; + } + matched.push(session); + } + } + + // Sort by (updated_at DESC, scroll_id DESC). + matched.sort_by(|a, b| { + b.updated_at + .cmp(&a.updated_at) + .then_with(|| b.scroll_id.cmp(&a.scroll_id)) + }); + + // Skip entries at-or-before the cursor. + if let Some(cursor) = query.cursor.as_ref() { + matched.retain(|s| { + (s.updated_at, s.scroll_id) < (cursor.updated_at, cursor.scroll_id) + }); + } + + // Capture total count before slicing. + let total_count = matched.len(); + + // Clamp limit and paginate. + let effective_limit = query.limit.min(MAX_PAGE_LIMIT).max(1); + let has_more = matched.len() > effective_limit; + matched.truncate(effective_limit); + + let next_cursor = if has_more { + matched.last().map(|s| SessionCursor { + updated_at: s.updated_at, + scroll_id: s.scroll_id, + }) + } else { + None + }; + + Ok(SessionPage { + items: matched, + next_cursor, + total_count: Some(total_count), + }) + } + + async fn delete_session(&self, scroll_id: Uuid) -> Result<()> { + // FileBasedArchivist ignores archive parameter (single-archive only) + + // First, read session metadata to get connector_uid and native_session_id + let session_dir = self.paths.session_dir(scroll_id); + let session_json_path = self.paths.session_json(scroll_id); + + if !session_dir.exists() { + return Err(ArchivistError::SessionUnknown(scroll_id)); + } + + // Read session metadata to get connector info + let metadata: SessionMetadata = read_json(&session_json_path).await?; + let connector_uid = metadata.connector_uid; + let native_session_id = metadata.native_session_id.clone(); + + // Delete the session directory and all its contents + tokio::fs::remove_dir_all(&session_dir).await.map_err(|e| { + tracing::error!("Failed to delete session directory {:?}: {}", session_dir, e); + ArchivistError::Io(e) + })?; + + tracing::info!( + "Deleted session directory for scroll_id: {}", + scroll_id + ); + + // Remove from session cache + if let Some(native_id) = &native_session_id { + let mut cache = self.session_cache.write().await; + cache.remove(&(connector_uid, native_id.clone())); + } + + // Note: We're not removing from sessions.ndjson because it's append-only. + // The session simply won't have a directory anymore, so list_sessions will skip it. + // A future enhancement could add a "deleted" flag or periodic compaction. + + tracing::info!( + "Successfully deleted session {} (connector: {})", + scroll_id, + connector_uid + ); + + Ok(()) + } + + async fn append_messages( + &self, + scroll_id: Uuid, + messages: Vec, + ) -> Result<()> { + // Ensure session directory exists (handles resync case where directory was deleted) + self.paths.ensure_dirs(scroll_id).await?; + + // Append each message to messages.jsonl + let messages_path = self.paths.messages_path_for_write(scroll_id); + for message in &messages { + append_ndjson(&messages_path, message).await?; + } + + // Update session.json timestamp (or create if missing) + let session_json_path = self.paths.session_json(scroll_id); + let now = Utc::now(); + + let session_metadata = match read_json::(&session_json_path).await { + Ok(mut metadata) => { + metadata.updated_at = now; + metadata + } + Err(_) => { + // session.json doesn't exist, create minimal metadata + // This handles resync case where directory was deleted but mapping still exists + tracing::info!( + scroll_id = %scroll_id, + "Creating minimal session.json during append (was missing)" + ); + + // Look up the correct connector_uid and native_session_id via session mappings + let (connector_uid, native_session_id) = match self.find_mapping_for_scroll_id(scroll_id).await { + Some(mapping) => mapping, + None => { + tracing::error!( + scroll_id = %scroll_id, + "Cannot reconstruct session.json: no connector mapping found. \ + Messages written but session metadata will remain missing." + ); + return Ok(()); + } + }; + + SessionMetadata { + version: 1, + scroll_id, + created_at: now, + updated_at: now, + title: None, + connector_uid, + native_session_id: Some(native_session_id), + agent_id: None, + parent_scroll_id: None, + continuation: None, + tags: Vec::new(), + metadata: serde_json::json!({}), + no_update: false, + kind: SessionKind::Chat, + acp_client_id: None, + is_connected: None, + current_session_id: None, + models: None, + modes: None, + config_options: None, + completeness: SessionCompleteness::default(), + matrix_room_id: None, + matrix_sharing_active: false, + matrix_shared_at: None, + is_subagent: false, + subagent_type: None, + spawning_tool_use_id: None, + } + } + }; + + write_json(&session_json_path, &session_metadata).await?; + + Ok(()) + } + + async fn get_messages_paged( + &self, + scroll_id: Uuid, + cursor: Option, + limit: usize, + ) -> Result { + use crate::types::MAX_PAGE_LIMIT; + + // Hard-clamp limit — same policy as sessions. + let effective_limit = limit.min(MAX_PAGE_LIMIT).max(1); + + // Read NDJSON, sort, apply cursor. + let mut all = self.read_messages_sorted(scroll_id).await?; + + if let Some(c) = cursor.as_ref() { + // Keep strictly-after the cursor point in (ts, message_id) order. + all.retain(|m| (m.ts, m.message_id) > (c.ts, c.message_id)); + } + + let total = all.len(); + let taken: Vec<_> = all.into_iter().take(effective_limit).collect(); + + let next_cursor = if total > taken.len() { + taken.last().map(|m| MessageCursor { + ts: m.ts, + message_id: m.message_id, + }) + } else { + None + }; + + Ok(MessagePage { + items: taken, + next_cursor, + }) + } + + async fn count_messages(&self, scroll_id: Uuid) -> Result { + let messages_path = self.paths.messages_path_for_read(scroll_id); + + // Read file and count lines (each line = one message) + // If file doesn't exist, return 0 (empty session) + match tokio::fs::read_to_string(&messages_path).await { + Ok(content) => { + // Count non-empty lines + let count = content.lines().filter(|line| !line.trim().is_empty()).count(); + Ok(count) + } + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + // File doesn't exist yet - empty session + Ok(0) + } + Err(e) => Err(e.into()), + } + } + + async fn clear_session_messages(&self, scroll_id: Uuid) -> Result<()> { + // First, verify the session exists by reading its metadata + let session_json_path = self.paths.session_json(scroll_id); + let mut session_metadata: SessionMetadata = match read_json(&session_json_path).await { + Ok(metadata) => metadata, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + return Err(ArchivistError::SessionUnknown(scroll_id)); + } + Err(e) => return Err(e.into()), + }; + + // Truncate the messages file (clear all messages) + // First try to clear .jsonl (new format), then fall back to .ndjson (legacy) + let jsonl_path = self.paths.messages_path_for_write(scroll_id); + #[allow(deprecated)] + let ndjson_path = self.paths.messages_ndjson(scroll_id); + + let mut cleared = false; + + // Clear .jsonl if it exists + if jsonl_path.exists() { + tokio::fs::write(&jsonl_path, "").await?; + cleared = true; + } + + // Also clear .ndjson if it exists (in case both are present) + if ndjson_path.exists() { + tokio::fs::write(&ndjson_path, "").await?; + cleared = true; + } + + if cleared { + tracing::info!( + scroll_id = %scroll_id, + "Cleared all messages from session" + ); + } + + // Update the session's updated_at timestamp + session_metadata.updated_at = Utc::now(); + write_json(&session_json_path, &session_metadata).await?; + + tracing::info!( + scroll_id = %scroll_id, + "Updated session metadata after clearing messages" + ); + + Ok(()) + } + + fn as_dag(&self) -> Option<&dyn DagBackend> { + Some(self) + } + fn as_meta_events(&self) -> Option<&dyn MetaEventsBackend> { + Some(self) + } + fn as_connector_registry(&self) -> Option<&dyn ConnectorRegistryBackend> { + Some(self) + } + fn as_session_mapping(&self) -> Option<&dyn SessionMappingBackend> { + Some(self) + } +} + +#[cfg(test)] +mod contract_tests { + use super::*; + use tempfile::tempdir; + + #[tokio::test] + async fn jsonl_backend_honors_all_contracts() { + let dir = tempdir().expect("tempdir"); + let backend = JsonlBackend::new(dir.path().to_path_buf()) + .await + .expect("new"); + crate::backend::contract::verify_all_contracts(&backend).await; + } +} diff --git a/crates/dirigent_archivist/src/backends/jsonl/connectors.rs b/crates/dirigent_archivist/src/backends/jsonl/connectors.rs new file mode 100644 index 0000000..2373e5b --- /dev/null +++ b/crates/dirigent_archivist/src/backends/jsonl/connectors.rs @@ -0,0 +1,161 @@ +//! `ConnectorRegistryBackend` impl for `JsonlBackend`. + +use async_trait::async_trait; +use uuid::Uuid; + +use crate::backend::ConnectorRegistryBackend; +use crate::backends::jsonl::backend::JsonlBackend; +use crate::error::{ArchivistError, Result}; +use crate::storage::{ + read_connector_index, read_json, write_connector_index, write_json, +}; +use crate::types::{ConnectorIndexRow, ConnectorRecord}; + +#[async_trait] +impl ConnectorRegistryBackend for JsonlBackend { + async fn put_connector(&self, record: ConnectorRecord) -> Result<()> { + // Write connector.json + let connector_dir = self.paths.connector_dir(record.connector_uid); + tokio::fs::create_dir_all(&connector_dir).await?; + write_json(&connector_dir.join("connector.json"), &record).await?; + + // Append row to index.tsv (read-modify-write). + let index_path = self.paths.connector_index_tsv(); + let mut rows = read_connector_index(&index_path).await?; + rows.push(ConnectorIndexRow { + connector_uid: record.connector_uid, + r#type: record.r#type.clone(), + title: record.title.clone(), + client_native_id: record.client_native_id.clone(), + alias_of: record.alias_of, + created_at: record.created_at, + fingerprint: record.fingerprint.clone(), + }); + write_connector_index(&index_path, &rows).await?; + + // Update cache + self.connector_cache + .write() + .await + .insert(record.connector_uid, record); + + Ok(()) + } + + async fn get_connector(&self, connector_uid: Uuid) -> Result> { + // Fast path: consult the in-memory cache. + { + let cache = self.connector_cache.read().await; + if let Some(record) = cache.get(&connector_uid) { + return Ok(Some(record.clone())); + } + } + + // Disk fallback. + let connector_json = self + .paths + .connector_dir(connector_uid) + .join("connector.json"); + match read_json::(&connector_json).await { + Ok(record) => Ok(Some(record)), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None), + Err(e) => Err(e.into()), + } + } + + async fn list_connectors(&self) -> Result> { + let index_path = self.paths.connector_index_tsv(); + let rows = read_connector_index(&index_path).await?; + let mut connectors = Vec::new(); + for row in rows { + if row.alias_of.is_some() { + continue; + } + let connector_json = self + .paths + .connector_dir(row.connector_uid) + .join("connector.json"); + match read_json::(&connector_json).await { + Ok(record) => connectors.push(record), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => continue, + Err(e) => return Err(e.into()), + } + } + Ok(connectors) + } + + async fn resolve_connector_uid( + &self, + client_native_id: &str, + ) -> Result> { + // First, try parsing client_native_id as a UUID directly + // This handles the common case where the connector_id IS the UUID + if let Ok(uuid) = Uuid::parse_str(client_native_id) { + // Check if this UUID is a registered connector_uid in cache + let cache = self.connector_cache.read().await; + if cache.contains_key(&uuid) { + return Ok(Some(uuid)); + } + drop(cache); + + // Check on disk if not in cache + let connector_json = self.paths.connector_dir(uuid).join("connector.json"); + if connector_json.exists() { + return Ok(Some(uuid)); + } + } + + // Not a UUID or not registered as a connector_uid - search by client_native_id + // Load connector index and find by client_native_id + let index_path = self.paths.connector_index_tsv(); + let connectors = read_connector_index(&index_path).await?; + + if let Some(connector) = connectors + .iter() + .find(|c| c.client_native_id == client_native_id) + { + return Ok(Some(connector.connector_uid)); + } + + // Not found - return Ok(None). Error wrapping is a coordinator concern. + tracing::warn!( + "Failed to resolve connector_uid for client_native_id '{}'. \ + This connector may not be registered with the archivist.", + client_native_id + ); + + Ok(None) + } + + async fn update_connector_fingerprint( + &self, + connector_uid: Uuid, + fingerprint: String, + ) -> Result<()> { + // 1. Read and update connector.json + let connector_dir = self.paths.connector_dir(connector_uid); + let connector_json = connector_dir.join("connector.json"); + let mut record: ConnectorRecord = read_json(&connector_json) + .await + .map_err(|_| ArchivistError::ConnectorUnknown(connector_uid))?; + + record.fingerprint = Some(fingerprint.clone()); + write_json(&connector_json, &record).await?; + + // 2. Update in-memory cache + self.connector_cache + .write() + .await + .insert(connector_uid, record); + + // 3. Update index.tsv + let index_path = self.paths.connector_index_tsv(); + let mut rows = read_connector_index(&index_path).await?; + if let Some(row) = rows.iter_mut().find(|r| r.connector_uid == connector_uid) { + row.fingerprint = Some(fingerprint); + } + write_connector_index(&index_path, &rows).await?; + + Ok(()) + } +} diff --git a/crates/dirigent_archivist/src/backends/jsonl/dag.rs b/crates/dirigent_archivist/src/backends/jsonl/dag.rs new file mode 100644 index 0000000..4c70fb2 --- /dev/null +++ b/crates/dirigent_archivist/src/backends/jsonl/dag.rs @@ -0,0 +1,69 @@ +//! `DagBackend` impl for `JsonlBackend`. + +use async_trait::async_trait; +use uuid::Uuid; + +use crate::backend::DagBackend; +use crate::backends::jsonl::backend::JsonlBackend; +use crate::error::Result; +use crate::storage::{append_ndjson, read_ndjson}; +use crate::types::{DagEdge, SessionMetadata}; + +#[async_trait] +impl DagBackend for JsonlBackend { + async fn append_dag_edge(&self, edge: DagEdge) -> Result<()> { + let dag_path = self.paths.dag_path(); + if let Some(parent) = dag_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + append_ndjson(&dag_path, &edge).await?; + Ok(()) + } + + async fn get_children(&self, parent: Uuid) -> Result> { + let dag_path = self.paths.dag_path(); + let edges: Vec = read_ndjson(&dag_path).await.unwrap_or_default(); + + let child_ids: Vec = edges + .iter() + .filter(|e| e.parent == parent) + .map(|e| e.child) + .collect(); + + let mut children = Vec::new(); + for child_id in child_ids { + match crate::backend::ArchiveBackend::get_session(self, child_id).await { + Ok(Some(meta)) => children.push(meta), + Ok(None) => { + tracing::warn!( + child_scroll_id = %child_id, + "DAG child session not found" + ); + } + Err(e) => { + tracing::warn!( + child_scroll_id = %child_id, + error = %e, + "DAG child session not found" + ); + } + } + } + + Ok(children) + } + + async fn get_dag_edges(&self, root: Uuid) -> Result> { + // Single-level read: return edges whose parent == root. + // The recursive DAG walk is coordinator-level orchestration. + let dag_path = self.paths.dag_path(); + let all_edges: Vec = read_ndjson(&dag_path).await.unwrap_or_default(); + + let edges = all_edges + .into_iter() + .filter(|e| e.parent == root) + .collect(); + + Ok(edges) + } +} diff --git a/crates/dirigent_archivist/src/backends/jsonl/mapping.rs b/crates/dirigent_archivist/src/backends/jsonl/mapping.rs new file mode 100644 index 0000000..b9c8b00 --- /dev/null +++ b/crates/dirigent_archivist/src/backends/jsonl/mapping.rs @@ -0,0 +1,179 @@ +//! `SessionMappingBackend` impl for `JsonlBackend`. + +use async_trait::async_trait; +use chrono::Utc; +use uuid::Uuid; + +use crate::backend::SessionMappingBackend; +use crate::backends::jsonl::backend::JsonlBackend; +use crate::error::Result; +use crate::storage::{append_ndjson, read_connector_index, read_ndjson, write_ndjson}; +use crate::types::SessionMapping; + +#[async_trait] +impl SessionMappingBackend for JsonlBackend { + async fn put_mapping( + &self, + connector_uid: Uuid, + native_session_id: &str, + scroll_id: Uuid, + ) -> Result<()> { + // Ported from the mapping-persistence tail of + // `FileBasedArchivist::register_session`: ensure the connector + // directory exists, append a `SessionMapping` row to + // `.db/connectors/{uid}/sessions.jsonl`, and prime `session_cache`. + // + // No alias detection — the caller has already chosen `scroll_id`. + let now = Utc::now(); + + // Ensure connector directory exists before appending. + self.paths.ensure_connector_dir(connector_uid).await?; + + let session_mapping = SessionMapping { + version: 1, + connector_uid, + native_session_id: native_session_id.to_string(), + scroll_id, + created_at: now, + alias_of: None, + }; + + let sessions_write_path = self.paths.sessions_path_for_write(connector_uid); + append_ndjson(&sessions_write_path, &session_mapping).await?; + + // Prime the in-memory cache for fast resolution. + self.session_cache + .write() + .await + .insert((connector_uid, native_session_id.to_string()), scroll_id); + + Ok(()) + } + + async fn get_mapping( + &self, + connector_uid: Uuid, + native_session_id: &str, + ) -> Result> { + // Ported from `FileBasedArchivist::resolve_session`. Cache-first + // lookup; on miss, scan the connector's sessions file and populate + // the cache on hit. Unlike the archivist trait, a miss returns + // `Ok(None)` instead of `Err(SessionUnknown)`. + + // Check cache first + let cache_key = (connector_uid, native_session_id.to_string()); + { + let cache = self.session_cache.read().await; + if let Some(&scroll_id) = cache.get(&cache_key) { + return Ok(Some(scroll_id)); + } + } + + // Cache miss - load from disk + let sessions_path = self.paths.sessions_path_for_read(connector_uid); + let mappings: Vec = read_ndjson(&sessions_path).await?; + + // Find mapping by native_session_id + if let Some(mapping) = mappings + .iter() + .find(|m| m.native_session_id == native_session_id) + { + // Update cache + self.session_cache + .write() + .await + .insert(cache_key, mapping.scroll_id); + Ok(Some(mapping.scroll_id)) + } else { + Ok(None) + } + } + + async fn list_mappings_for_connector( + &self, + connector_uid: Uuid, + ) -> Result> { + // Read `.db/connectors/{uid}/sessions.jsonl` (with `.ndjson` + // fallback handled by `sessions_path_for_read` + `read_ndjson`). + let sessions_path = self.paths.sessions_path_for_read(connector_uid); + let mappings: Vec = + read_ndjson(&sessions_path).await.unwrap_or_default(); + Ok(mappings) + } + + async fn find_owner( + &self, + native_session_id: &str, + ) -> Result> { + // Ported verbatim from `FileBasedArchivist::find_session_owner`. + + // Fast path: scan in-memory session_cache + { + let cache = self.session_cache.read().await; + for ((connector_uid, cached_native_id), scroll_id) in cache.iter() { + if cached_native_id == native_session_id { + return Ok(Some((*connector_uid, *scroll_id))); + } + } + } + + // Slow path: read connector index and scan each connector's sessions file + let index_path = self.paths.connector_index_tsv(); + let rows = read_connector_index(&index_path).await?; + + for row in &rows { + // Skip alias connectors - only search primary connectors + if row.alias_of.is_some() { + continue; + } + + let sessions_path = self.paths.sessions_path_for_read(row.connector_uid); + let mappings: Vec = read_ndjson(&sessions_path).await?; + + if let Some(mapping) = mappings + .iter() + .find(|m| m.native_session_id == native_session_id) + { + // Cache the found mapping for future lookups + let cache_key = (row.connector_uid, native_session_id.to_string()); + self.session_cache + .write() + .await + .insert(cache_key, mapping.scroll_id); + + return Ok(Some((row.connector_uid, mapping.scroll_id))); + } + } + + Ok(None) + } + + async fn rewrite_connector_mappings( + &self, + connector_uid: Uuid, + mappings: Vec, + ) -> Result<()> { + // Ensure the connector directory exists before we write. + self.paths.ensure_connector_dir(connector_uid).await?; + + // Invalidate cache entries for this connector first, then re-prime + // from the new mapping set. Any (connector_uid, native_id) entry not + // present in `mappings` is dropped. + { + let mut cache = self.session_cache.write().await; + cache.retain(|(cu, _), _| *cu != connector_uid); + for m in &mappings { + cache.insert( + (connector_uid, m.native_session_id.clone()), + m.scroll_id, + ); + } + } + + // Truncate + re-write the canonical `.jsonl` table. + let write_path = self.paths.sessions_path_for_write(connector_uid); + write_ndjson(&write_path, &mappings).await?; + + Ok(()) + } +} diff --git a/crates/dirigent_archivist/src/backends/jsonl/meta.rs b/crates/dirigent_archivist/src/backends/jsonl/meta.rs new file mode 100644 index 0000000..90321fe --- /dev/null +++ b/crates/dirigent_archivist/src/backends/jsonl/meta.rs @@ -0,0 +1,200 @@ +//! `MetaEventsBackend` impl for `JsonlBackend`. + +use async_trait::async_trait; +use chrono::Utc; +use uuid::Uuid; + +use crate::backend::MetaEventsBackend; +use crate::backends::jsonl::backend::JsonlBackend; +use crate::error::{ArchivistError, Result}; +use crate::storage::{append_ndjson, read_json, read_ndjson, write_json}; +use crate::types::{ + MetaEventRecord, SessionCompleteness, SessionKind, SessionMetadata, +}; + +#[async_trait] +impl MetaEventsBackend for JsonlBackend { + async fn append_meta_events( + &self, + scroll_id: Uuid, + events: Vec, + ) -> Result<()> { + // Ensure session directory exists + self.paths.ensure_dirs(scroll_id).await?; + + // Append each event to events.jsonl + let events_path = self.paths.events_path(scroll_id); + for event in &events { + append_ndjson(&events_path, event).await?; + } + + // Update session.json timestamp + let session_json_path = self.paths.session_json(scroll_id); + let now = Utc::now(); + + let session_metadata = match read_json::(&session_json_path).await { + Ok(mut metadata) => { + metadata.updated_at = now; + metadata + } + Err(_) => { + // session.json doesn't exist, this shouldn't happen for meta sessions + // but we'll handle it gracefully + tracing::warn!( + scroll_id = %scroll_id, + "session.json missing when appending meta events, creating minimal metadata" + ); + + SessionMetadata { + version: 1, + scroll_id, + created_at: now, + updated_at: now, + title: None, + connector_uid: scroll_id, // Use scroll_id as placeholder + native_session_id: None, + agent_id: None, + parent_scroll_id: None, + continuation: None, + tags: Vec::new(), + metadata: serde_json::json!({}), + no_update: false, + kind: SessionKind::AcpConnection, + acp_client_id: None, + is_connected: None, + current_session_id: None, + models: None, + modes: None, + config_options: None, + completeness: SessionCompleteness::default(), + matrix_room_id: None, + matrix_sharing_active: false, + matrix_shared_at: None, + is_subagent: false, + subagent_type: None, + spawning_tool_use_id: None, + } + } + }; + + write_json(&session_json_path, &session_metadata).await?; + + Ok(()) + } + + async fn get_meta_events(&self, scroll_id: Uuid) -> Result> { + let events_path = self.paths.events_path(scroll_id); + + // Read events from events.jsonl + let mut events: Vec = read_ndjson(&events_path) + .await + .unwrap_or_else(|_| Vec::new()); + + // Sort by timestamp then event_id for stable ordering + events.sort_by(|a, b| { + a.ts.cmp(&b.ts).then_with(|| a.event_id.cmp(&b.event_id)) + }); + + Ok(events) + } + + async fn update_meta_session_status( + &self, + scroll_id: Uuid, + is_connected: bool, + current_session_id: Option, + ) -> Result<()> { + // Load existing session metadata + let session_json_path = self.paths.session_json(scroll_id); + + let mut session_metadata: SessionMetadata = match read_json(&session_json_path).await { + Ok(metadata) => metadata, + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + return Err(ArchivistError::SessionUnknown(scroll_id)); + } + Err(e) => return Err(e.into()), + }; + + // Update connection status fields + session_metadata.is_connected = Some(is_connected); + session_metadata.current_session_id = current_session_id; + session_metadata.updated_at = Utc::now(); + + // Write updated metadata back to disk + write_json(&session_json_path, &session_metadata).await?; + + tracing::info!( + scroll_id = %scroll_id, + is_connected = %is_connected, + current_session_id = ?current_session_id, + "Updated meta session status" + ); + + Ok(()) + } + + async fn list_meta_sessions(&self) -> Result> { + // Scan .contexts/ directory for all session.json files + let contexts_dir = self.paths.root().join(".contexts"); + + if !contexts_dir.exists() { + return Ok(Vec::new()); + } + + let mut meta_sessions = Vec::new(); + + // Read all session directories + let mut entries = tokio::fs::read_dir(&contexts_dir).await?; + + while let Some(entry) = entries.next_entry().await? { + if !entry.file_type().await?.is_dir() { + continue; + } + + let session_json_path = entry.path().join("session.json"); + + // Try to read session.json + match read_json::(&session_json_path).await { + Ok(metadata) => { + // Filter to only AcpConnection sessions + if metadata.kind == SessionKind::AcpConnection { + meta_sessions.push(metadata); + } + } + Err(e) if e.kind() == std::io::ErrorKind::NotFound => { + // Skip missing session files + continue; + } + Err(e) => { + tracing::warn!( + path = ?session_json_path, + error = %e, + "Failed to read session.json while listing meta sessions" + ); + continue; + } + } + } + + // Sort by updated_at descending (newest first) + meta_sessions.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); + + Ok(meta_sessions) + } + + async fn find_meta_session_by_client( + &self, + client_id: &str, + ) -> Result> { + // Use list_meta_sessions and filter by acp_client_id + let meta_sessions = self.list_meta_sessions().await?; + + let result = meta_sessions + .into_iter() + .find(|session| { + session.acp_client_id.as_deref() == Some(client_id) + }); + + Ok(result) + } +} diff --git a/crates/dirigent_archivist/src/backends/jsonl/mod.rs b/crates/dirigent_archivist/src/backends/jsonl/mod.rs new file mode 100644 index 0000000..01f9a00 --- /dev/null +++ b/crates/dirigent_archivist/src/backends/jsonl/mod.rs @@ -0,0 +1,12 @@ +//! NDJSON/JSON/TSV file-based backend. +//! +//! Ports the body of the former `FileBasedArchivist`. Uses the existing +//! `crate::storage` free-function primitives unchanged. + +mod backend; +mod connectors; +mod dag; +mod mapping; +mod meta; + +pub use backend::JsonlBackend; diff --git a/crates/dirigent_archivist/src/backends/mod.rs b/crates/dirigent_archivist/src/backends/mod.rs new file mode 100644 index 0000000..1866265 --- /dev/null +++ b/crates/dirigent_archivist/src/backends/mod.rs @@ -0,0 +1,5 @@ +//! Concrete backend implementations for `ArchiveBackend`. + +pub mod jsonl; + +pub use jsonl::JsonlBackend; diff --git a/crates/dirigent_archivist/src/backfill.rs b/crates/dirigent_archivist/src/backfill.rs new file mode 100644 index 0000000..27c6cbe --- /dev/null +++ b/crates/dirigent_archivist/src/backfill.rs @@ -0,0 +1,558 @@ +//! Backfill functionality for importing existing sessions from connectors. +//! +//! This module provides utilities to import sessions and messages from connectors +//! that support listing operations (like OpenCode connectors) into the Archivist. + +use futures::future::BoxFuture; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::coordinator::Archivist; +use crate::error::{ArchivistError, Result}; +use crate::types::{MessageRecord, RegisterSessionRequest, RegisterStatus}; +use dirigent_protocol::{Message, Session}; + +/// Statistics collected during a backfill operation. +/// +/// This provides a summary of what was imported and any errors encountered. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct BackfillStats { + /// Total number of sessions found in the connector + pub sessions_found: usize, + /// Number of sessions successfully imported (new registrations) + pub sessions_imported: usize, + /// Number of sessions skipped (already archived) + pub sessions_skipped: usize, + /// Total number of messages imported across all sessions + pub messages_imported: usize, + /// Error messages for sessions that failed to import + pub errors: Vec, +} + +impl BackfillStats { + /// Create a new BackfillStats with all counts at zero + pub fn new() -> Self { + Self { + sessions_found: 0, + sessions_imported: 0, + sessions_skipped: 0, + messages_imported: 0, + errors: Vec::new(), + } + } +} + +impl Default for BackfillStats { + fn default() -> Self { + Self::new() + } +} + +/// Backfill sessions from a connector into the archive. +/// +/// This function imports existing sessions from a connector by: +/// 1. Attempting to register each session with the archivist +/// 2. For newly registered sessions, fetching messages via the provided closure +/// 3. Appending fetched messages to the archive +/// 4. Collecting statistics on successes, failures, and skips +/// +/// # Arguments +/// +/// * `archivist` - The archivist to backfill into +/// * `connector_uid` - The UID of the connector being backfilled +/// * `sessions` - List of sessions to import (from connector's list_sessions()) +/// * `fetch_messages` - Async closure to fetch messages for a given native session ID +/// +/// # Returns +/// +/// Statistics about the backfill operation including counts and errors +/// +/// # Error Handling +/// +/// This function continues processing all sessions even if individual sessions fail. +/// Errors are collected in `BackfillStats.errors` rather than aborting the operation. +/// +/// # Example +/// +/// ```no_run +/// use dirigent_archivist::{Archivist, backfill_from_sessions}; +/// use dirigent_protocol::{Session, Message}; +/// use uuid::Uuid; +/// +/// # async fn example(archivist: &Archivist, sessions: Vec) { +/// let connector_uid = Uuid::now_v7(); +/// +/// let stats = backfill_from_sessions( +/// archivist, +/// connector_uid, +/// sessions, +/// |session_id| { +/// Box::pin(async move { +/// // Fetch messages from connector +/// // Return Vec +/// Ok(vec![]) +/// }) +/// } +/// ).await.unwrap(); +/// +/// println!("Imported {} sessions, {} messages", +/// stats.sessions_imported, +/// stats.messages_imported); +/// # } +/// ``` +pub async fn backfill_from_sessions( + archivist: &Archivist, + connector_uid: Uuid, + sessions: Vec, + fetch_messages: F, +) -> Result +where + F: Fn(&str) -> BoxFuture<'static, Result>> + Send + Sync, +{ + let mut stats = BackfillStats::new(); + stats.sessions_found = sessions.len(); + + for session in sessions { + let native_session_id = session.id.clone(); + + // Try to resolve the session - if it exists, skip it + match archivist + .resolve_session(connector_uid, &native_session_id, None) + .await + { + Ok(_scroll_id) => { + // Session already archived, skip + stats.sessions_skipped += 1; + continue; + } + Err(ArchivistError::SessionUnknown(_)) => { + // Session not found, proceed with import + } + Err(e) => { + // Unexpected error during resolution + stats.errors.push(format!( + "Failed to resolve session {}: {}", + native_session_id, e + )); + continue; + } + } + + // Register the session + let register_req = RegisterSessionRequest { + connector_uid, + native_session_id: native_session_id.clone(), + title: Some(session.title.clone()), + custom_scroll_id: None, // Let archivist generate + metadata: serde_json::to_value(&session.metadata) + .unwrap_or_else(|_| serde_json::json!({})), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let scroll_id = match archivist.register_session(register_req, None).await { + Ok(response) => { + match response.status { + RegisterStatus::Accepted => { + stats.sessions_imported += 1; + response.scroll_id + } + RegisterStatus::Aliased => { + // Already exists (shouldn't happen since we checked, but handle gracefully) + stats.sessions_skipped += 1; + continue; + } + RegisterStatus::Rejected => { + // Registration rejected (collision inconsistency) + stats.errors.push(format!( + "Session registration rejected for {}: UID collision", + native_session_id + )); + continue; + } + } + } + Err(e) => { + stats.errors.push(format!( + "Failed to register session {}: {}", + native_session_id, e + )); + continue; + } + }; + + // Fetch messages for this session + let messages = match fetch_messages(&native_session_id).await { + Ok(msgs) => msgs, + Err(e) => { + stats.errors.push(format!( + "Failed to fetch messages for session {}: {}", + native_session_id, e + )); + continue; + } + }; + + // Convert protocol messages to message records + let message_records: Vec = messages + .into_iter() + .map(|msg| convert_message_to_record(msg, scroll_id)) + .collect(); + + let message_count = message_records.len(); + + // Append messages to the archive + if let Err(e) = archivist + .append_messages(scroll_id, message_records, None) + .await + { + stats.errors.push(format!( + "Failed to append messages for session {}: {}", + native_session_id, e + )); + continue; + } + + stats.messages_imported += message_count; + } + + Ok(stats) +} + +/// Convert a dirigent_protocol::Message to a MessageRecord for archival. +/// +/// This function translates the protocol message format into the archivist's +/// internal storage format, extracting markdown content and metadata. +pub fn convert_message_to_record(msg: Message, scroll_id: Uuid) -> MessageRecord { + // Extract text content from message parts and convert to markdown + let mut md_parts = Vec::new(); + for part in &msg.content { + match part { + dirigent_protocol::MessagePart::Text { text } => { + md_parts.push(text.clone()); + } + dirigent_protocol::MessagePart::Thinking { text } => { + md_parts.push(format!("\n{}\n", text)); + } + dirigent_protocol::MessagePart::Code { language, code } => { + md_parts.push(format!("```{}\n{}\n```", language, code)); + } + dirigent_protocol::MessagePart::Tool { + tool, + tool_call_id: _, + input, + output, + } => { + let mut tool_text = + format!("**Tool: {}**\n\nInput:\n```json\n{}\n```", tool, input); + if let Some(out) = output { + tool_text.push_str(&format!("\n\nOutput:\n```json\n{}\n```", out)); + } + md_parts.push(tool_text); + } + dirigent_protocol::MessagePart::File { path, content } => { + md_parts.push(format!("**File: {}**\n\n```\n{}\n```", path, content)); + } + } + } + let content_md = md_parts.join("\n\n"); + + // Serialize original content parts for rich UI rendering + let content_parts = serde_json::to_value(&msg.content).ok(); + + // Convert role + let role = match msg.role { + dirigent_protocol::MessageRole::User => "user", + dirigent_protocol::MessageRole::Assistant => "assistant", + } + .to_string(); + + // Generate message ID from the protocol message ID or create new one + let message_id = Uuid::now_v7(); + + MessageRecord { + version: 1, + message_id, + session: scroll_id, + parent_id: None, + ts: msg.created_at, + role, + author: None, // Protocol messages don't have author field + content_md, + content_parts, + attachments: Vec::new(), // Would need to extract from message parts if supported + metadata: msg + .metadata + .and_then(|m| serde_json::to_value(m).ok()) + .unwrap_or_else(|| serde_json::json!({})), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::coordinator::Archivist; + use chrono::Utc; + use dirigent_protocol::{MessageRole, MessageStatus, SessionMetadata}; + use tempfile::TempDir; + + async fn setup_test_archivist() -> (Archivist, TempDir) { + let temp_dir = TempDir::new().unwrap(); + // Use `from_single_backend` so each test is isolated (no shared + // registry file in the tempdir's parent racing against siblings). + let backend = std::sync::Arc::new( + crate::backends::JsonlBackend::new(temp_dir.path().to_path_buf()) + .await + .unwrap(), + ); + let archivist = Archivist::from_single_backend("main".into(), backend) + .await + .unwrap(); + (archivist, temp_dir) + } + + fn create_test_session(id: &str, title: &str) -> Session { + Session { + id: id.to_string(), + title: title.to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + metadata: SessionMetadata { + project_path: "/test".to_string(), + model: Some("test-model".to_string()), + total_messages: 0, + system_message: None, + current_mode_id: None, + _meta: None, + project_id: None, + }, + cwd: None, + config_options: None, + acp_client_id: None, + models: None, + modes: None, + } + } + + fn create_test_message(id: &str, session_id: &str, role: MessageRole, text: &str) -> Message { + Message { + id: id.to_string(), + session_id: session_id.to_string(), + role, + created_at: Utc::now(), + content: vec![dirigent_protocol::MessagePart::Text { + text: text.to_string(), + }], + status: MessageStatus::Completed, + metadata: None, + } + } + + #[tokio::test] + async fn test_backfill_new_sessions() { + let (archivist, _temp) = setup_test_archivist().await; + + // Register connector first + let connector_uid = Uuid::now_v7(); + let connector_req = crate::types::RegisterConnectorRequest { + custom_uid: Some(connector_uid), + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "test-connector".to_string(), + metadata: serde_json::json!({}), + fingerprint: None, + }; + archivist + .register_connector(connector_req, None) + .await + .unwrap(); + + // Create test sessions + let sessions = vec![ + create_test_session("session-1", "Session 1"), + create_test_session("session-2", "Session 2"), + ]; + + // Mock message fetcher + let fetch_messages = |session_id: &str| { + let sid = session_id.to_string(); + Box::pin(async move { + Ok(vec![ + create_test_message("msg-1", &sid, MessageRole::User, "Hello"), + create_test_message("msg-2", &sid, MessageRole::Assistant, "Hi there"), + ]) + }) as BoxFuture<'static, Result>> + }; + + // Backfill + let stats = backfill_from_sessions(&archivist, connector_uid, sessions, fetch_messages) + .await + .unwrap(); + + // Verify stats + assert_eq!(stats.sessions_found, 2); + assert_eq!(stats.sessions_imported, 2); + assert_eq!(stats.sessions_skipped, 0); + assert_eq!(stats.messages_imported, 4); // 2 messages per session + assert_eq!(stats.errors.len(), 0); + } + + #[tokio::test] + async fn test_backfill_skips_existing_sessions() { + let (archivist, _temp) = setup_test_archivist().await; + + // Register connector first + let connector_uid = Uuid::now_v7(); + let connector_req = crate::types::RegisterConnectorRequest { + custom_uid: Some(connector_uid), + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "test-connector".to_string(), + metadata: serde_json::json!({}), + fingerprint: None, + }; + archivist + .register_connector(connector_req, None) + .await + .unwrap(); + + // Pre-register one session + let session1 = create_test_session("session-1", "Session 1"); + let req = RegisterSessionRequest { + connector_uid, + native_session_id: session1.id.clone(), + title: Some(session1.title.clone()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + archivist.register_session(req, None).await.unwrap(); + + // Create sessions including the pre-registered one + let sessions = vec![session1, create_test_session("session-2", "Session 2")]; + + // Mock message fetcher + let fetch_messages = |session_id: &str| { + let sid = session_id.to_string(); + Box::pin(async move { + Ok(vec![create_test_message( + "msg-1", + &sid, + MessageRole::User, + "Test", + )]) + }) as BoxFuture<'static, Result>> + }; + + // Backfill + let stats = backfill_from_sessions(&archivist, connector_uid, sessions, fetch_messages) + .await + .unwrap(); + + // Verify stats - session-1 should be skipped + assert_eq!(stats.sessions_found, 2); + assert_eq!(stats.sessions_imported, 1); // Only session-2 + assert_eq!(stats.sessions_skipped, 1); // session-1 already exists + assert_eq!(stats.messages_imported, 1); // Only messages from session-2 + assert_eq!(stats.errors.len(), 0); + } + + #[tokio::test] + async fn test_backfill_handles_fetch_errors() { + let (archivist, _temp) = setup_test_archivist().await; + + // Register connector first + let connector_uid = Uuid::now_v7(); + let connector_req = crate::types::RegisterConnectorRequest { + custom_uid: Some(connector_uid), + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "test-connector".to_string(), + metadata: serde_json::json!({}), + fingerprint: None, + }; + archivist + .register_connector(connector_req, None) + .await + .unwrap(); + + let sessions = vec![create_test_session("session-1", "Session 1")]; + + // Mock message fetcher that fails + let fetch_messages = |_session_id: &str| { + Box::pin(async move { + Err(ArchivistError::InvalidRequest( + "Failed to fetch messages".to_string(), + )) + }) as BoxFuture<'static, Result>> + }; + + // Backfill + let stats = backfill_from_sessions(&archivist, connector_uid, sessions, fetch_messages) + .await + .unwrap(); + + // Verify stats - session registered but messages failed + assert_eq!(stats.sessions_found, 1); + assert_eq!(stats.sessions_imported, 1); // Session was registered + assert_eq!(stats.messages_imported, 0); // But no messages imported + assert_eq!(stats.errors.len(), 1); // Error recorded + assert!(stats.errors[0].contains("Failed to fetch messages")); + } + + #[test] + fn test_backfill_stats_default() { + let stats = BackfillStats::default(); + assert_eq!(stats.sessions_found, 0); + assert_eq!(stats.sessions_imported, 0); + assert_eq!(stats.sessions_skipped, 0); + assert_eq!(stats.messages_imported, 0); + assert_eq!(stats.errors.len(), 0); + } + + #[test] + fn test_convert_message_to_record() { + let scroll_id = Uuid::now_v7(); + let msg = create_test_message("msg-1", "session-1", MessageRole::User, "Hello world"); + + let record = convert_message_to_record(msg, scroll_id); + + assert_eq!(record.session, scroll_id); + assert_eq!(record.role, "user"); + assert_eq!(record.content_md, "Hello world"); + assert_eq!(record.version, 1); + } + + #[test] + fn test_convert_message_with_thinking() { + let scroll_id = Uuid::now_v7(); + let msg = Message { + id: "msg-1".to_string(), + session_id: "session-1".to_string(), + role: MessageRole::Assistant, + created_at: Utc::now(), + content: vec![dirigent_protocol::MessagePart::Thinking { + text: "Let me think...".to_string(), + }], + status: MessageStatus::Completed, + metadata: None, + }; + + let record = convert_message_to_record(msg, scroll_id); + + assert!(record.content_md.contains("")); + assert!(record.content_md.contains("Let me think...")); + assert!(record.content_md.contains("")); + } +} diff --git a/crates/dirigent_archivist/src/coordinator/admin.rs b/crates/dirigent_archivist/src/coordinator/admin.rs new file mode 100644 index 0000000..39cf36f --- /dev/null +++ b/crates/dirigent_archivist/src/coordinator/admin.rs @@ -0,0 +1,70 @@ +//! Admin / inspection methods on `Archivist`. +//! +//! Split out because they aren't part of the hot-path coordinator API: +//! `shutdown` drains queued writer tasks, `list_archives_with_health` +//! snapshots every registration's health + queue depth, and the cache +//! admin methods delegate to `ReadCache`. + +use std::sync::Arc; + +use tokio::sync::oneshot; + +use crate::error::Result; +use crate::registry::writer::WriteOp; +use crate::registry::{ArchiveRegistration, ArchiveStatus}; + +use super::Archivist; + +impl Archivist { + /// Drain every queued writer task. Inline backends are no-ops. + /// Call before process exit to ensure in-flight batches land. + pub async fn shutdown(&self) -> Result<()> { + let regs: Vec> = self.registrations.read().await.clone(); + for reg in regs.iter() { + if let Some(writer) = reg.writer.as_ref() { + let (tx, rx) = oneshot::channel(); + // If the send fails, the writer task has already exited — skip the wait. + if writer.sender.send(WriteOp::Shutdown(tx)).await.is_ok() { + let _ = rx.await; + } + // Join the task, if it's still attached. + if let Some(handle) = writer.join.lock().await.take() { + let _ = handle.await; + } + } + } + Ok(()) + } + + /// Snapshot every registered archive's current status. + pub async fn list_archives_with_health(&self) -> Vec { + let regs: Vec> = self.registrations.read().await.clone(); + let mut out = Vec::with_capacity(regs.len()); + for reg in regs.iter() { + let health = reg.last_health.read().await.clone(); + let last_error = reg.last_error.read().await.clone(); + let queue_depth = reg.writer.as_ref().map(|w| w.queue_depth_now()); + out.push(ArchiveStatus { + name: reg.name.clone(), + type_name: reg.type_name.to_string(), + enabled: reg.enabled, + write_active: reg.write_active, + failure_mode: reg.failure_mode, + read_priority: reg.read_priority, + capabilities: reg.capabilities().clone(), + health, + last_error, + queue_depth, + }); + } + out + } + + pub async fn clear_read_cache(&self) { + self.read_cache.clear().await; + } + + pub async fn read_cache_size(&self) -> usize { + self.read_cache.len().await + } +} diff --git a/crates/dirigent_archivist/src/coordinator/archives.rs b/crates/dirigent_archivist/src/coordinator/archives.rs new file mode 100644 index 0000000..3004d02 --- /dev/null +++ b/crates/dirigent_archivist/src/coordinator/archives.rs @@ -0,0 +1,77 @@ +//! Archive lifecycle methods for `Archivist`. +//! +//! Phase 3 is **startup-only**: the archive registry is constructed from +//! `dirigent.toml` at boot and not mutated at runtime. Accordingly, +//! `add_archive`, `remove_archive`, and `set_default_archive` all return +//! [`ArchivistError::DynamicRegistryUnsupported`]. The `list_archives` +//! and `get_default_archive` read-paths continue to operate against the +//! new `Vec>` storage. + +use std::path::PathBuf; + +use crate::coordinator::Archivist; +use crate::error::{ArchivistError, Result}; +use crate::registry::FailureMode; + +impl Archivist { + /// **Deprecated in Phase 3.** Archive registry is configured at boot + /// via `dirigent.toml`; runtime mutation is not supported. + pub async fn add_archive(&self, _name: String, _path: PathBuf) -> Result<()> { + Err(ArchivistError::DynamicRegistryUnsupported) + } + + /// **Deprecated in Phase 3.** Archive registry is configured at boot + /// via `dirigent.toml`; runtime mutation is not supported. + pub async fn remove_archive(&self, _name: String, _force: bool) -> Result<()> { + Err(ArchivistError::DynamicRegistryUnsupported) + } + + /// List all configured archives. Session counts are reported as `0` + /// because the Phase 3 multi-backend coordinator does not persist a + /// per-archive connector index; counts will be reintroduced by the + /// admin-status query in Task 23. + pub async fn list_archives(&self) -> Result> { + let regs = self.registrations.read().await; + let primary_name = regs + .iter() + .filter(|r| { + r.enabled && r.write_active && r.failure_mode == FailureMode::Required + }) + .min_by_key(|r| r.read_priority) + .map(|r| r.name.clone()); + + Ok(regs + .iter() + .map(|r| super::types::ArchiveInfo { + name: r.name.clone(), + path: PathBuf::new(), + created_at: chrono::Utc::now(), + session_count: 0, + is_default: primary_name.as_deref() == Some(r.name.as_str()), + }) + .collect()) + } + + /// Get the name of the "default" archive — interpreted in Phase 3 as + /// the enabled, write-active, `Required` backend with the lowest + /// `read_priority`. + pub async fn get_default_archive(&self) -> Result { + let regs = self.registrations.read().await; + regs.iter() + .filter(|r| { + r.enabled && r.write_active && r.failure_mode == FailureMode::Required + }) + .min_by_key(|r| r.read_priority) + .map(|r| r.name.clone()) + .ok_or_else(|| ArchivistError::PrimaryUnavailable { + name: "".into(), + reason: "no required write-active backend".into(), + }) + } + + /// **Deprecated in Phase 3.** Archive registry is configured at boot + /// via `dirigent.toml`; runtime mutation is not supported. + pub async fn set_default_archive(&self, _name: String) -> Result<()> { + Err(ArchivistError::DynamicRegistryUnsupported) + } +} diff --git a/crates/dirigent_archivist/src/coordinator/boot.rs b/crates/dirigent_archivist/src/coordinator/boot.rs new file mode 100644 index 0000000..ce890cb --- /dev/null +++ b/crates/dirigent_archivist/src/coordinator/boot.rs @@ -0,0 +1,281 @@ +//! Boot-time construction of the `Archivist` coordinator from a parsed +//! `ArchivesConfig` and a `BackendRegistry` of factories. + +use std::sync::Arc; + +use tokio::sync::RwLock; + +use crate::backend::HealthStatus; +use crate::error::ArchivistBootError; +use crate::registry::{ + cache::ReadCache, ArchiveRegistration, ArchivesConfig, BackendRegistry, FailureMode, + WritePolicy, +}; + +use super::Archivist; + +impl Archivist { + /// Construct the coordinator from a parsed `[[archives]]` config block + /// and a registry of backend factories. + /// + /// - Validates the config (duplicate-name / no-primary rules). + /// - Instantiates every enabled backend via the factory. + /// - Runs a startup `health_check` per backend. + /// - Sorts registrations by `read_priority` (ties by declaration order). + /// - Writer tasks for `WritePolicy::Queued` backends are wired in Task 17; + /// for now every backend boots with `writer = None`. + pub async fn from_config( + mut config: ArchivesConfig, + registry: &BackendRegistry, + base_dir: Option<&std::path::Path>, + ) -> Result { + config.validate()?; + + // Filter-level validation (Phase 4, Task 19). + // + // 1. At least one enabled write-active archive must have an + // unrestricted filter. Otherwise there is no default home for + // a session that does not match any filter, and the primary + // target would silently exclude sessions despite being the + // "write-always" backend. + // 2. No archive may declare a filter whose `include_connectors` + // set is `Some(empty)` — that form rejects every session + // unconditionally and is almost always a config typo. + let mut has_unrestricted_write_active = false; + for entry in &config.entries { + if let Some(inc) = &entry.filter.include_connectors { + if inc.is_empty() { + return Err(ArchivistBootError::FilterRejectsEverything { + archive: entry.name.clone(), + }); + } + } + if entry.enabled && entry.write_active && entry.filter.is_unrestricted() { + has_unrestricted_write_active = true; + } + } + if !config.entries.is_empty() && !has_unrestricted_write_active { + return Err(ArchivistBootError::NoUnrestrictedPrimary); + } + + // Resolve relative `params.path` values against `base_dir` so that + // archives declared with relative paths land under the data directory + // rather than the binary's CWD. + if let Some(base) = base_dir { + for entry in &mut config.entries { + if let toml::Value::Table(ref mut table) = entry.params { + if let Some(toml::Value::String(ref mut path_str)) = table.get_mut("path") { + let p = std::path::Path::new(path_str.as_str()); + if p.is_relative() { + *path_str = base.join(&*path_str).to_string_lossy().into_owned(); + } + } + } + } + } + + let mut registrations: Vec> = Vec::new(); + + for entry in config.entries.into_iter() { + let backend = registry + .build(&entry.name, &entry.type_name, entry.params) + .await + .map_err(|e| match e { + crate::registry::BackendBuildError::UnknownType(t) => { + ArchivistBootError::UnknownType { + name: entry.name.clone(), + type_name: t, + } + } + other => ArchivistBootError::BackendBuild { + name: entry.name.clone(), + source: other, + }, + })?; + + let initial_health = backend.health_check().await; + + if entry.failure_mode == FailureMode::Required { + if let HealthStatus::Unavailable { reason } = &initial_health { + return Err(ArchivistBootError::UnavailableRequiredBackend { + name: entry.name.clone(), + reason: reason.clone(), + }); + } + } + + let runtime_policy: WritePolicy = entry.write_policy.into_runtime(); + + // Build shared drift state up-front so the writer task (if any) + // and the registration's health-drift helpers mutate the SAME + // `Arc>` cells. This keeps Task 22's drift semantics + // coherent across the inline and queued paths. + let health_state: Arc> = + Arc::new(RwLock::new(initial_health.clone())); + let error_state: Arc< + RwLock, String)>>, + > = Arc::new(RwLock::new(None)); + let failure_counter: Arc> = Arc::new(RwLock::new(0u32)); + + let writer = match &runtime_policy { + WritePolicy::Inline => None, + WritePolicy::Queued { + batch_window_ms, + capacity, + overflow, + } => Some(crate::registry::writer::spawn_writer( + backend.clone(), + entry.name.clone(), + *capacity, + std::time::Duration::from_millis(*batch_window_ms), + *overflow, + health_state.clone(), + error_state.clone(), + failure_counter.clone(), + )), + }; + + // Leak `type_name` to satisfy &'static str on the registration; safe at boot, + // and a constant number of entries (O(archives in config)). + let type_name_static: &'static str = Box::leak(entry.type_name.into_boxed_str()); + + let registration = ArchiveRegistration::new_with_shared_state( + entry.name, + type_name_static, + backend, + entry.write_active, + entry.failure_mode, + entry.read_priority, + entry.enabled, + runtime_policy, + writer, + health_state, + error_state, + failure_counter, + ) + .with_filter(entry.filter); + + registrations.push(Arc::new(registration)); + } + + // Sort by `read_priority`. Rust's sort is stable, so ties keep declaration order. + registrations.sort_by_key(|r| r.read_priority); + + Ok(Self { + registrations: RwLock::new(registrations), + read_cache: Arc::new(ReadCache::new()), + registry_path: std::path::PathBuf::new(), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::registry::{ArchivesConfig, BackendRegistry}; + + fn parse(toml_src: &str) -> ArchivesConfig { + toml::from_str(toml_src).unwrap() + } + + #[tokio::test] + async fn relative_archive_path_resolved_against_base_dir() { + let base = tempfile::tempdir().unwrap(); + let cfg = parse( + r#" + [[archives]] + name = "main" + type = "jsonl" + [archives.params] + path = "my_archive" + "#, + ); + let registry = BackendRegistry::with_jsonl(); + let archivist = Archivist::from_config(cfg, ®istry, Some(base.path())) + .await + .unwrap(); + + // The archive should have been created under base_dir/my_archive. + // Verify by checking the .contexts directory exists. + let expected = base.path().join("my_archive").join(".contexts"); + assert!( + expected.exists(), + "expected {expected:?} to exist after boot with relative path" + ); + + // Also verify the archivist is functional (has one registration). + let archives = archivist.list_archives().await.unwrap(); + assert_eq!(archives.len(), 1); + } + + #[tokio::test] + async fn absolute_archive_path_not_affected_by_base_dir() { + let base = tempfile::tempdir().unwrap(); + let archive_dir = tempfile::tempdir().unwrap(); + let abs_path = archive_dir.path().to_string_lossy().replace('\\', "/"); + + let cfg = parse(&format!( + r#" + [[archives]] + name = "main" + type = "jsonl" + [archives.params] + path = "{abs_path}" + "#, + )); + let registry = BackendRegistry::with_jsonl(); + let archivist = Archivist::from_config(cfg, ®istry, Some(base.path())) + .await + .unwrap(); + + // The archive should be at the absolute path, NOT under base_dir. + let expected = archive_dir.path().join(".contexts"); + assert!( + expected.exists(), + "expected {expected:?} to exist (absolute path should be used as-is)" + ); + + // Verify nothing was created under base_dir with the archive name. + // If base_dir resolution incorrectly touched the absolute path, we'd + // see stray directories under base_dir. + let base_entries: Vec<_> = std::fs::read_dir(base.path()) + .unwrap() + .collect(); + assert!( + base_entries.is_empty(), + "base_dir should be untouched when archive path is absolute, found: {base_entries:?}" + ); + + let archives = archivist.list_archives().await.unwrap(); + assert_eq!(archives.len(), 1); + } + + #[tokio::test] + async fn none_base_dir_preserves_existing_behavior() { + let archive_dir = tempfile::tempdir().unwrap(); + let abs_path = archive_dir.path().to_string_lossy().replace('\\', "/"); + + let cfg = parse(&format!( + r#" + [[archives]] + name = "main" + type = "jsonl" + [archives.params] + path = "{abs_path}" + "#, + )); + let registry = BackendRegistry::with_jsonl(); + let archivist = Archivist::from_config(cfg, ®istry, None) + .await + .unwrap(); + + let expected = archive_dir.path().join(".contexts"); + assert!( + expected.exists(), + "expected {expected:?} to exist with None base_dir and absolute path" + ); + + let archives = archivist.list_archives().await.unwrap(); + assert_eq!(archives.len(), 1); + } +} diff --git a/crates/dirigent_archivist/src/coordinator/connectors.rs b/crates/dirigent_archivist/src/coordinator/connectors.rs new file mode 100644 index 0000000..1195afc --- /dev/null +++ b/crates/dirigent_archivist/src/coordinator/connectors.rs @@ -0,0 +1,285 @@ +//! Connector orchestration for `Archivist`. +//! +//! Alias detection and tri-state registration logic live here; persistence is +//! delegated to each backend's `ConnectorRegistryBackend` sub-trait. Ported +//! from `FileBasedArchivist::register_connector` and +//! `MultiArchiveArchivist::resolve_connector_uid`. + +use chrono::Utc; +use uuid::Uuid; + +use crate::backend::ArchiveCapability; +use crate::coordinator::Archivist; +use crate::error::{ArchivistError, Result}; +use crate::types::{ + ConnectorRecord, RegisterConnectorRequest, RegisterConnectorResponse, RegisterStatus, +}; + +impl Archivist { + /// Register a connector with alias detection. + /// + /// Ported from `FileBasedArchivist::register_connector`. Decision order: + /// + /// 1. If `custom_uid` collides with an existing connector: + /// - same `client_native_id` → `Aliased` (idempotent re-registration). + /// - different `client_native_id` → `CollisionInconsistent` error. + /// 2. If the `client_native_id` is already registered under a different + /// UID → `Aliased` to that pre-existing UID. + /// 3. If a `fingerprint` matches a pre-existing connector → `Aliased` to + /// that UID. (Identity persistence across connector re-adds.) + /// 4. Otherwise → `Accepted`; a new `ConnectorRecord` is persisted via + /// `ConnectorRegistryBackend::put_connector`. + // TODO(phase3 task 16): register_connector fanout requires replicating the + // ConnectorRecord to secondaries. Since connectors are identity-shaped (UIDs + // must match across backends), the tri-state alias detection must stay + // canonical on the primary, but the accepted record should be mirrored to + // secondaries. Deferred to a follow-up within Phase 3 — the core Task 16 + // plan covers append_messages and the session mutators which are the hot + // paths. Current behaviour: single-primary via `resolve_backend`. + pub async fn register_connector( + &self, + req: RegisterConnectorRequest, + archive: Option, + ) -> Result { + let backend = self.resolve_backend(archive).await?; + let registry = backend.as_connector_registry().ok_or_else(|| { + ArchivistError::CapabilityNotSupported { + capability: ArchiveCapability::ConnectorRegistry, + backend: "selected".into(), + } + })?; + + // Generate connector UID (use custom_uid or generate new) + let connector_uid = req.custom_uid.unwrap_or_else(Uuid::now_v7); + + // Load existing (non-alias) connectors for collision detection. + let existing_connectors = registry.list_connectors().await?; + + // 1. Check for UID collision. + if let Some(existing) = existing_connectors + .iter() + .find(|c| c.connector_uid == connector_uid) + { + if existing.client_native_id == req.client_native_id { + // Same UID with same client_native_id -> ALIASED (idempotent). + return Ok(RegisterConnectorResponse { + status: RegisterStatus::Aliased, + connector_uid, + alias_of: Some(connector_uid), + note: Some("Connector already registered with this UID".to_string()), + }); + } else { + // Same UID with different client_native_id -> REJECTED. + return Err(ArchivistError::CollisionInconsistent(connector_uid)); + } + } + + // 2. Check for existing client_native_id (different UID collision). + if let Some(existing) = existing_connectors + .iter() + .find(|c| c.client_native_id == req.client_native_id) + { + return Ok(RegisterConnectorResponse { + status: RegisterStatus::Aliased, + connector_uid: existing.connector_uid, + alias_of: Some(existing.connector_uid), + note: Some("Connector already registered with different UID".to_string()), + }); + } + + // 3. Check for fingerprint match (identity persistence across re-adds). + // + // Note: the original `FileBasedArchivist` additionally refreshed the + // matched connector's `title`/`metadata` on disk and in cache here. + // That refresh bypassed both the TSV index and any backend abstraction + // (direct `read_json`/`write_json` against `connector.json`). The + // `ConnectorRegistryBackend` trait does not yet expose an + // "update metadata" method, and `put_connector` would append a + // duplicate row to the index rather than mutate in place. The refresh + // was best-effort (`let _ = write_json(...)`) and is not exercised by + // existing tests; deliberately skipped here. Re-introduce via a + // dedicated backend method if a consumer relies on it. + if let Some(ref fp) = req.fingerprint { + if let Some(existing) = existing_connectors + .iter() + .find(|c| c.fingerprint.as_deref() == Some(fp.as_str())) + { + let matched_uid = existing.connector_uid; + return Ok(RegisterConnectorResponse { + status: RegisterStatus::Aliased, + connector_uid: matched_uid, + alias_of: Some(matched_uid), + note: Some(format!("Matched by fingerprint: {}", fp)), + }); + } + } + + // 4. No collision -> ACCEPTED, create and persist new connector. + let now = Utc::now(); + let connector_record = ConnectorRecord { + version: 1, + connector_uid, + r#type: req.r#type, + title: req.title, + client_native_id: req.client_native_id, + alias_of: None, + created_at: now, + metadata: req.metadata, + fingerprint: req.fingerprint, + }; + + registry.put_connector(connector_record).await?; + + Ok(RegisterConnectorResponse { + status: RegisterStatus::Accepted, + connector_uid, + alias_of: None, + note: None, + }) + } + + /// Resolve a connector UID by scanning every registered backend. + /// + /// Ported from `MultiArchiveArchivist::resolve_connector_uid`: each + /// backend is tried in turn; the first backend that recognises the + /// `client_native_id` wins. As a secondary path, if `client_native_id` + /// parses as a UUID, checks whether a backend already has a connector + /// record at that UID. Returns `ConnectorUnknown(Uuid::nil())` if no + /// backend can resolve it. + pub async fn resolve_connector_uid(&self, client_native_id: &str) -> Result { + // Hand-rolled walk rather than `read_walk_collection`: we want + // "try every backend" semantics — a backend that returns `Ok(None)` + // should NOT win the walk. `read_walk_collection` treats any `Ok(_)` + // as a hit, so it would stop at the first backend that answered at + // all. Health drift is still wired through `record_read_*`. + let regs: Vec<_> = self.registrations.read().await.clone(); + for reg in regs.iter() { + if !reg.enabled { + continue; + } + let Some(registry) = reg.backend.as_connector_registry() else { + continue; + }; + match registry.resolve_connector_uid(client_native_id).await { + Ok(Some(uid)) => { + self.record_read_success(reg).await; + return Ok(uid); + } + Ok(None) => { + self.record_read_success(reg).await; + if let Ok(parsed) = Uuid::parse_str(client_native_id) { + match registry.get_connector(parsed).await { + Ok(Some(_)) => return Ok(parsed), + Ok(None) => {} + Err(_) => { + self.record_read_failure(reg).await; + } + } + } + } + Err(_) => { + self.record_read_failure(reg).await; + } + } + } + Err(ArchivistError::ConnectorUnknown(Uuid::nil())) + } + + /// List connectors in the selected archive (non-aliases only). + /// + /// When `archive` is `Some`, the explicit override still resolves directly + /// against that named backend (returning `ArchiveNameUnknown` / + /// `CapabilityNotSupported` as appropriate). When `None`, routing walks + /// enabled backends in `read_priority` order and returns the first + /// `ConnectorRegistry`-capable answer. + pub async fn list_connectors( + &self, + archive: Option, + ) -> Result> { + if let Some(name) = archive { + let reg = self + .find_registration(&name) + .await + .ok_or(ArchivistError::ArchiveNameUnknown(name))?; + let registry = reg.backend.as_connector_registry().ok_or_else(|| { + ArchivistError::CapabilityNotSupported { + capability: ArchiveCapability::ConnectorRegistry, + backend: reg.name.clone(), + } + })?; + return registry.list_connectors().await; + } + Ok(self + .read_walk_collection( + |reg| reg.backend.as_connector_registry().is_some(), + |backend| async move { + let cr = backend + .as_connector_registry() + .expect("predicate ensured"); + cr.list_connectors().await + }, + ) + .await? + .unwrap_or_default()) + } + + /// Update the stable fingerprint of an existing connector. + /// + /// NOTE: read-mutate-write on the backend side; falls through to inline + /// under `WritePolicy::Queued` (no `WriteOp` variant). + pub async fn update_connector_fingerprint( + &self, + connector_uid: Uuid, + fingerprint: String, + archive: Option, + ) -> Result<()> { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + let primary_reg = primary.backend.as_connector_registry().ok_or_else(|| { + ArchivistError::PrimaryUnavailable { + name: primary.name.clone(), + reason: "backend lacks ConnectorRegistry capability".into(), + } + })?; + if let Err(e) = primary_reg + .update_connector_fingerprint(connector_uid, fingerprint.clone()) + .await + { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + self.record_write_success(&primary).await; + + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + let Some(sec_reg) = reg.backend.as_connector_registry() else { + tracing::debug!( + backend = reg.name.as_str(), + type_name = reg.type_name, + op = "update_connector_fingerprint", + "capability_skip" + ); + continue; + }; + if let Err(e) = sec_reg + .update_connector_fingerprint(connector_uid, fingerprint.clone()) + .await + { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + Ok(()) + } +} diff --git a/crates/dirigent_archivist/src/coordinator/meta.rs b/crates/dirigent_archivist/src/coordinator/meta.rs new file mode 100644 index 0000000..94c8c6f --- /dev/null +++ b/crates/dirigent_archivist/src/coordinator/meta.rs @@ -0,0 +1,526 @@ +//! Meta events, DAG, and cleanup orchestration for `Archivist`. +//! +//! Ported from `FileBasedArchivist` in `archivist.rs`. Meta events and DAG +//! methods are thin delegates over `as_meta_events()` / `as_dag()`; +//! `get_session_tree` performs a recursive DAG walk; `cleanup_empty_sessions` +//! pages through all sessions and deletes those with zero messages (skipping +//! `SessionKind::AcpConnection` meta sessions, which track events rather than +//! messages). + +use uuid::Uuid; + +use crate::backend::ArchiveCapability; +use crate::coordinator::Archivist; +use crate::error::{ArchivistError, Result}; +use crate::types::{ + DagEdge, MetaEventRecord, SessionKind, SessionListQuery, SessionMetadata, MAX_PAGE_LIMIT, +}; + +impl Archivist { + // ------------------------------------------------------------------ + // Meta events + // ------------------------------------------------------------------ + + pub async fn append_meta_events( + &self, + scroll_id: Uuid, + events: Vec, + archive: Option, + ) -> Result<()> { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + // Primary must have MetaEvents capability even in the queued path — + // the writer task dispatches to `as_meta_events()`, so we'd silently + // drop events on an incapable primary. Fail fast here. + let _ = primary.backend.as_meta_events().ok_or_else(|| { + ArchivistError::PrimaryUnavailable { + name: primary.name.clone(), + reason: "backend lacks MetaEvents capability".into(), + } + })?; + + match &primary.write_policy { + crate::registry::WritePolicy::Inline => { + let primary_meta = primary + .backend + .as_meta_events() + .expect("capability checked above"); + if let Err(e) = primary_meta + .append_meta_events(scroll_id, events.clone()) + .await + { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + self.record_write_success(&primary).await; + } + crate::registry::WritePolicy::Queued { .. } => { + let writer = primary + .writer + .as_ref() + .expect("queued policy implies writer handle present"); + writer + .enqueue(crate::registry::writer::WriteOp::AppendMetaEvents { + scroll_id, + events: events.clone(), + }) + .await?; + } + } + + let session_metadata_for_filter = self + .load_metadata_for_filter(scroll_id, ®s, &primary.name) + .await; + + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + if !Self::filter_allows(reg, session_metadata_for_filter.as_ref()) { + tracing::debug!( + archive = %reg.name, + scroll_id = %scroll_id, + op = "append_meta_events", + "filter_skip" + ); + continue; + } + if reg.backend.as_meta_events().is_none() { + tracing::debug!( + backend = reg.name.as_str(), + type_name = reg.type_name, + op = "append_meta_events", + "capability_skip" + ); + continue; + } + match ®.write_policy { + crate::registry::WritePolicy::Inline => { + let me = reg + .backend + .as_meta_events() + .expect("capability checked above"); + if let Err(e) = me.append_meta_events(scroll_id, events.clone()).await { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + crate::registry::WritePolicy::Queued { .. } => { + let writer = reg + .writer + .as_ref() + .expect("queued policy implies writer handle present"); + if let Err(e) = writer + .enqueue(crate::registry::writer::WriteOp::AppendMetaEvents { + scroll_id, + events: events.clone(), + }) + .await + { + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } + } + } + } + Ok(()) + } + + pub async fn get_meta_events( + &self, + scroll_id: Uuid, + _archive: Option, + ) -> Result> { + // `archive` is now ignored for reads; routing picks the highest-priority + // backend that has the session and supports `MetaEvents`. + Ok(self + .read_walk_per_session( + scroll_id, + |reg| reg.backend.as_meta_events().is_some(), + |backend| async move { + let me = backend.as_meta_events().expect("predicate ensured"); + me.get_meta_events(scroll_id).await.map(Some) + }, + ) + .await? + .unwrap_or_default()) + } + + /// Update the connection status of an ACP meta-session. + /// + /// NOTE: read-mutate-write on the backend side (the impl rewrites fields + /// on the stored session); falls through to inline under + /// `WritePolicy::Queued` (no `WriteOp` variant). + pub async fn update_meta_session_status( + &self, + scroll_id: Uuid, + is_connected: bool, + current_session_id: Option, + archive: Option, + ) -> Result<()> { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + let primary_meta = primary.backend.as_meta_events().ok_or_else(|| { + ArchivistError::PrimaryUnavailable { + name: primary.name.clone(), + reason: "backend lacks MetaEvents capability".into(), + } + })?; + if let Err(e) = primary_meta + .update_meta_session_status(scroll_id, is_connected, current_session_id) + .await + { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + self.record_write_success(&primary).await; + + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + let Some(me) = reg.backend.as_meta_events() else { + tracing::debug!( + backend = reg.name.as_str(), + type_name = reg.type_name, + op = "update_meta_session_status", + "capability_skip" + ); + continue; + }; + if let Err(e) = me + .update_meta_session_status(scroll_id, is_connected, current_session_id) + .await + { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + Ok(()) + } + + pub async fn list_meta_sessions( + &self, + _archive: Option, + ) -> Result> { + // Collection-shape read: first enabled/healthy backend that supports + // `MetaEvents` wins. `archive` override is no longer honoured here — + // routing decides. + Ok(self + .read_walk_collection( + |reg| reg.backend.as_meta_events().is_some(), + |backend| async move { + let me = backend.as_meta_events().expect("predicate ensured"); + me.list_meta_sessions().await + }, + ) + .await? + .unwrap_or_default()) + } + + pub async fn find_meta_session_by_client( + &self, + client_id: &str, + _archive: Option, + ) -> Result> { + // Collection-shape read: first enabled/healthy backend that supports + // `MetaEvents` wins. The inner op returns `Result>`, so the + // walker's outer `Option` flattens to the inner one — "no backend + // answered" and "backend answered None" collapse the same way. + let client_id = client_id.to_string(); + let result = self + .read_walk_collection( + |reg| reg.backend.as_meta_events().is_some(), + |backend| { + let client_id = client_id.clone(); + async move { + let me = backend.as_meta_events().expect("predicate ensured"); + me.find_meta_session_by_client(&client_id).await + } + }, + ) + .await?; + Ok(result.flatten()) + } + + // ------------------------------------------------------------------ + // DAG + // ------------------------------------------------------------------ + + pub async fn append_dag_edge( + &self, + edge: DagEdge, + archive: Option, + ) -> Result<()> { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + // Primary must have DAG capability — even in the queued path the + // writer task dispatches to `as_dag()`, so silently accepting a + // non-DAG primary would lose the edge. + let _ = primary.backend.as_dag().ok_or_else(|| { + ArchivistError::PrimaryUnavailable { + name: primary.name.clone(), + reason: "backend lacks Dag capability".into(), + } + })?; + + match &primary.write_policy { + crate::registry::WritePolicy::Inline => { + let primary_dag = primary + .backend + .as_dag() + .expect("capability checked above"); + if let Err(e) = primary_dag.append_dag_edge(edge.clone()).await { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + self.record_write_success(&primary).await; + } + crate::registry::WritePolicy::Queued { .. } => { + let writer = primary + .writer + .as_ref() + .expect("queued policy implies writer handle present"); + writer + .enqueue(crate::registry::writer::WriteOp::AppendDagEdge(edge.clone())) + .await?; + } + } + + // DAG edges are indexed under the parent scroll_id, so use that for + // filter evaluation (the session whose DAG is being extended). + let parent_scroll_id = edge.parent; + let session_metadata_for_filter = self + .load_metadata_for_filter(parent_scroll_id, ®s, &primary.name) + .await; + + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + if !Self::filter_allows(reg, session_metadata_for_filter.as_ref()) { + tracing::debug!( + archive = %reg.name, + scroll_id = %parent_scroll_id, + op = "append_dag_edge", + "filter_skip" + ); + continue; + } + if reg.backend.as_dag().is_none() { + tracing::debug!( + backend = reg.name.as_str(), + type_name = reg.type_name, + op = "append_dag_edge", + "capability_skip" + ); + continue; + } + match ®.write_policy { + crate::registry::WritePolicy::Inline => { + let d = reg.backend.as_dag().expect("capability checked above"); + if let Err(e) = d.append_dag_edge(edge.clone()).await { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + crate::registry::WritePolicy::Queued { .. } => { + let writer = reg + .writer + .as_ref() + .expect("queued policy implies writer handle present"); + if let Err(e) = writer + .enqueue(crate::registry::writer::WriteOp::AppendDagEdge(edge.clone())) + .await + { + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } + } + } + } + Ok(()) + } + + pub async fn get_children( + &self, + scroll_id: Uuid, + _archive: Option, + ) -> Result> { + // `archive` is now ignored for reads; routing picks the highest-priority + // backend that has the session and supports `Dag`. + Ok(self + .read_walk_per_session( + scroll_id, + |reg| reg.backend.as_dag().is_some(), + |backend| async move { + let d = backend.as_dag().expect("predicate ensured"); + d.get_children(scroll_id).await.map(Some) + }, + ) + .await? + .unwrap_or_default()) + } + + /// Recursive DAG walk rooted at `root_scroll_id`. + /// + /// Matches the shape of `FileBasedArchivist::get_session_tree`: returns + /// every edge reachable from `root_scroll_id` (children, grandchildren, + /// …). Uses `DagBackend::get_dag_edges` per-parent plus a `seen` set to + /// guard against cycles. + pub async fn get_session_tree( + &self, + root_scroll_id: Uuid, + archive: Option, + ) -> Result> { + // TODO(phase3): consider multi-backend DAG walk in a future phase — + // current impl uses the default backend only. Consistent BFS across + // a tree requires all `get_dag_edges` calls to target the SAME + // backend as the root, which the walker API does not yet expose. + let backend = self.resolve_backend(archive).await?; + let dag = backend + .as_dag() + .ok_or_else(|| ArchivistError::CapabilityNotSupported { + capability: ArchiveCapability::Dag, + backend: "selected".into(), + })?; + + let mut out = Vec::new(); + let mut stack = vec![root_scroll_id]; + let mut seen = std::collections::HashSet::new(); + while let Some(parent) = stack.pop() { + if !seen.insert(parent) { + continue; + } + let edges = dag.get_dag_edges(parent).await?; + for e in &edges { + stack.push(e.child); + } + out.extend(edges); + } + Ok(out) + } + + // ------------------------------------------------------------------ + // Cleanup + // ------------------------------------------------------------------ + + /// Delete sessions that have zero messages. + /// + /// Ported from `FileBasedArchivist::cleanup_empty_sessions`. Pages through + /// every session (including hidden ones) via `list_sessions_paged`, counts + /// messages per session, and deletes those with zero. Meta sessions + /// (`SessionKind::AcpConnection`) are skipped — they track connection + /// events in `events.jsonl`, not messages, so an empty message log is + /// expected. + /// + /// Returns `(deleted, total_scanned)`. + pub async fn cleanup_empty_sessions( + &self, + archive: Option, + ) -> Result<(usize, usize)> { + let backend = self.resolve_backend(archive).await?; + + let mut total: usize = 0; + let mut deleted: usize = 0; + let mut q = SessionListQuery { + include_hidden: true, + limit: MAX_PAGE_LIMIT, + ..SessionListQuery::default() + }; + + loop { + let page = backend.list_sessions_paged(q.clone()).await?; + + for session in page.items.iter() { + total += 1; + + // Skip meta sessions - they track events, not messages, so a + // zero message count is expected and not a signal of emptiness. + if session.kind == SessionKind::AcpConnection { + tracing::debug!( + scroll_id = %session.scroll_id, + "Skipping meta session (AcpConnection) during cleanup" + ); + continue; + } + + let count = match backend.count_messages(session.scroll_id).await { + Ok(c) => c, + Err(e) => { + // Match legacy semantics: if we can't count messages, + // skip this session rather than risk deleting a + // non-empty one. + tracing::warn!( + scroll_id = %session.scroll_id, + error = %e, + "Failed to count messages for session, skipping cleanup" + ); + continue; + } + }; + + if count == 0 { + match backend.delete_session(session.scroll_id).await { + Ok(()) => { + tracing::info!( + scroll_id = %session.scroll_id, + "Deleted empty session during cleanup" + ); + deleted += 1; + } + Err(e) => { + tracing::warn!( + scroll_id = %session.scroll_id, + error = %e, + "Failed to delete empty session during cleanup" + ); + } + } + } + } + + match page.next_cursor { + Some(cursor) => q.cursor = Some(cursor), + None => break, + } + } + + tracing::info!( + deleted = deleted, + total = total, + "Completed empty session cleanup" + ); + + Ok((deleted, total)) + } +} diff --git a/crates/dirigent_archivist/src/coordinator/mod.rs b/crates/dirigent_archivist/src/coordinator/mod.rs new file mode 100644 index 0000000..61e4654 --- /dev/null +++ b/crates/dirigent_archivist/src/coordinator/mod.rs @@ -0,0 +1,231 @@ +//! Concrete archivist coordinator. +//! +//! Owns a `Vec>` sorted by `read_priority`, plus a +//! positive `scroll_id → backend` cache. The registry is constructed from +//! `dirigent.toml` at boot (Task 12). `Archivist::new` remains a legacy +//! convenience for the dev-instance migration path; later tasks migrate +//! consumers to `Archivist::from_config`. + +mod admin; +mod archives; +mod boot; +mod connectors; +mod meta; +mod routing; +mod sessions; +pub mod types; + +pub use types::{ArchiveInfo, ArchiveMetadata}; + +use std::path::PathBuf; +use std::sync::Arc; + +use tokio::sync::RwLock; + +use crate::backend::ArchiveBackend; +use crate::error::{ArchivistError, Result}; +use crate::registry::{ + cache::ReadCache, ArchiveRegistration, FailureMode, WritePolicy, +}; + +pub struct Archivist { + pub(crate) registrations: RwLock>>, + #[allow(dead_code)] // wired up in later tasks (cache-backed reads) + pub(crate) read_cache: Arc, + #[allow(dead_code)] // retained for future admin endpoints / diagnostics + pub(crate) registry_path: PathBuf, +} + +impl Archivist { + /// Legacy constructor: builds a single JsonlBackend rooted at + /// `registry_path.parent()`. Kept so dev-instance migration still + /// succeeds before Task 28 migrates consumers to `from_config`. + pub async fn new(registry_path: PathBuf) -> Result { + use crate::backends::JsonlBackend; + + let mut registrations: Vec> = Vec::new(); + if !registry_path.as_os_str().is_empty() { + let archive_root = registry_path + .parent() + .map(|p| p.to_path_buf()) + .unwrap_or_else(|| registry_path.clone()); + let backend = Arc::new(JsonlBackend::new(archive_root).await?) + as Arc; + let initial_health = backend.health_check().await; + registrations.push(Arc::new(ArchiveRegistration::new( + "main".into(), + "jsonl", + backend, + /* write_active */ true, + FailureMode::Required, + /* read_priority */ 0, + /* enabled */ true, + WritePolicy::Inline, + /* writer */ None, + initial_health, + ))); + } + + Ok(Self { + registrations: RwLock::new(registrations), + read_cache: Arc::new(ReadCache::new()), + registry_path, + }) + } + + /// Construct a coordinator with a single `JsonlBackend` archive named + /// "main" rooted at `archive_root`. + pub async fn new_with_single_archive(archive_root: PathBuf) -> Result { + use crate::backends::JsonlBackend; + + let backend = Arc::new(JsonlBackend::new(archive_root).await?) + as Arc; + let initial_health = backend.health_check().await; + let reg = Arc::new(ArchiveRegistration::new( + "main".into(), + "jsonl", + backend, + true, + FailureMode::Required, + 0, + true, + WritePolicy::Inline, + None, + initial_health, + )); + Ok(Self { + registrations: RwLock::new(vec![reg]), + read_cache: Arc::new(ReadCache::new()), + registry_path: PathBuf::new(), + }) + } + + /// Construct a coordinator with a pre-built single backend (for tests + /// that need to hold the backend directly alongside the coordinator). + pub async fn from_single_backend( + name: String, + backend: Arc, + ) -> Result { + let initial_health = backend.health_check().await; + let reg = Arc::new(ArchiveRegistration::new( + name, + "external", + backend, + true, + FailureMode::Required, + 0, + true, + WritePolicy::Inline, + None, + initial_health, + )); + Ok(Self { + registrations: RwLock::new(vec![reg]), + read_cache: Arc::new(ReadCache::new()), + registry_path: PathBuf::new(), + }) + } + + /// Resolve a single backend by optional name. + /// + /// `None` → lowest-`read_priority` enabled write-active `Required` + /// backend. `Some(name)` → the backend with that name (must exist). + #[allow(dead_code)] // wired up in later tasks + pub(crate) async fn resolve_backend( + &self, + archive: Option, + ) -> Result> { + let regs = self.registrations.read().await; + + if regs.is_empty() { + return Err(ArchivistError::NoArchiveConfigured); + } + + let chosen = match archive { + Some(name) => match regs.iter().find(|r| r.name == name) { + Some(r) => r, + None => return Err(ArchivistError::ArchiveNameUnknown(name)), + }, + None => regs + .iter() + .filter(|r| { + r.enabled && r.write_active && r.failure_mode == FailureMode::Required + }) + .min_by_key(|r| r.read_priority) + .ok_or_else(|| ArchivistError::PrimaryUnavailable { + name: "".into(), + reason: "no required write-active backend".into(), + })?, + }; + + Ok(chosen.backend.clone()) + } + + /// Resolve the primary `ArchiveRegistration` for a write. + /// + /// `None` → default-write target (lowest `read_priority` among enabled + /// write-active `Required` backends). `Some(name)` → the backend with that + /// name; errors if disabled or not write-active. + #[allow(dead_code)] // wired up in Task 16 + pub(crate) async fn resolve_primary( + &self, + archive: Option, + ) -> Result> { + let regs = self.registrations.read().await; + if regs.is_empty() { + return Err(ArchivistError::NoArchiveConfigured); + } + let chosen = match archive { + Some(name) => { + let r = regs + .iter() + .find(|r| r.name == name) + .ok_or_else(|| ArchivistError::ArchiveNameUnknown(name.clone()))?; + if !r.enabled { + return Err(ArchivistError::PrimaryUnavailable { + name: r.name.clone(), + reason: "backend is disabled".into(), + }); + } + if !r.write_active { + return Err(ArchivistError::PrimaryUnavailable { + name: r.name.clone(), + reason: "backend is not write-active".into(), + }); + } + r.clone() + } + None => regs + .iter() + .filter(|r| { + r.enabled + && r.write_active + && r.failure_mode == crate::registry::FailureMode::Required + }) + .min_by_key(|r| r.read_priority) + .cloned() + .ok_or_else(|| ArchivistError::PrimaryUnavailable { + name: "".into(), + reason: "no required write-active backend".into(), + })?, + }; + Ok(chosen) + } +} + +#[cfg(any(test, feature = "test-utils"))] +impl Archivist { + /// Test-only: construct directly from pre-built registrations. + pub fn from_registrations( + regs: Vec>, + ) -> Self { + Self { + registrations: tokio::sync::RwLock::new(regs), + read_cache: std::sync::Arc::new(crate::registry::cache::ReadCache::new()), + registry_path: std::path::PathBuf::new(), + } + } +} + +#[cfg(test)] +mod tests; diff --git a/crates/dirigent_archivist/src/coordinator/routing.rs b/crates/dirigent_archivist/src/coordinator/routing.rs new file mode 100644 index 0000000..47b0eed --- /dev/null +++ b/crates/dirigent_archivist/src/coordinator/routing.rs @@ -0,0 +1,136 @@ +//! Read priority walk shared by every per-scroll_id and collection-shape +//! coordinator method. +//! +//! The walk honours per-backend `enabled`, caller-supplied capability +//! predicates, and current health. Per-scroll_id reads populate a positive +//! LRU cache keyed on `scroll_id`, so the second read for the same session +//! can short-circuit the priority walk. + +use std::sync::Arc; + +use uuid::Uuid; + +use crate::backend::ArchiveBackend; +use crate::error::Result; +use crate::registry::ArchiveRegistration; + +use super::Archivist; + +impl Archivist { + /// Walk enabled + healthy registrations in `read_priority` order. + /// + /// `predicate` decides whether a backend can serve the read (typically a + /// capability check). `op` is invoked on the first matching backend: + /// - `Ok(Some(value))` — wins the walk; per-scroll_id cache is updated; returned. + /// - `Ok(None)` — backend doesn't have it; continue. + /// - `Err(_)` — drift the backend's health and continue. + pub(crate) async fn read_walk_per_session( + &self, + scroll_id: Uuid, + predicate: P, + op: F, + ) -> Result> + where + T: Send, + P: Fn(&ArchiveRegistration) -> bool + Send + Sync, + F: Fn(Arc) -> Fut + Send + Sync, + Fut: std::future::Future>> + Send, + { + // Cache hit: try the cached backend first. + if let Some(cached_name) = self.read_cache.get(scroll_id).await { + if let Some(reg) = self.find_registration(&cached_name).await { + if predicate(®) && reg.enabled && !self.is_unavailable(®).await { + match op(reg.backend.clone()).await { + Ok(Some(value)) => return Ok(Some(value)), + Ok(None) => { + // Cached entry no longer holds — invalidate and fall through. + self.read_cache.invalidate(scroll_id).await; + } + Err(_) => { + self.record_read_failure(®).await; + self.read_cache.invalidate(scroll_id).await; + } + } + } + } + } + + // Priority walk. + let regs: Vec> = self.registrations.read().await.clone(); + for reg in regs.iter() { + if !reg.enabled || !predicate(reg) || self.is_unavailable(reg).await { + continue; + } + match op(reg.backend.clone()).await { + Ok(Some(value)) => { + self.record_read_success(reg).await; + self.read_cache.put(scroll_id, reg.name.clone()).await; + return Ok(Some(value)); + } + Ok(None) => { + self.record_read_success(reg).await; + continue; + } + Err(_) => { + self.record_read_failure(reg).await; + continue; + } + } + } + + Ok(None) + } + + /// Collection-shape read variant: returns the first enabled/healthy backend's + /// result, no cache. `op`'s return type is `Result` (no `Option`): + /// an error is treated as "backend couldn't serve this" and drifted; `Ok(T)` + /// is the answer. + pub(crate) async fn read_walk_collection( + &self, + predicate: P, + op: F, + ) -> Result> + where + T: Send, + P: Fn(&ArchiveRegistration) -> bool + Send + Sync, + F: Fn(Arc) -> Fut + Send + Sync, + Fut: std::future::Future> + Send, + { + let regs: Vec> = self.registrations.read().await.clone(); + for reg in regs.iter() { + if !reg.enabled || !predicate(reg) || self.is_unavailable(reg).await { + continue; + } + match op(reg.backend.clone()).await { + Ok(value) => { + self.record_read_success(reg).await; + return Ok(Some(value)); + } + Err(_) => { + self.record_read_failure(reg).await; + continue; + } + } + } + Ok(None) + } + + pub(crate) async fn find_registration( + &self, + name: &str, + ) -> Option> { + self.registrations + .read() + .await + .iter() + .find(|r| r.name == name) + .cloned() + } + + async fn is_unavailable(&self, reg: &ArchiveRegistration) -> bool { + matches!( + *reg.last_health.read().await, + crate::backend::HealthStatus::Unavailable { .. } + ) + } +} diff --git a/crates/dirigent_archivist/src/coordinator/sessions.rs b/crates/dirigent_archivist/src/coordinator/sessions.rs new file mode 100644 index 0000000..f0c8c5e --- /dev/null +++ b/crates/dirigent_archivist/src/coordinator/sessions.rs @@ -0,0 +1,1470 @@ +//! Session orchestration for `Archivist`. +//! +//! Covers registration (with alias detection), resolution, mandatory-primitive +//! dispatch, read-modify-write metadata wrappers, and move/copy semantics. +//! Ported from the `FileBasedArchivist` bodies in `archivist.rs`; persistence +//! routes through backend sub-traits (`SessionMappingBackend`, +//! `ConnectorRegistryBackend`) plus the mandatory `ArchiveBackend` surface. + +use chrono::Utc; +use uuid::Uuid; + +use crate::backend::ArchiveCapability; +use crate::coordinator::Archivist; +use crate::error::{ArchivistError, Result}; +use crate::types::{ + MessageCursor, MessagePage, MessageRecord, MoveReport, RegisterSessionRequest, + RegisterSessionResponse, RegisterStatus, SessionCompleteness, SessionListQuery, + SessionMetadata, SessionPage, MAX_PAGE_LIMIT, +}; + +impl Archivist { + // ------------------------------------------------------------------ + // Filter helpers (Task 20) + // ------------------------------------------------------------------ + + /// Evaluate a registration's filter against optional session metadata. + /// + /// - If the filter is unrestricted, always allows. + /// - If the filter is restricted and metadata is present, delegates to + /// `ArchiveFilter::allows`. + /// - If the filter is restricted but metadata is unavailable, rejects + /// (the safe default: without metadata we cannot prove the session + /// should be replicated). + pub(crate) fn filter_allows( + reg: &crate::registry::ArchiveRegistration, + session: Option<&SessionMetadata>, + ) -> bool { + if reg.filter.is_unrestricted() { + return true; + } + match session { + Some(s) => reg.filter.allows(s, &s.connector_uid), + None => false, + } + } + + /// Lazily load session metadata for filter evaluation. + /// + /// Returns `None` when none of the non-primary registrations carry a + /// non-unrestricted filter (no load needed) or when the load itself + /// failed (e.g. the session isn't present anywhere yet). + pub(crate) async fn load_metadata_for_filter( + &self, + scroll_id: Uuid, + regs: &[std::sync::Arc], + primary_name: &str, + ) -> Option { + let any_restricted = regs.iter().any(|r| { + r.name != primary_name + && r.enabled + && r.write_active + && !r.filter.is_unrestricted() + }); + if !any_restricted { + return None; + } + self.get_session_metadata(scroll_id, None).await.ok() + } + + // ------------------------------------------------------------------ + // Registration & alias detection + // ------------------------------------------------------------------ + + /// Register a session with alias detection. + /// + /// Ported from `FileBasedArchivist::register_session`. Persistence routes + /// through the backend's `SessionMappingBackend::put_mapping` and + /// `ArchiveBackend::put_session` methods; alias detection + /// (Accepted vs. Aliased) stays at the coordinator. + pub async fn register_session( + &self, + req: RegisterSessionRequest, + archive: Option, + ) -> Result { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + let primary_mapping = primary.backend.as_session_mapping().ok_or_else(|| { + ArchivistError::PrimaryUnavailable { + name: primary.name.clone(), + reason: "backend lacks SessionMapping capability".into(), + } + })?; + let primary_registry = primary.backend.as_connector_registry().ok_or_else(|| { + ArchivistError::PrimaryUnavailable { + name: primary.name.clone(), + reason: "backend lacks ConnectorRegistry capability".into(), + } + })?; + + // Generate scroll_id (use custom_scroll_id only if it's UUID7, otherwise generate new) + // This validation prevents UUID4 leaks into folder names. + let scroll_id = match req.custom_scroll_id { + Some(uuid) if uuid.get_version_num() == 7 => { + tracing::debug!("Using provided UUID7 as scroll_id: {}", uuid); + uuid + } + Some(uuid) => { + tracing::warn!( + "Rejected non-UUID7 custom_scroll_id: {} (version {}), generating fresh UUID7", + uuid, + uuid.get_version_num() + ); + Uuid::now_v7() + } + None => Uuid::now_v7(), + }; + + // Validate that the connector exists on the PRIMARY. + if primary_registry + .get_connector(req.connector_uid) + .await? + .is_none() + { + return Err(ArchivistError::ConnectorUnknown(req.connector_uid)); + } + + // Alias detection: if a mapping already exists for this + // (connector_uid, native_session_id) on the PRIMARY, return ALIASED. + // Alias detection stays on the primary only — the answer is canonical + // and must not be per-backend. + if let Some(existing_scroll) = primary_mapping + .get_mapping(req.connector_uid, &req.native_session_id) + .await? + { + return Ok(RegisterSessionResponse { + status: RegisterStatus::Aliased, + scroll_id: existing_scroll, + alias_of: Some(existing_scroll), + }); + } + + // ACCEPTED: write mapping first, then the session metadata. This order + // matches the original `FileBasedArchivist` sequence — if metadata + // creation fails, the mapping still lets `resolve_session` work, and + // the next `append_messages` reconstructs a minimal session.json. + let now = Utc::now(); + + // Write mapping on PRIMARY first. Any failure here propagates — the + // canonical mapping must succeed for the operation to be meaningful. + if let Err(e) = primary_mapping + .put_mapping(req.connector_uid, &req.native_session_id, scroll_id) + .await + { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + self.record_write_success(&primary).await; + + let session_metadata = SessionMetadata { + version: 1, + scroll_id, + created_at: now, + updated_at: now, + title: req.title.clone(), + connector_uid: req.connector_uid, + native_session_id: Some(req.native_session_id.clone()), + agent_id: req.agent_id.clone(), + parent_scroll_id: req.parent_scroll_id, + continuation: req.continuation, + tags: Vec::new(), + metadata: req.metadata.clone(), + no_update: false, + kind: crate::types::SessionKind::Chat, + acp_client_id: None, + is_connected: None, + current_session_id: None, + models: None, + modes: None, + config_options: None, + completeness: req.completeness, + matrix_room_id: None, + matrix_sharing_active: false, + matrix_shared_at: None, + is_subagent: req.is_subagent, + subagent_type: req.subagent_type.clone(), + spawning_tool_use_id: req.spawning_tool_use_id.clone(), + }; + + // Mirror the original best-effort semantics: if writing session.json + // fails, log and proceed — the mapping is already durable. + match &primary.write_policy { + crate::registry::WritePolicy::Inline => { + if let Err(e) = primary.backend.put_session(session_metadata.clone()).await { + tracing::warn!( + scroll_id = %scroll_id, + native_session_id = %req.native_session_id, + connector_uid = %req.connector_uid, + error = %e, + "Failed to write session metadata after mapping write. \ + Session is registered but metadata will be created on first message write." + ); + self.record_write_failure(&primary, &format!("{e}")).await; + } else { + self.record_write_success(&primary).await; + } + } + crate::registry::WritePolicy::Queued { .. } => { + let writer = primary + .writer + .as_ref() + .expect("queued policy implies writer handle present"); + // Queued put_session is fire-and-forget; errors drift health + // on the writer task side. + let _ = writer + .enqueue(crate::registry::writer::WriteOp::PutSession( + session_metadata.clone(), + )) + .await; + } + } + + // Secondaries: fan out mapping + session with capability filter. + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + + // Per-archive include/exclude filter. Checked before capability so + // cheaper checks come first. + if !reg.filter.allows(&session_metadata, &session_metadata.connector_uid) { + tracing::debug!( + archive = %reg.name, + scroll_id = %scroll_id, + op = "register_session", + "filter_skip" + ); + continue; + } + + // Mapping on secondary (capability check). Mapping writes have no + // `WriteOp` variant (SessionMapping is a sub-trait, not part of the + // queued dispatch surface), so they remain inline in both policies. + let Some(sec_mapping) = reg.backend.as_session_mapping() else { + tracing::debug!( + backend = reg.name.as_str(), + type_name = reg.type_name, + op = "register_session:mapping", + "capability_skip" + ); + continue; + }; + if let Err(e) = sec_mapping + .put_mapping(req.connector_uid, &req.native_session_id, scroll_id) + .await + { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + continue; + } + + // put_session on secondary — honours the secondary's write_policy. + match ®.write_policy { + crate::registry::WritePolicy::Inline => { + if let Err(e) = reg.backend.put_session(session_metadata.clone()).await { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + crate::registry::WritePolicy::Queued { .. } => { + let writer = reg + .writer + .as_ref() + .expect("queued policy implies writer handle present"); + if let Err(e) = writer + .enqueue(crate::registry::writer::WriteOp::PutSession( + session_metadata.clone(), + )) + .await + { + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } + } + } + } + + Ok(RegisterSessionResponse { + status: RegisterStatus::Accepted, + scroll_id, + alias_of: None, + }) + } + + /// Resolve a native session ID into a scroll_id via the mapping table. + /// + /// A missing mapping is surfaced as `SessionUnknown(Uuid::nil())` to keep + /// the legacy `Archivist::resolve_session` contract (caller expects an + /// error, not `Ok(None)`). + pub async fn resolve_session( + &self, + connector_uid: Uuid, + native_session_id: &str, + archive: Option, + ) -> Result { + let backend = self.resolve_backend(archive).await?; + let mapping = backend.as_session_mapping().ok_or_else(|| { + ArchivistError::CapabilityNotSupported { + capability: ArchiveCapability::SessionMapping, + backend: "selected".into(), + } + })?; + + match mapping.get_mapping(connector_uid, native_session_id).await? { + Some(scroll_id) => Ok(scroll_id), + None => Err(ArchivistError::SessionUnknown(Uuid::nil())), + } + } + + /// Locate the `(connector_uid, scroll_id)` that owns a native session id. + pub async fn find_session_owner( + &self, + native_session_id: &str, + archive: Option, + ) -> Result> { + let backend = self.resolve_backend(archive).await?; + let mapping = backend.as_session_mapping().ok_or_else(|| { + ArchivistError::CapabilityNotSupported { + capability: ArchiveCapability::SessionMapping, + backend: "selected".into(), + } + })?; + mapping.find_owner(native_session_id).await + } + + // ------------------------------------------------------------------ + // Mandatory primitive dispatch + // ------------------------------------------------------------------ + + /// Fetch session metadata. `None` from the backend becomes + /// `SessionUnknown` so callers can treat a missing session as an error. + pub async fn get_session_metadata( + &self, + scroll_id: Uuid, + _archive: Option, + ) -> Result { + // `archive` is now ignored for reads; routing picks the highest-priority + // backend that has the session. + self.read_walk_per_session( + scroll_id, + |_reg| true, + |backend| async move { backend.get_session(scroll_id).await }, + ) + .await? + .ok_or(ArchivistError::SessionUnknown(scroll_id)) + } + + /// List sessions with cursor pagination. The `archive` field on the query + /// selects the backend; it is consumed by the coordinator and not passed + /// on to the backend implementation. + pub async fn list_sessions_paged( + &self, + mut query: SessionListQuery, + ) -> Result { + let archive = query.archive.take(); + let backend = self.resolve_backend(archive).await?; + backend.list_sessions_paged(query).await + } + + pub async fn append_messages( + &self, + scroll_id: Uuid, + messages: Vec, + archive: Option, + ) -> Result<()> { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + match &primary.write_policy { + crate::registry::WritePolicy::Inline => { + if let Err(e) = primary + .backend + .append_messages(scroll_id, messages.clone()) + .await + { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + self.record_write_success(&primary).await; + } + crate::registry::WritePolicy::Queued { .. } => { + let writer = primary + .writer + .as_ref() + .expect("queued policy implies writer handle present"); + writer + .enqueue(crate::registry::writer::WriteOp::AppendMessages { + scroll_id, + msgs: messages.clone(), + }) + .await?; + } + } + + // Load metadata once for filter checks. Only needed if at least one + // non-primary secondary has a non-unrestricted filter. + let session_metadata_for_filter = self + .load_metadata_for_filter(scroll_id, ®s, &primary.name) + .await; + + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + if !Self::filter_allows(reg, session_metadata_for_filter.as_ref()) { + tracing::debug!( + archive = %reg.name, + scroll_id = %scroll_id, + op = "append_messages", + "filter_skip" + ); + continue; + } + match ®.write_policy { + crate::registry::WritePolicy::Inline => { + if let Err(e) = reg.backend.append_messages(scroll_id, messages.clone()).await + { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + crate::registry::WritePolicy::Queued { .. } => { + let writer = reg + .writer + .as_ref() + .expect("queued policy implies writer handle present"); + if let Err(e) = writer + .enqueue(crate::registry::writer::WriteOp::AppendMessages { + scroll_id, + msgs: messages.clone(), + }) + .await + { + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } + } + } + } + Ok(()) + } + + pub async fn get_messages_paged( + &self, + scroll_id: Uuid, + cursor: Option, + limit: usize, + _archive: Option, + ) -> Result { + // `archive` is now ignored for reads; routing picks the highest-priority + // backend that has the session. + Ok(self + .read_walk_per_session( + scroll_id, + |_reg| true, + |backend| { + let cursor = cursor.clone(); + async move { + backend + .get_messages_paged(scroll_id, cursor, limit) + .await + .map(Some) + } + }, + ) + .await? + .unwrap_or(MessagePage { + items: Vec::new(), + next_cursor: None, + })) + } + + /// Fetch all messages for a session by draining every cursor page. + /// + /// This is a convenience wrapper over `get_messages_paged` that reconstructs + /// the flat `Vec` shape the pre-Phase-2 `Archivist::get_messages` trait + /// method promised. Consumers that know they want to stream large histories + /// should prefer `get_messages_paged` directly. + pub async fn get_messages( + &self, + scroll_id: Uuid, + archive: Option, + ) -> Result> { + let mut out = Vec::new(); + let mut cursor = None; + loop { + let page = self + .get_messages_paged( + scroll_id, + cursor, + crate::types::MAX_PAGE_LIMIT, + archive.clone(), + ) + .await?; + out.extend(page.items); + match page.next_cursor { + Some(c) => cursor = Some(c), + None => return Ok(out), + } + } + } + + /// Return a slice of messages by offset/limit (sorted chronologically). + /// + /// Ported from `FileBasedArchivist::get_messages_range`. This loads the + /// full message set (via `get_messages`) and slices it — it is NOT cursor- + /// based and should only be used for small offsets / tests. For anything + /// production-sized, use `get_messages_paged`. + pub async fn get_messages_range( + &self, + scroll_id: Uuid, + offset: usize, + limit: usize, + archive: Option, + ) -> Result> { + let all = self.get_messages(scroll_id, archive).await?; + Ok(all.into_iter().skip(offset).take(limit).collect()) + } + + pub async fn count_messages( + &self, + scroll_id: Uuid, + _archive: Option, + ) -> Result { + // `archive` is now ignored for reads; routing picks the highest-priority + // backend that has the session. + Ok(self + .read_walk_per_session( + scroll_id, + |_reg| true, + |backend| async move { backend.count_messages(scroll_id).await.map(Some) }, + ) + .await? + .unwrap_or(0)) + } + + pub async fn delete_session( + &self, + scroll_id: Uuid, + _archive: Option, // ignored: delete spans every backend that has the session + ) -> Result<()> { + let regs: Vec> = + self.registrations.read().await.clone(); + + let mut read_only_violations: Vec = Vec::new(); + let mut last_required_error: Option = None; + + for reg in regs.iter() { + if !reg.enabled { + continue; + } + + // Cheap existence check before attempting delete. We treat read failures + // here as "backend doesn't have it" to avoid cascading errors. + let exists = reg + .backend + .get_session(scroll_id) + .await + .ok() + .flatten() + .is_some(); + if !exists { + continue; + } + + if !reg.write_active { + read_only_violations.push(reg.name.clone()); + continue; + } + + match ®.write_policy { + crate::registry::WritePolicy::Inline => { + match reg.backend.delete_session(scroll_id).await { + Ok(()) => self.record_write_success(reg).await, + Err(e) => { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required + && last_required_error.is_none() + { + last_required_error = Some(e); + } + } + } + } + crate::registry::WritePolicy::Queued { .. } => { + if let Some(writer) = reg.writer.as_ref() { + let _ = writer + .enqueue(crate::registry::writer::WriteOp::DeleteSession { scroll_id }) + .await; + } + } + } + } + + // Cache invalidation regardless of outcome — the session no longer has a stable home. + self.read_cache.invalidate(scroll_id).await; + + if let Some(e) = last_required_error { + return Err(e); + } + if let Some(name) = read_only_violations.into_iter().next() { + return Err(ArchivistError::DeleteOnReadOnlyBackend { + backend: name, + scroll_id, + }); + } + Ok(()) + } + + pub async fn clear_session_messages( + &self, + scroll_id: Uuid, + archive: Option, + ) -> Result<()> { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + match &primary.write_policy { + crate::registry::WritePolicy::Inline => { + if let Err(e) = primary.backend.clear_session_messages(scroll_id).await { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + self.record_write_success(&primary).await; + } + crate::registry::WritePolicy::Queued { .. } => { + let writer = primary + .writer + .as_ref() + .expect("queued policy implies writer handle present"); + writer + .enqueue(crate::registry::writer::WriteOp::ClearSessionMessages { + scroll_id, + }) + .await?; + } + } + + let session_metadata_for_filter = self + .load_metadata_for_filter(scroll_id, ®s, &primary.name) + .await; + + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + if !Self::filter_allows(reg, session_metadata_for_filter.as_ref()) { + tracing::debug!( + archive = %reg.name, + scroll_id = %scroll_id, + op = "clear_session_messages", + "filter_skip" + ); + continue; + } + match ®.write_policy { + crate::registry::WritePolicy::Inline => { + if let Err(e) = reg.backend.clear_session_messages(scroll_id).await { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + crate::registry::WritePolicy::Queued { .. } => { + let writer = reg + .writer + .as_ref() + .expect("queued policy implies writer handle present"); + if let Err(e) = writer + .enqueue(crate::registry::writer::WriteOp::ClearSessionMessages { + scroll_id, + }) + .await + { + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } + } + } + } + Ok(()) + } + + // ------------------------------------------------------------------ + // Read-modify-write metadata wrappers + // ------------------------------------------------------------------ + + /// Update `title` and/or `model` on an existing session. + /// + /// Ported from `FileBasedArchivist::update_session_metadata`. `title` is + /// set directly; `model` is written into `metadata.model` (JSON object). + /// `updated_at` is always bumped. + /// + /// NOTE: Read-mutate-write operations silently fall through to inline + /// even under `WritePolicy::Queued`. There is no `WriteOp::UpdateSessionMetadata` + /// variant because RMW doesn't compose with batching — the read-side of the + /// RMW must see a consistent, already-persisted session. + pub async fn update_session_metadata( + &self, + scroll_id: Uuid, + title: Option, + model: Option, + archive: Option, + ) -> Result<()> { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + let apply = |backend: std::sync::Arc, + title: Option, + model: Option| + -> std::pin::Pin>> + Send>> { + Box::pin(async move { + let mut session_metadata = backend + .get_session(scroll_id) + .await? + .ok_or(ArchivistError::SessionUnknown(scroll_id))?; + + if let Some(new_title) = title { + session_metadata.title = Some(new_title); + } + if let Some(new_model) = model { + if let Some(obj) = session_metadata.metadata.as_object_mut() { + obj.insert("model".to_string(), serde_json::Value::String(new_model)); + } + } + + session_metadata.updated_at = Utc::now(); + let title_dbg = session_metadata.title.clone(); + backend.put_session(session_metadata).await?; + Ok(title_dbg) + }) + }; + + let title_dbg = match apply(primary.backend.clone(), title.clone(), model.clone()).await { + Ok(t) => { + self.record_write_success(&primary).await; + t + } + Err(e) => { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + }; + + let session_metadata_for_filter = self + .load_metadata_for_filter(scroll_id, ®s, &primary.name) + .await; + + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + if !Self::filter_allows(reg, session_metadata_for_filter.as_ref()) { + tracing::debug!( + archive = %reg.name, + scroll_id = %scroll_id, + op = "update_session_metadata", + "filter_skip" + ); + continue; + } + if let Err(e) = apply(reg.backend.clone(), title.clone(), model.clone()).await { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + + tracing::info!( + scroll_id = %scroll_id, + title = ?title_dbg, + "Updated session metadata" + ); + + Ok(()) + } + + /// Update ACP-specific metadata (`models`, `modes`, `config_options`). + /// + /// NOTE: read-mutate-write; falls through to inline under + /// `WritePolicy::Queued` (no `WriteOp` variant). + pub async fn update_session_acp_metadata( + &self, + scroll_id: Uuid, + models: Option, + modes: Option, + config_options: Option, + archive: Option, + ) -> Result<()> { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + let apply = |backend: std::sync::Arc, + models: Option, + modes: Option, + config_options: Option| + -> std::pin::Pin> + Send>> { + Box::pin(async move { + let mut session_metadata = backend + .get_session(scroll_id) + .await? + .ok_or(ArchivistError::SessionUnknown(scroll_id))?; + + if let Some(new_models) = models { + session_metadata.models = Some(new_models); + } + if let Some(new_modes) = modes { + session_metadata.modes = Some(new_modes); + } + if let Some(new_config_options) = config_options { + session_metadata.config_options = Some(new_config_options); + } + + session_metadata.updated_at = Utc::now(); + + let has_models = session_metadata.models.is_some(); + let has_modes = session_metadata.modes.is_some(); + let has_config_options = session_metadata.config_options.is_some(); + + backend.put_session(session_metadata).await?; + Ok((has_models, has_modes, has_config_options)) + }) + }; + + let (has_models, has_modes, has_config_options) = match apply( + primary.backend.clone(), + models.clone(), + modes.clone(), + config_options.clone(), + ) + .await + { + Ok(flags) => { + self.record_write_success(&primary).await; + flags + } + Err(e) => { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + }; + + let session_metadata_for_filter = self + .load_metadata_for_filter(scroll_id, ®s, &primary.name) + .await; + + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + if !Self::filter_allows(reg, session_metadata_for_filter.as_ref()) { + tracing::debug!( + archive = %reg.name, + scroll_id = %scroll_id, + op = "update_session_acp_metadata", + "filter_skip" + ); + continue; + } + if let Err(e) = apply( + reg.backend.clone(), + models.clone(), + modes.clone(), + config_options.clone(), + ) + .await + { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + + tracing::info!( + scroll_id = %scroll_id, + has_models = has_models, + has_modes = has_modes, + has_config_options = has_config_options, + "Updated session agent metadata" + ); + + Ok(()) + } + + /// Update the Matrix sharing flags on a session. `matrix_shared_at` is + /// stamped on the first transition to `sharing_active=true`. + /// + /// NOTE: read-mutate-write; falls through to inline under + /// `WritePolicy::Queued` (no `WriteOp` variant). + pub async fn update_session_sharing( + &self, + scroll_id: Uuid, + matrix_room_id: Option, + sharing_active: bool, + archive: Option, + ) -> Result<()> { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + let apply = |backend: std::sync::Arc, + matrix_room_id: Option, + sharing_active: bool| + -> std::pin::Pin> + Send>> { + Box::pin(async move { + let mut session_metadata = backend + .get_session(scroll_id) + .await? + .ok_or(ArchivistError::SessionUnknown(scroll_id))?; + + session_metadata.matrix_room_id = matrix_room_id; + session_metadata.matrix_sharing_active = sharing_active; + if sharing_active && session_metadata.matrix_shared_at.is_none() { + session_metadata.matrix_shared_at = Some(Utc::now()); + } + session_metadata.updated_at = Utc::now(); + + backend.put_session(session_metadata).await + }) + }; + + if let Err(e) = apply(primary.backend.clone(), matrix_room_id.clone(), sharing_active).await + { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + self.record_write_success(&primary).await; + + let session_metadata_for_filter = self + .load_metadata_for_filter(scroll_id, ®s, &primary.name) + .await; + + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + if !Self::filter_allows(reg, session_metadata_for_filter.as_ref()) { + tracing::debug!( + archive = %reg.name, + scroll_id = %scroll_id, + op = "update_session_sharing", + "filter_skip" + ); + continue; + } + if let Err(e) = + apply(reg.backend.clone(), matrix_room_id.clone(), sharing_active).await + { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + + tracing::debug!( + scroll_id = %scroll_id, + sharing_active = sharing_active, + "Updated session sharing metadata" + ); + + Ok(()) + } + + /// Set the `completeness` level on a session. + /// + /// NOTE: read-mutate-write; falls through to inline under + /// `WritePolicy::Queued` (no `WriteOp` variant). + pub async fn update_session_completeness( + &self, + scroll_id: Uuid, + completeness: SessionCompleteness, + archive: Option, + ) -> Result<()> { + let primary = self.resolve_primary(archive.clone()).await?; + let regs: Vec> = + self.registrations.read().await.clone(); + + let apply = |backend: std::sync::Arc, + completeness: SessionCompleteness| + -> std::pin::Pin> + Send>> { + Box::pin(async move { + let mut session_metadata = backend + .get_session(scroll_id) + .await? + .ok_or(ArchivistError::SessionUnknown(scroll_id))?; + + session_metadata.completeness = completeness; + session_metadata.updated_at = Utc::now(); + + backend.put_session(session_metadata).await + }) + }; + + if let Err(e) = apply(primary.backend.clone(), completeness).await { + self.record_write_failure(&primary, &format!("{e}")).await; + return Err(e); + } + self.record_write_success(&primary).await; + + let session_metadata_for_filter = self + .load_metadata_for_filter(scroll_id, ®s, &primary.name) + .await; + + for reg in regs.iter() { + if reg.name == primary.name { + continue; + } + if !reg.enabled || !reg.write_active { + continue; + } + if !Self::filter_allows(reg, session_metadata_for_filter.as_ref()) { + tracing::debug!( + archive = %reg.name, + scroll_id = %scroll_id, + op = "update_session_completeness", + "filter_skip" + ); + continue; + } + if let Err(e) = apply(reg.backend.clone(), completeness).await { + self.record_write_failure(reg, &format!("{e}")).await; + if reg.failure_mode == crate::registry::FailureMode::Required { + return Err(e); + } + } else { + self.record_write_success(reg).await; + } + } + + tracing::debug!( + scroll_id = %scroll_id, + completeness = ?completeness, + "Updated session completeness" + ); + + Ok(()) + } + + // ------------------------------------------------------------------ + // Listing helpers + // ------------------------------------------------------------------ + + /// Return every session in the selected archive whose + /// `matrix_sharing_active` flag is set. + /// + /// `SessionListQuery` does not yet expose a `matrix_sharing_active` + /// filter, so we page through every session (including hidden ones) and + /// filter client-side. Included-hidden matches the original + /// `FileBasedArchivist::list_sessions_with_active_sharing`, which scanned + /// `.contexts/` without respecting `no_update` / `is_subagent`. + pub async fn list_sessions_with_active_sharing( + &self, + archive: Option, + ) -> Result> { + let backend = self.resolve_backend(archive).await?; + + let mut results = Vec::new(); + let mut q = SessionListQuery { + include_hidden: true, + limit: MAX_PAGE_LIMIT, + ..SessionListQuery::default() + }; + + loop { + let page = backend.list_sessions_paged(q.clone()).await?; + for session in page.items { + if session.matrix_sharing_active { + results.push(session); + } + } + match page.next_cursor { + Some(c) => q.cursor = Some(c), + None => break, + } + } + + Ok(results) + } + + // ------------------------------------------------------------------ + // Move & copy + // ------------------------------------------------------------------ + + /// Move a session to a different connector. + /// + /// Ported from `FileBasedArchivist::move_session`: updates the session's + /// `connector_uid`, removes the mapping from the source connector's + /// table (via `rewrite_connector_mappings`), and appends a new mapping + /// to the target. + pub async fn move_session_to_connector( + &self, + scroll_id: Uuid, + target_connector_uid: Uuid, + archive: Option, + ) -> Result<()> { + let backend = self.resolve_backend(archive).await?; + let mapping = backend.as_session_mapping().ok_or_else(|| { + ArchivistError::CapabilityNotSupported { + capability: ArchiveCapability::SessionMapping, + backend: "selected".into(), + } + })?; + let registry = backend.as_connector_registry().ok_or_else(|| { + ArchivistError::CapabilityNotSupported { + capability: ArchiveCapability::ConnectorRegistry, + backend: "selected".into(), + } + })?; + + // 1. Read session metadata. + let mut metadata = backend + .get_session(scroll_id) + .await? + .ok_or(ArchivistError::SessionUnknown(scroll_id))?; + let old_connector_uid = metadata.connector_uid; + if old_connector_uid == target_connector_uid { + return Ok(()); + } + + // 2. Verify target connector exists. + if registry.get_connector(target_connector_uid).await?.is_none() { + return Err(ArchivistError::ConnectorUnknown(target_connector_uid)); + } + + // 3. Update session metadata with new connector_uid. + metadata.connector_uid = target_connector_uid; + let native_session_id = metadata + .native_session_id + .clone() + .unwrap_or_else(|| scroll_id.to_string()); + backend.put_session(metadata).await?; + + // 4. Remove mapping from old connector by rewriting its sessions + // table without the moved row. The trait's + // `rewrite_connector_mappings` also handles cache invalidation. + let old_mappings = mapping.list_mappings_for_connector(old_connector_uid).await?; + let filtered: Vec<_> = old_mappings + .into_iter() + .filter(|m| m.scroll_id != scroll_id) + .collect(); + mapping + .rewrite_connector_mappings(old_connector_uid, filtered) + .await?; + + // 5. Append mapping to target connector. + mapping + .put_mapping(target_connector_uid, &native_session_id, scroll_id) + .await?; + + tracing::info!( + scroll_id = %scroll_id, + from = %old_connector_uid, + to = %target_connector_uid, + "Moved session to new connector" + ); + + Ok(()) + } + + /// Bulk-move many sessions. Collects per-session errors into a + /// `MoveReport` without aborting the whole batch. + pub async fn move_sessions_to_connector( + &self, + scroll_ids: Vec, + target_connector_uid: Uuid, + archive: Option, + ) -> Result { + let mut report = MoveReport::default(); + + for scroll_id in scroll_ids { + match self + .move_session_to_connector(scroll_id, target_connector_uid, archive.clone()) + .await + { + Ok(()) => { + report.moved += 1; + } + Err(e) => { + report.failed += 1; + report + .errors + .push(format!("Failed to move session {}: {}", scroll_id, e)); + tracing::warn!( + scroll_id = %scroll_id, + error = %e, + "Failed to move session during bulk move" + ); + } + } + } + + tracing::info!( + moved = report.moved, + failed = report.failed, + target = %target_connector_uid, + "Completed bulk session move" + ); + + Ok(report) + } + + /// Copy a session to a new connector under a fresh scroll_id. + /// + /// Copies the session metadata (with updated `scroll_id`, `connector_uid`, + /// `created_at`, `updated_at`) and all messages, page by page. The target + /// connector gets a new mapping entry. + pub async fn copy_session_to_connector( + &self, + scroll_id: Uuid, + target_connector_uid: Uuid, + archive: Option, + ) -> Result { + let backend = self.resolve_backend(archive).await?; + let mapping = backend.as_session_mapping().ok_or_else(|| { + ArchivistError::CapabilityNotSupported { + capability: ArchiveCapability::SessionMapping, + backend: "selected".into(), + } + })?; + let registry = backend.as_connector_registry().ok_or_else(|| { + ArchivistError::CapabilityNotSupported { + capability: ArchiveCapability::ConnectorRegistry, + backend: "selected".into(), + } + })?; + + // 1. Read source session metadata. + let source_metadata = backend + .get_session(scroll_id) + .await? + .ok_or(ArchivistError::SessionUnknown(scroll_id))?; + + // 2. Verify target connector exists. + if registry.get_connector(target_connector_uid).await?.is_none() { + return Err(ArchivistError::ConnectorUnknown(target_connector_uid)); + } + + // 3. Create new scroll_id + timestamp. + let new_scroll_id = Uuid::now_v7(); + let now = Utc::now(); + + // 4. Persist new session metadata with updated identity. + let new_metadata = SessionMetadata { + scroll_id: new_scroll_id, + connector_uid: target_connector_uid, + created_at: now, + updated_at: now, + ..source_metadata + }; + let native_session_id = new_metadata + .native_session_id + .clone() + .unwrap_or_else(|| new_scroll_id.to_string()); + backend.put_session(new_metadata).await?; + + // 5. Copy messages via the paged primitive. Uses MAX_PAGE_LIMIT + // per batch to stream large sessions without loading everything. + let mut cursor: Option = None; + loop { + let page = backend + .get_messages_paged(scroll_id, cursor.clone(), MAX_PAGE_LIMIT) + .await?; + if page.items.is_empty() { + break; + } + + // Rewrite each message's session reference to the new scroll_id + // so the copied session owns its own records. + let rewritten: Vec = page + .items + .into_iter() + .map(|mut m| { + m.session = new_scroll_id; + m + }) + .collect(); + + backend.append_messages(new_scroll_id, rewritten).await?; + + match page.next_cursor { + Some(c) => cursor = Some(c), + None => break, + } + } + + // 6. Register mapping for the new (connector, native_id) → new_scroll_id. + mapping + .put_mapping(target_connector_uid, &native_session_id, new_scroll_id) + .await?; + + tracing::info!( + source_scroll_id = %scroll_id, + new_scroll_id = %new_scroll_id, + target_connector = %target_connector_uid, + "Copied session to new connector" + ); + + Ok(new_scroll_id) + } + + // ------------------------------------------------------------------ + // Cross-archive move/copy (Phase 3) + // ------------------------------------------------------------------ + + /// Copy a session from one archive to another, preserving `scroll_id`. + /// + /// Copies session metadata, message pages, DAG edges, and meta events + /// where both sides have the required capability. Leaves the source + /// intact. + pub async fn copy_session( + &self, + scroll_id: Uuid, + from: &str, + to: &str, + ) -> Result<()> { + let from_reg = self + .find_registration(from) + .await + .ok_or_else(|| ArchivistError::ArchiveNameUnknown(from.into()))?; + let to_reg = self + .find_registration(to) + .await + .ok_or_else(|| ArchivistError::ArchiveNameUnknown(to.into()))?; + + if !to_reg.enabled || !to_reg.write_active { + return Err(ArchivistError::PrimaryUnavailable { + name: to_reg.name.clone(), + reason: "target backend is disabled or not write-active".into(), + }); + } + + // 1. Session metadata + let meta = from_reg + .backend + .get_session(scroll_id) + .await? + .ok_or(ArchivistError::SessionUnknown(scroll_id))?; + to_reg.backend.put_session(meta).await?; + + // 2. Messages — page through. + let mut cursor: Option = None; + loop { + let page = from_reg + .backend + .get_messages_paged(scroll_id, cursor.clone(), MAX_PAGE_LIMIT) + .await?; + if page.items.is_empty() { + break; + } + to_reg.backend.append_messages(scroll_id, page.items).await?; + match page.next_cursor { + Some(c) => cursor = Some(c), + None => break, + } + } + + // 3. DAG edges (both sides must support Dag). + if let (Some(src_dag), Some(dst_dag)) = + (from_reg.backend.as_dag(), to_reg.backend.as_dag()) + { + for edge in src_dag.get_dag_edges(scroll_id).await? { + dst_dag.append_dag_edge(edge).await?; + } + } + + // 4. Meta events (both sides must support MetaEvents). + if let (Some(src_me), Some(dst_me)) = + (from_reg.backend.as_meta_events(), to_reg.backend.as_meta_events()) + { + let events = src_me.get_meta_events(scroll_id).await?; + if !events.is_empty() { + dst_me.append_meta_events(scroll_id, events).await?; + } + } + + // Cache: leave pointing at `from` (source remains canonical). + Ok(()) + } + + /// Move a session from one archive to another: `copy_session` followed by + /// source-side delete. If the copy fails, the source is intact. If the + /// source-side delete fails AFTER a successful copy, returns + /// `ArchivistError::PartialMove`. + pub async fn move_session( + &self, + scroll_id: Uuid, + from: &str, + to: &str, + ) -> Result<()> { + // 1. Copy. + self.copy_session(scroll_id, from, to).await?; + + // 2. Delete from source only. + let from_reg = self + .find_registration(from) + .await + .ok_or_else(|| ArchivistError::ArchiveNameUnknown(from.into()))?; + + if let Err(e) = from_reg.backend.delete_session(scroll_id).await { + self.record_write_failure(&from_reg, &format!("{e}")).await; + return Err(ArchivistError::PartialMove { + copied_to: to.into(), + delete_error: Box::new(e), + }); + } + + // 3. Cache: rewrite to `to`. + self.read_cache.rewrite(scroll_id, to.into()).await; + + Ok(()) + } +} diff --git a/crates/dirigent_archivist/src/coordinator/tests.rs b/crates/dirigent_archivist/src/coordinator/tests.rs new file mode 100644 index 0000000..02ccb65 --- /dev/null +++ b/crates/dirigent_archivist/src/coordinator/tests.rs @@ -0,0 +1,195 @@ +//! Coordinator orchestration unit tests using `MockBackend`. +//! +//! These tests exercise alias detection, move/copy semantics, DAG walks, +//! and cleanup policies without any disk I/O. + +#![cfg(test)] + +use std::sync::Arc; + +use tokio::sync::RwLock; +use uuid::Uuid; + +use crate::backend::mock::MockBackend; +use crate::backend::ArchiveBackend; +use crate::coordinator::Archivist; +use crate::registry::{ + cache::ReadCache, ArchiveRegistration, FailureMode, WritePolicy, +}; +use crate::types::{ + DagEdge, MessageRecord, RegisterConnectorRequest, RegisterStatus, SessionCompleteness, + SessionKind, SessionMetadata, +}; + +/// Construct a blank `SessionMetadata` with the given `scroll_id` and +/// `connector_uid`. Sensible defaults for every other field. +fn blank_session(scroll_id: Uuid, connector_uid: Uuid) -> SessionMetadata { + let now = chrono::Utc::now(); + SessionMetadata { + version: 1, + scroll_id, + created_at: now, + updated_at: now, + title: None, + connector_uid, + native_session_id: None, + agent_id: None, + parent_scroll_id: None, + continuation: None, + tags: Vec::new(), + metadata: serde_json::Value::Null, + no_update: false, + kind: SessionKind::Chat, + acp_client_id: None, + is_connected: None, + current_session_id: None, + models: None, + modes: None, + config_options: None, + completeness: SessionCompleteness::Complete, + matrix_room_id: None, + matrix_sharing_active: false, + matrix_shared_at: None, + is_subagent: false, + subagent_type: None, + spawning_tool_use_id: None, + } +} + +/// Construct a blank `MessageRecord` scoped to the given session with a +/// freshly generated `message_id` and current timestamp. +fn blank_message(session: Uuid) -> MessageRecord { + MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session, + parent_id: None, + ts: chrono::Utc::now(), + role: "user".into(), + author: None, + content_md: String::new(), + content_parts: None, + attachments: Vec::new(), + metadata: serde_json::Value::Null, + } +} + +async fn make_coordinator_with_single_mock() -> Archivist { + let backend: Arc = Arc::new(MockBackend::new()); + let initial_health = backend.health_check().await; + let reg = Arc::new(ArchiveRegistration::new( + "main".into(), + "mock", + backend, + /* write_active */ true, + FailureMode::Required, + /* read_priority */ 0, + /* enabled */ true, + WritePolicy::Inline, + /* writer */ None, + initial_health, + )); + Archivist { + registrations: RwLock::new(vec![reg]), + read_cache: Arc::new(ReadCache::new()), + registry_path: std::path::PathBuf::from("/mock/.archives.json"), + } +} + +#[tokio::test] +async fn register_connector_assigns_uid_and_returns_accepted() { + let coord = make_coordinator_with_single_mock().await; + let req = RegisterConnectorRequest { + r#type: "OpenCode".into(), + title: "test".into(), + client_native_id: "opencode@localhost".into(), + custom_uid: None, + metadata: serde_json::Value::Null, + fingerprint: None, + }; + let resp = coord.register_connector(req, None).await.expect("register"); + assert!(matches!(resp.status, RegisterStatus::Accepted)); + assert_ne!(resp.connector_uid, Uuid::nil()); +} + +#[tokio::test] +async fn register_connector_aliases_on_duplicate_native_id() { + let coord = make_coordinator_with_single_mock().await; + let mk_req = || RegisterConnectorRequest { + r#type: "OpenCode".into(), + title: "test".into(), + client_native_id: "opencode@localhost".into(), + custom_uid: None, + metadata: serde_json::Value::Null, + fingerprint: None, + }; + let first = coord.register_connector(mk_req(), None).await.unwrap(); + let second = coord.register_connector(mk_req(), None).await.unwrap(); + assert_eq!(second.connector_uid, first.connector_uid); + assert!(matches!(second.status, RegisterStatus::Aliased)); +} + +#[tokio::test] +async fn get_session_tree_walks_full_dag() { + let coord = make_coordinator_with_single_mock().await; + let connector_uid = Uuid::now_v7(); + let root = Uuid::now_v7(); + let child_a = Uuid::now_v7(); + let child_b = Uuid::now_v7(); + let grand = Uuid::now_v7(); + + let backend = coord.registrations.read().await[0].backend.clone(); + for id in [root, child_a, child_b, grand] { + backend + .put_session(blank_session(id, connector_uid)) + .await + .unwrap(); + } + + for (p, c) in [(root, child_a), (root, child_b), (child_a, grand)] { + coord + .append_dag_edge( + DagEdge { + parent: p, + child: c, + agent_id: String::new(), + subagent_type: None, + tool_use_id: None, + ts: Some(chrono::Utc::now()), + }, + None, + ) + .await + .unwrap(); + } + + let edges = coord.get_session_tree(root, None).await.unwrap(); + assert_eq!(edges.len(), 3, "expected 3 edges, got {}", edges.len()); +} + +#[tokio::test] +async fn cleanup_empty_sessions_deletes_only_message_less_sessions() { + let coord = make_coordinator_with_single_mock().await; + + let connector_uid = Uuid::now_v7(); + let empty = Uuid::now_v7(); + let populated = Uuid::now_v7(); + + let backend = coord.registrations.read().await[0].backend.clone(); + for scroll_id in [empty, populated] { + backend + .put_session(blank_session(scroll_id, connector_uid)) + .await + .unwrap(); + } + backend + .append_messages(populated, vec![blank_message(populated)]) + .await + .unwrap(); + + let (deleted, total) = coord.cleanup_empty_sessions(None).await.unwrap(); + assert_eq!(deleted, 1); + assert_eq!(total, 2); + assert!(backend.get_session(empty).await.unwrap().is_none()); + assert!(backend.get_session(populated).await.unwrap().is_some()); +} diff --git a/crates/dirigent_archivist/src/coordinator/types.rs b/crates/dirigent_archivist/src/coordinator/types.rs new file mode 100644 index 0000000..71c1588 --- /dev/null +++ b/crates/dirigent_archivist/src/coordinator/types.rs @@ -0,0 +1,60 @@ +//! Shared data types used by the archivist coordinator. +//! +//! `ArchiveMetadata` is persisted per-archive in the registry file and +//! tracks creation time, path, and the set of connectors registered in +//! the archive. `ArchiveInfo` is the display-friendly projection returned +//! from listing APIs; it extends the metadata with computed fields like +//! session count and default-archive status. + +use std::path::PathBuf; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Metadata about a single archive. +/// +/// This structure contains all the information needed to track and display +/// an archive without loading its full backend instance. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ArchiveMetadata { + /// Unique name for this archive (e.g., "personal", "work", "experiments") + pub name: String, + + /// Filesystem path to the archive root directory + pub path: PathBuf, + + /// When this archive was first registered with the coordinator + pub created_at: DateTime, + + /// List of connector UIDs registered in this archive + /// + /// This is updated as connectors are registered/unregistered and provides + /// a quick way to see which connectors belong to which archive. + pub connector_uids: Vec, +} + +/// Display-friendly information about an archive. +/// +/// This struct is returned by listing operations and includes computed +/// fields like session count and default status. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ArchiveInfo { + /// Unique name for this archive + pub name: String, + + /// Filesystem path to the archive root directory + pub path: PathBuf, + + /// When this archive was first registered + pub created_at: DateTime, + + /// Total number of sessions across all connectors in this archive + /// + /// This is computed by counting sessions across all connectors and + /// may be expensive for large archives. + pub session_count: usize, + + /// Whether this is the current default archive + pub is_default: bool, +} diff --git a/crates/dirigent_archivist/src/error.rs b/crates/dirigent_archivist/src/error.rs new file mode 100644 index 0000000..0d42863 --- /dev/null +++ b/crates/dirigent_archivist/src/error.rs @@ -0,0 +1,314 @@ +//! Error types for the Archivist. +//! +//! This module defines all error types that can occur during archival operations, +//! including I/O errors, JSON errors, and domain-specific errors for connectors +//! and sessions. + +use std::path::PathBuf; +use thiserror::Error; +use uuid::Uuid; + +/// Result type alias for Archivist operations +pub type Result = std::result::Result; + +/// Errors that can occur during archival operations +#[derive(Debug, Error)] +pub enum ArchivistError { + /// Connector with the given UID was not found + #[error("Connector not found: {0}")] + ConnectorUnknown(Uuid), + + /// Session with the given scroll ID was not found + #[error("Session not found: {0}")] + SessionUnknown(Uuid), + + /// UUID collision detected with inconsistent data + /// + /// This occurs when a custom UUID is provided that matches an existing + /// entity but with different attributes (e.g., different connector type). + #[error("UUID collision: {0}")] + CollisionInconsistent(Uuid), + + /// Invalid request (e.g., missing required fields, invalid format) + #[error("Invalid request: {0}")] + InvalidRequest(String), + + /// I/O error during file operations + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + /// JSON serialization/deserialization error + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), + + // Multi-archive errors + /// Invalid archive name (empty or contains invalid characters) + #[error("Invalid archive name: {0}")] + InvalidArchiveName(String), + + /// Archive already exists with the given name + #[error("Archive already exists: {0}")] + ArchiveAlreadyExists(String), + + /// Archive not found with the given name + #[error("Archive not found: {0}")] + ArchiveNotFound(String), + + /// Archive path conflict (path is already used by another archive) + #[error("Archive path conflict: {0}")] + ArchivePathConflict(PathBuf), + + /// Cannot remove default archive without force flag + #[error("Cannot remove default archive without force flag")] + CannotRemoveDefaultArchive, + + /// Archive is not empty (has sessions) + #[error("Archive '{name}' is not empty ({session_count} sessions)")] + ArchiveNotEmpty { + name: String, + session_count: usize, + }, + + /// No archives configured + #[error("No archives configured")] + NoArchivesConfigured, + + /// Failed to load registry file + #[error("Failed to load registry: {0}")] + RegistryLoadError(String), + + /// Failed to parse registry JSON + #[error("Failed to parse registry: {0}")] + RegistryParseError(String), + + /// Failed to serialize registry to JSON + #[error("Failed to serialize registry: {0}")] + RegistrySerializeError(String), + + /// Failed to write registry file + #[error("Failed to write registry: {0}")] + RegistryWriteError(String), + + /// Backend is unavailable (e.g., disk full, connection lost, degraded state) + #[error("Backend {name} is unavailable")] + BackendUnavailable { name: String }, + + /// Backend does not support the requested capability + #[error("Backend {backend} does not support capability {capability:?}")] + CapabilityNotSupported { + capability: crate::backend::ArchiveCapability, + backend: String, + }, + + /// Health check for a backend failed + #[error("Health check for backend {name} failed: {reason}")] + BackendHealthCheckFailed { name: String, reason: String }, + + /// Primary write backend is unavailable or misconfigured. + #[error("primary write backend `{name}` is unavailable: {reason}")] + PrimaryUnavailable { name: String, reason: String }, + + /// Session exists on a read-only (not write_active) backend; deletion impossible. + #[error("session {scroll_id} exists in read-only backend `{backend}`; cannot delete")] + DeleteOnReadOnlyBackend { backend: String, scroll_id: uuid::Uuid }, + + /// Move succeeded at the destination but source-side delete failed. + #[error("partial move: copy to `{copied_to}` succeeded but source-side delete failed: {delete_error}")] + PartialMove { + copied_to: String, + delete_error: Box, + }, + + /// Queued-write backend's queue is full. + #[error("write queue full for backend `{backend}` (op `{op}`)")] + WriteQueueFull { + backend: String, + op: &'static str, + }, + + /// The coordinator has no archive configured (ephemeral mode). + #[error("no archive is configured (ephemeral mode)")] + NoArchiveConfigured, + + /// A requested archive name does not exist in the registry. + #[error("archive name `{0}` is unknown")] + ArchiveNameUnknown(String), + + /// Runtime mutation of the archive registry is not supported in Phase 3. + #[error("dynamic registry mutation is not supported (Phase 3 is startup-only)")] + DynamicRegistryUnsupported, + + /// Catch-all for injected failures / legacy call sites. Prefer a typed variant when possible. + #[error("{0}")] + Other(String), +} + +/// Errors raised exclusively at boot, by `Archivist::from_config`. +#[derive(Debug, thiserror::Error)] +pub enum ArchivistBootError { + #[error("duplicate archive name `{0}` in config")] + DuplicateName(String), + + #[error("archive `{name}` declares unknown type `{type_name}`")] + UnknownType { name: String, type_name: String }, + + #[error("no `required` write-active backend configured (need at least one primary)")] + NoPrimary, + + #[error("backend `{name}` failed to build: {source}")] + BackendBuild { + name: String, + #[source] + source: crate::registry::BackendBuildError, + }, + + #[error("required backend `{name}` is unavailable at boot: {reason}")] + UnavailableRequiredBackend { name: String, reason: String }, + + #[error("no unrestricted write-active archive — at least one enabled, write_active backend must have an empty filter")] + NoUnrestrictedPrimary, + + #[error("filter for archive `{archive}` rejects all sessions (include_connectors is empty)")] + FilterRejectsEverything { archive: String }, + + #[error("config validation failed: {0}")] + Validation(#[from] crate::registry::ConfigValidationError), +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io; + + #[test] + fn test_error_display() { + let uuid = Uuid::now_v7(); + + // Test ConnectorUnknown + let err = ArchivistError::ConnectorUnknown(uuid); + assert_eq!(err.to_string(), format!("Connector not found: {}", uuid)); + + // Test SessionUnknown + let err = ArchivistError::SessionUnknown(uuid); + assert_eq!(err.to_string(), format!("Session not found: {}", uuid)); + + // Test CollisionInconsistent + let err = ArchivistError::CollisionInconsistent(uuid); + assert_eq!(err.to_string(), format!("UUID collision: {}", uuid)); + + // Test InvalidRequest + let err = ArchivistError::InvalidRequest("missing field".to_string()); + assert_eq!(err.to_string(), "Invalid request: missing field"); + } + + #[test] + fn test_io_error_conversion() { + // Create an I/O error + let io_err = io::Error::new(io::ErrorKind::NotFound, "file not found"); + + // Convert to ArchivistError using From trait + let archivist_err: ArchivistError = io_err.into(); + + // Verify it's the right variant + match archivist_err { + ArchivistError::Io(e) => { + assert_eq!(e.kind(), io::ErrorKind::NotFound); + assert_eq!(e.to_string(), "file not found"); + } + _ => panic!("Expected Io variant"), + } + } + + #[test] + fn test_json_error_conversion() { + // Create a JSON error by trying to parse invalid JSON + let json_err = serde_json::from_str::("invalid json").unwrap_err(); + + // Convert to ArchivistError using From trait + let archivist_err: ArchivistError = json_err.into(); + + // Verify it's the right variant + match archivist_err { + ArchivistError::Json(_) => { + // Success - it's a JSON error + } + _ => panic!("Expected Json variant"), + } + } + + #[test] + fn test_result_type_with_question_mark() { + // Test that Result works with the ? operator + fn test_function() -> Result { + // This should compile and work with ? + let _data: serde_json::Value = serde_json::from_str(r#"{"key": "value"}"#)?; + Ok("success".to_string()) + } + + let result = test_function(); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "success"); + } + + #[test] + fn test_error_chain() { + // Test that errors can be chained properly + fn inner_function() -> std::io::Result { + Err(std::io::Error::new( + std::io::ErrorKind::NotFound, + "inner error", + )) + } + + fn outer_function() -> Result { + // The ? operator should automatically convert io::Error to ArchivistError + let _result = inner_function()?; + Ok("success".to_string()) + } + + let result = outer_function(); + assert!(result.is_err()); + + match result { + Err(ArchivistError::Io(e)) => { + assert_eq!(e.kind(), std::io::ErrorKind::NotFound); + } + _ => panic!("Expected Io error"), + } + } + + #[test] + fn test_error_debug() { + let uuid = Uuid::now_v7(); + let err = ArchivistError::ConnectorUnknown(uuid); + + // Verify Debug implementation works + let debug_str = format!("{:?}", err); + assert!(debug_str.contains("ConnectorUnknown")); + assert!(debug_str.contains(&uuid.to_string())); + } + + #[test] + fn test_all_error_variants() { + let uuid = Uuid::now_v7(); + + // Test all variants can be created + let errors = vec![ + ArchivistError::ConnectorUnknown(uuid), + ArchivistError::SessionUnknown(uuid), + ArchivistError::CollisionInconsistent(uuid), + ArchivistError::InvalidRequest("test".to_string()), + ArchivistError::Io(io::Error::new(io::ErrorKind::Other, "test")), + ArchivistError::Json(serde_json::from_str::("bad").unwrap_err()), + ]; + + // Verify each error has a non-empty display string + for err in errors { + let display = err.to_string(); + assert!(!display.is_empty(), "Error display should not be empty"); + + let debug = format!("{:?}", err); + assert!(!debug.is_empty(), "Error debug should not be empty"); + } + } +} diff --git a/crates/dirigent_archivist/src/events.rs b/crates/dirigent_archivist/src/events.rs new file mode 100644 index 0000000..92adbc2 --- /dev/null +++ b/crates/dirigent_archivist/src/events.rs @@ -0,0 +1,2162 @@ +//! Event handling for dirigent_core event stream. +//! +//! The EventHandler subscribes to dirigent_core's global event stream and writes +//! to the archive in real-time, accumulating streaming message chunks into complete +//! messages. + +use std::collections::HashSet; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use uuid::Uuid; +use chrono::Utc; + +use dirigent_protocol::{ + ContentBlock, Event, Message, MessagePart, MessageRole, Session, SessionUpdate, + TurnCompleteTrigger, +}; +use dirigent_protocol::streaming::{BusEvent, BusReceiver}; + +/// Closure that publishes a `BusEvent` to the shared event bus. +/// +/// The archivist cannot depend on `dirigent_core::SharingBus` directly +/// (that would introduce a dependency cycle). Callers in `api::archivist` +/// and `web::server` wrap the core's `SharingBus::publish` call in a +/// closure that matches this signature and install it via +/// [`EventHandler::set_bus_publisher`]. Implementations must spawn the +/// publish onto a tokio task (the call site here is synchronous). +pub type BusPublishFn = Arc; + +use crate::accumulator::MessageAccumulator; +use crate::coordinator::Archivist; +use crate::error::Result; +use crate::types::{ + MessageRecord, MetaEventRecord, MetaEventType, RegisterSessionRequest, + RegisterSessionResponse, RegisterStatus, SessionCompleteness, +}; + +/// Event handler for subscribing to dirigent_core events and archiving them +pub struct EventHandler { + archivist: Arc, + accumulator: Mutex, + /// Track which message IDs have been archived to prevent duplicates + archived_messages: Mutex>, + /// Sessions excluded from archiving (per-session toggle support) + /// Key is "connector_id:session_id" composite key + excluded_sessions: Mutex>, + /// Connectors whose sessions should be excluded from archiving by default + /// This is used for Gateway connectors where users can opt-in to archiving + excluded_connector_ids: Mutex>, + /// Sessions currently being replayed via session/load. + /// SessionUpdate events for these sessions are suppressed to avoid + /// re-archiving messages that already exist in the archive. + replaying_sessions: Mutex>, + /// Bus publisher used to emit `SessionRegistered` back onto the shared + /// event bus (`dirigent_core::SharingBus`). Installed via + /// [`Self::set_bus_publisher`] from the boot wiring in the server + /// crate. Without it, downstream consumers (e.g. the web UI's session + /// list refresh) won't receive the signal that a session is durably + /// registered — but archival still functions. + bus_publisher: Option, +} + +impl EventHandler { + /// Create a new event handler + pub fn new(archivist: Arc) -> Self { + Self { + archivist, + accumulator: Mutex::new( + MessageAccumulator::new().expect("Failed to create accumulator"), + ), + archived_messages: Mutex::new(HashSet::new()), + excluded_sessions: Mutex::new(HashSet::new()), + excluded_connector_ids: Mutex::new(HashSet::new()), + replaying_sessions: Mutex::new(HashSet::new()), + bus_publisher: None, + } + } + + /// Install a bus publisher so the archivist can emit its + /// `SessionRegistered` events onto the shared event bus. + /// + /// The callback receives a [`BusEvent`] and is expected to forward it + /// to `dirigent_core::SharingBus::publish` (likely by spawning a tokio + /// task). + pub fn set_bus_publisher(&mut self, publisher: BusPublishFn) { + self.bus_publisher = Some(publisher); + } + + /// Add a connector ID to the list of connectors whose sessions should be excluded by default + /// + /// Sessions from these connectors will start with archiving disabled. Users can still + /// enable archiving for individual sessions using the toggle. + /// + /// This is useful for Gateway connectors where most sessions are transient + /// and don't need to be archived. + pub async fn add_excluded_connector(&self, connector_id: &str) { + let mut excluded = self.excluded_connector_ids.lock().await; + excluded.insert(connector_id.to_string()); + tracing::info!( + "Connector '{}' added to excluded list - sessions will not archive by default", + connector_id + ); + } + + /// Remove a connector ID from the list of excluded connectors + /// + /// New sessions from this connector will archive by default (existing exclusions are not removed). + pub async fn remove_excluded_connector(&self, connector_id: &str) { + let mut excluded = self.excluded_connector_ids.lock().await; + excluded.remove(connector_id); + tracing::info!( + "Connector '{}' removed from excluded list - sessions will archive by default", + connector_id + ); + } + + /// Check if a connector is in the excluded list + pub async fn is_connector_excluded(&self, connector_id: &str) -> bool { + let excluded = self.excluded_connector_ids.lock().await; + excluded.contains(connector_id) + } + + /// Check if a session is excluded from archiving + /// + /// A session is excluded if: + /// 1. It's explicitly in the excluded_sessions set, OR + /// 2. Its connector is in excluded_connector_ids (e.g., Gateway) and the session + /// hasn't been explicitly enabled (which would have removed it from excluded_sessions) + /// + /// For Gateway connectors, sessions default to excluded until explicitly enabled. + pub async fn is_session_excluded(&self, connector_id: &str, session_id: &str) -> bool { + let key = format!("{}:{}", connector_id, session_id); + let excluded = self.excluded_sessions.lock().await; + + // If session is explicitly in excluded set, it's excluded + if excluded.contains(&key) { + return true; + } + + // Check if connector defaults to excluded (e.g., Gateway) + // For these connectors, sessions are excluded by default + let connector_excluded = { + let excluded_connectors = self.excluded_connector_ids.lock().await; + excluded_connectors.contains(connector_id) + }; + + if connector_excluded { + // Connector defaults to excluded. Session is excluded UNLESS it has been + // explicitly enabled. We track enabled sessions by checking if they were + // previously in excluded_sessions and removed (toggle enabled them). + // Since they're not in excluded_sessions now, check included_sessions. + // For simplicity, we use a marker: if connector is excluded but session + // is not in excluded_sessions, check if we have an "enabled" marker. + // Since we don't have a separate set, we'll treat "not in excluded_sessions" + // for an excluded connector as still excluded (the default state). + // The toggle will add an "enabled:key" marker to excluded_sessions to indicate + // the session was explicitly enabled. + let enabled_key = format!("enabled:{}", key); + if excluded.contains(&enabled_key) { + return false; // Explicitly enabled + } + return true; // Default excluded for Gateway + } + + false + } + + /// Exclude a session from archiving (disable archiving for this session) + /// + /// Returns true if the session was newly excluded, false if already excluded. + pub async fn exclude_session(&self, connector_id: &str, session_id: &str) -> bool { + let key = format!("{}:{}", connector_id, session_id); + let mut excluded = self.excluded_sessions.lock().await; + let newly_excluded = excluded.insert(key.clone()); + if newly_excluded { + tracing::info!( + "Archiving disabled for session {} (connector: {})", + session_id, + connector_id + ); + } + newly_excluded + } + + /// Include a session in archiving (enable archiving for this session) + /// + /// Returns true if the session was previously excluded, false if it wasn't excluded. + pub async fn include_session(&self, connector_id: &str, session_id: &str) -> bool { + let key = format!("{}:{}", connector_id, session_id); + let mut excluded = self.excluded_sessions.lock().await; + let was_excluded = excluded.remove(&key); + if was_excluded { + tracing::info!( + "Archiving enabled for session {} (connector: {})", + session_id, + connector_id + ); + } + was_excluded + } + + /// Toggle session archiving status + /// + /// Returns the new status: true = archiving enabled, false = archiving disabled + /// + /// For Gateway connectors (in excluded_connector_ids), sessions default to excluded. + /// Toggling uses an "enabled:" marker to track explicitly enabled sessions. + pub async fn toggle_session_archiving(&self, connector_id: &str, session_id: &str) -> bool { + let key = format!("{}:{}", connector_id, session_id); + let enabled_key = format!("enabled:{}", key); + + // Check if connector defaults to excluded (e.g., Gateway) + let connector_excluded = { + let excluded_connectors = self.excluded_connector_ids.lock().await; + excluded_connectors.contains(connector_id) + }; + + let mut excluded = self.excluded_sessions.lock().await; + + if connector_excluded { + // Gateway connector: sessions default to excluded + if excluded.contains(&enabled_key) { + // Currently enabled, disable by removing the enabled marker + excluded.remove(&enabled_key); + tracing::info!( + "Archiving disabled for session {} (connector: {}, Gateway default)", + session_id, + connector_id + ); + false // archiving now disabled + } else { + // Currently excluded (default), enable by adding the enabled marker + // Also remove from explicit excluded set if present + excluded.remove(&key); + excluded.insert(enabled_key); + tracing::info!( + "Archiving enabled for session {} (connector: {}, Gateway override)", + session_id, + connector_id + ); + true // archiving now enabled + } + } else { + // Non-Gateway connector: normal toggle logic + if excluded.contains(&key) { + excluded.remove(&key); + tracing::info!( + "Archiving enabled for session {} (connector: {})", + session_id, + connector_id + ); + true // archiving now enabled + } else { + excluded.insert(key); + tracing::info!( + "Archiving disabled for session {} (connector: {})", + session_id, + connector_id + ); + false // archiving now disabled + } + } + } + + /// Canonicalize a finalized message record and write it to the archive. + /// + /// This is the SINGLE write path for all finalization triggers (TurnComplete, + /// SessionIdle, stale timeout, shutdown flush). Resolves connector_id to + /// connector_uid, then to the canonical scroll_id, rewrites the record's + /// session field, and appends. + async fn canonicalize_and_write( + &self, + record: &mut MessageRecord, + connector_id: &str, + native_session_id: &str, + message_id: &str, + finalized_via: &str, + ) -> Result<()> { + // 1. Resolve connector_id to connector_uid + let connector_uid = match self.archivist.resolve_connector_uid(connector_id).await { + Ok(uid) => uid, + Err(e) => { + tracing::warn!( + "canonicalize_and_write: Failed to resolve connector_uid for '{}': {}. \ + Message {} will not be archived.", + connector_id, + e, + message_id + ); + return Ok(()); + } + }; + + // 2. Resolve or lazy-register session to get canonical scroll_id + let scroll_id = self + .resolve_or_register_session( + connector_uid, + native_session_id, + &format!( + "{}: Session mapping missing during message write", + finalized_via + ), + ) + .await?; + + // 3. Rewrite record to canonical identity + record.session = scroll_id; + + tracing::info!( + "canonicalize_and_write: Writing message {} ({} bytes) for session {} -> scroll_id {} via {}", + message_id, + record.content_md.len(), + native_session_id, + scroll_id, + finalized_via, + ); + + // 4. Write to archive + self.archivist + .append_messages(scroll_id, vec![record.clone()], None) + .await?; + + Ok(()) + } + + /// Resolve session with lazy registration fallback + /// + /// Attempts to resolve the session mapping. If resolution fails, performs + /// lazy registration with placeholder metadata to ensure messages can be written. + async fn resolve_or_register_session( + &self, + connector_uid: Uuid, + native_session_id: &str, + reason: &str, + ) -> Result { + match self + .archivist + .resolve_session(connector_uid, native_session_id, None) + .await + { + Ok(scroll_id) => Ok(scroll_id), + Err(e) => { + tracing::warn!( + "Failed to resolve session {} for connector {}: {}. Attempting lazy registration.", + native_session_id, + connector_uid, + e + ); + + // Attempt lazy registration using available metadata + let register_req = RegisterSessionRequest { + connector_uid, + native_session_id: native_session_id.to_string(), + title: Some("Lazy-registered session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({ + "lazy_registered": true, + "reason": reason + }), + completeness: SessionCompleteness::Discovered, + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + match self.archivist.register_session(register_req, None).await { + Ok(response) => { + tracing::info!( + "Lazy registration successful for session {} (scroll_id: {})", + native_session_id, + response.scroll_id + ); + Ok(response.scroll_id) + } + Err(reg_error) => { + tracing::error!( + "Lazy registration failed for session {}: {}. Message will be lost.", + native_session_id, + reg_error + ); + Err(reg_error) + } + } + } + } + } + + /// Create or get an existing meta session for an ACP client. + /// + /// If a meta session already exists for this client_id, returns its scroll_id. + /// Otherwise creates a new meta session with SessionKind::AcpConnection. + async fn get_or_create_meta_session( + &self, + client_id: &str, + connector_uid: Uuid, + connected_at: &str, + ) -> Result { + // First, try to find existing meta session + if let Some(existing) = self.archivist.find_meta_session_by_client(client_id, None).await? { + tracing::debug!( + "Found existing meta session {} for client {}", + existing.scroll_id, + client_id + ); + return Ok(existing.scroll_id); + } + + // No existing session, create new one + let req = RegisterSessionRequest { + connector_uid, + native_session_id: format!("acp-meta-{}", client_id), + title: Some(format!("ACP Connection: {}", client_id)), + custom_scroll_id: None, + completeness: SessionCompleteness::Complete, + metadata: serde_json::json!({ + "kind": "ACP_CONNECTION", + "acp_client_id": client_id, + "is_connected": true, + "connected_at": connected_at, + }), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let response = self.archivist.register_session(req, None).await?; + + tracing::info!( + "Created meta session {} for ACP client {}", + response.scroll_id, + client_id + ); + + Ok(response.scroll_id) + } + + /// Finalize stale message buffers that have been inactive for too long + /// + /// This prevents data loss when MessageCompleted events are missed or delayed. + /// Buffers inactive for longer than the threshold are finalized and written to archive. + async fn finalize_stale_buffers(&self, threshold: Duration) -> Result<()> { + let now = Utc::now(); + + // Collect stale message IDs + let stale_ids = { + let acc = self.accumulator.lock().await; + acc.get_stale_message_ids(now, threshold) + }; + + if stale_ids.is_empty() { + return Ok(()); + } + + tracing::warn!( + "Finalizing {} stale message buffer(s) (inactive > {:?})", + stale_ids.len(), + threshold + ); + + // Finalize each stale buffer + for message_id in stale_ids { + // Finalize buffer + let finalized_record = { + let mut acc = self.accumulator.lock().await; + acc.finalize(&message_id) + }; + + if let Some((mut record, connector_id, native_session_id)) = finalized_record { + // Check if already archived + let already_archived = { + let mut archived = self.archived_messages.lock().await; + if archived.contains(&message_id) { + true + } else { + archived.insert(message_id.clone()); + false + } + }; + + if already_archived { + tracing::debug!( + "Stale message {} already archived, skipping", + message_id + ); + continue; + } + + // Tag with finalization method + if let Some(metadata_obj) = record.metadata.as_object_mut() { + metadata_obj.insert( + "finalized_via".to_string(), + serde_json::Value::String("stale_timeout".to_string()), + ); + } + + // Canonicalize identity and write to archive + if let Err(e) = self.canonicalize_and_write( + &mut record, + &connector_id, + &native_session_id, + &message_id, + "stale_timeout", + ).await { + tracing::error!( + "Failed to write stale message {} to archive: {}", + message_id, + e + ); + // Continue processing other stale buffers + } + } + } + + Ok(()) + } + + /// Gracefully shut down the event handler, flushing all active buffers + /// + /// This ensures no messages are lost when the server is shut down. + /// All active message buffers are finalized and written to archive. + pub async fn shutdown_and_flush(&self) -> Result<()> { + tracing::info!("Shutting down EventHandler, flushing active buffers"); + + // Get all active buffer message IDs + let message_ids = { + let acc = self.accumulator.lock().await; + acc.get_all_message_ids() + }; + + if message_ids.is_empty() { + tracing::info!("No active buffers to flush on shutdown"); + return Ok(()); + } + + tracing::warn!( + "Flushing {} active message buffer(s) on shutdown", + message_ids.len() + ); + + let mut flushed_count = 0; + let mut already_archived_count = 0; + + // Finalize each buffer + for message_id in message_ids { + // Finalize buffer + let finalized_record = { + let mut acc = self.accumulator.lock().await; + acc.finalize(&message_id) + }; + + if let Some((mut record, connector_id, native_session_id)) = finalized_record { + // Check if already archived (deduplication) + let already_archived = { + let mut archived = self.archived_messages.lock().await; + if archived.contains(&message_id) { + true + } else { + archived.insert(message_id.clone()); + false + } + }; + + if already_archived { + tracing::debug!("Shutdown: Message {} already archived, skipping", message_id); + already_archived_count += 1; + continue; + } + + // Tag with finalization method + if let Some(metadata_obj) = record.metadata.as_object_mut() { + metadata_obj.insert( + "finalized_via".to_string(), + serde_json::Value::String("shutdown_flush".to_string()), + ); + } + + // Canonicalize identity and write to archive + if let Err(e) = self.canonicalize_and_write( + &mut record, + &connector_id, + &native_session_id, + &message_id, + "shutdown_flush", + ).await { + tracing::error!( + "Shutdown: Failed to write message {} to archive: {}", + message_id, + e + ); + // Continue processing other buffers despite error + } else { + flushed_count += 1; + } + } + } + + tracing::info!( + "Shutdown complete: {} messages flushed, {} already archived", + flushed_count, + already_archived_count + ); + + Ok(()) + } + + /// Run the event loop, consuming events from the bus receiver. + /// + /// Accepts a [`BusReceiver`] from `dirigent_core::SharingBus`. Each + /// received `BusEvent` is unwrapped to its inner `Event` and dispatched + /// to [`Self::handle_event`]. The `routing` metadata is currently + /// ignored by the archivist dispatcher — it continues to operate on the + /// pre-existing `Event` enum — but the scroll_id hints are already + /// attached at publish time for future consumers. + pub async fn run(&self, mut bus_rx: BusReceiver) { + let mut interval = tokio::time::interval(Duration::from_secs(1)); + // Use a long stale threshold (300s / 5 minutes) to match the general tool execution + // limit. Task agents and other slow tools can run for minutes. TurnComplete is the + // authoritative finalization signal; this is just a safety net for missed events. + let stale_threshold = Duration::from_secs(300); + + // Sample the lagged counter periodically so operators see drops in logs. + let lagged = Arc::clone(&bus_rx.lagged); + let mut last_lagged: u64 = 0; + + loop { + tokio::select! { + maybe_bus_event = bus_rx.rx.recv() => { + match maybe_bus_event { + Some(bus_event) => { + // Unwrap the Arc; clone the inner only when + // we're the last holder so the common path is a + // move rather than a deep clone. + let event = Arc::try_unwrap(bus_event.event) + .unwrap_or_else(|shared| (*shared).clone()); + if let Err(e) = self.handle_event(event).await { + tracing::error!("Failed to archive event: {}", e); + } + } + None => { + tracing::info!("Bus event stream closed, stopping event handler"); + break; + } + } + } + _ = interval.tick() => { + // Log any newly-dropped events since last tick. + let current = lagged.load(std::sync::atomic::Ordering::Relaxed); + if current > last_lagged { + tracing::warn!( + dropped = current - last_lagged, + total = current, + "archivist bus subscriber is lagging; events were dropped" + ); + last_lagged = current; + } + if let Err(e) = self.finalize_stale_buffers(stale_threshold).await { + tracing::error!("Failed to finalize stale buffers: {}", e); + } + } + } + } + } + + /// Dispatch event to appropriate handler + async fn handle_event(&self, event: Event) -> Result<()> { + match event { + Event::SessionCreated { + connector_id, + session, + } => { + // Always register sessions - we want to track them even if archiving is disabled + // This allows enabling archiving later + self.handle_session_created(connector_id, session).await?; + } + Event::SessionsListed { + connector_id, + sessions, + } => { + // Register all discovered sessions — archivist deduplicates via ALIASED status + self.handle_sessions_listed(connector_id, sessions).await?; + } + Event::TurnComplete { + connector_id, + session_id, + message_id, + trigger, + } => { + // Check if session is excluded from archiving + if self.is_session_excluded(&connector_id, &session_id).await { + tracing::debug!( + "Skipping TurnComplete for excluded session {} (connector: {})", + session_id, + connector_id + ); + return Ok(()); + } + tracing::debug!( + "TurnComplete received for message {} in session {} (trigger: {:?})", + message_id, + session_id, + trigger + ); + self.handle_turn_complete(connector_id, session_id, message_id, trigger).await?; + } + Event::MessageCompleted { + connector_id, + message, + } => { + // MessageCompleted is now informational only - metadata is ready + // Finalization happens on TurnComplete + tracing::debug!( + "MessageCompleted received for message {} in session {} (connector: {}) - metadata ready, awaiting TurnComplete for finalization", + message.id, + message.session_id, + connector_id + ); + // We could update message metadata in accumulator here if needed, + // but for now we keep this as a no-op for backward compatibility + } + Event::SessionUpdate { + connector_id, + session_id, + update, + } => { + // Skip re-archiving for sessions being replayed (already Complete) + if self.replaying_sessions.lock().await.contains(&session_id) { + tracing::debug!( + "Suppressing SessionUpdate for replaying session {} (already archived)", + session_id + ); + return Ok(()); + } + // Check if session is excluded from archiving + if self.is_session_excluded(&connector_id, &session_id).await { + tracing::debug!( + "Skipping SessionUpdate for excluded session {} (connector: {})", + session_id, + connector_id + ); + return Ok(()); + } + self.handle_session_update(connector_id, session_id, update) + .await?; + } + Event::SessionIdle { connector_id, session_id } => { + // Clear replay state if this session was being replayed + if self.replaying_sessions.lock().await.remove(&session_id) { + tracing::debug!( + "Session {} replay complete, resuming normal archiving", + session_id + ); + } + self.handle_session_idle(connector_id, session_id).await?; + } + Event::SessionMetadataUpdated { + connector_id, + session_id, + title, + total_messages: _, + model, + } => { + // Check if session is excluded from archiving + if self.is_session_excluded(&connector_id, &session_id).await { + tracing::debug!( + "Skipping SessionMetadataUpdated for excluded session {} (connector: {})", + session_id, + connector_id + ); + return Ok(()); + } + self.handle_session_metadata_updated(connector_id, session_id, title, model) + .await?; + } + Event::AcpClientConnected { + client_id, + connected_at, + capabilities, + connector_uid, + } => { + self.handle_acp_client_connected(client_id, connected_at, capabilities, connector_uid).await?; + } + Event::AcpClientDisconnected { + client_id, + disconnected_at, + reason, + } => { + self.handle_acp_client_disconnected(client_id, disconnected_at, reason).await?; + } + Event::AcpClientSessionOpened { + client_id, + gateway_session_id, + client_session_id: _, + timestamp: _, + } => { + self.handle_acp_session_opened(client_id, gateway_session_id).await?; + } + Event::AcpClientSessionRouted { + client_id, + from_session_id, + to_session_id, + connector_id, + connector_title: _, + connector_kind: _, + model: _, + agent_info: _, + timestamp: _, + } => { + self.handle_acp_session_routed(client_id, from_session_id, to_session_id, connector_id).await?; + } + Event::SessionMetadataReceived { + connector_id, + session_id, + models, + modes, + config_options, + } => { + // Check if session is excluded from archiving + if self.is_session_excluded(&connector_id, &session_id).await { + tracing::debug!( + "Skipping SessionMetadataReceived for excluded session {} (connector: {})", + session_id, + connector_id + ); + return Ok(()); + } + self.handle_session_acp_metadata_received(connector_id, session_id, models, modes, config_options).await?; + } + Event::SessionClosed { connector_id, session_id } => { + tracing::debug!("SessionClosed: session '{}' from connector '{}'", session_id, connector_id); + // No action needed — session remains in archive, just not actively connected + } + _ => { + // Ignore other event types for MVP + } + } + Ok(()) + } + + /// Handle SessionCreated event + /// Register a single session with the archivist, given an already-resolved connector_uid. + /// + /// Returns the registration response (Accepted/Aliased/Rejected). + /// Callers that process batches should resolve connector_uid once and reuse it. + async fn register_single_session( + &self, + connector_uid: Uuid, + session: &Session, + completeness: SessionCompleteness, + ) -> Result { + // Never use external session ID as scroll_id, even if it's a valid UUID. + // The archivist MUST generate a fresh UUID7 for storage consistency. + // External IDs (including UUID4s from connectors like claude-code-acp) + // are stored in session metadata for reverse lookup via sessions.jsonl. + let mut metadata = serde_json::json!({ + "project_path": session.metadata.project_path, + "model": session.metadata.model, + "project_id": session.metadata.project_id.map(|u| u.to_string()), + }); + + // Propagate tool_configuration from _meta.extra into session metadata. + // This preserves the connector's tool configuration in the archived session + // so it can be restored when the session is loaded later. + if let Some(ref meta) = session.metadata._meta { + if let Some(tool_config) = meta.extra.get("tool_configuration") { + if let Some(obj) = metadata.as_object_mut() { + obj.insert("tool_configuration".to_string(), tool_config.clone()); + } + } + } + + let req = RegisterSessionRequest { + connector_uid, + native_session_id: session.id.clone(), + title: Some(session.title.clone()), + custom_scroll_id: None, + metadata, + completeness, + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + self.archivist.register_session(req, None).await + } + + async fn handle_session_created(&self, connector_id: String, session: Session) -> Result<()> { + if session.id.is_empty() { + tracing::warn!( + "SessionCreated event has empty session ID — session will not be resumable" + ); + } + + // Resolve connector_id to connector_uid using the archivist + // This handles both UUID connector IDs and human-readable IDs like "gateway-1" + let connector_uid = match self.archivist.resolve_connector_uid(&connector_id).await { + Ok(uid) => uid, + Err(e) => { + tracing::warn!( + "Failed to resolve connector_uid for connector_id '{}': {}. \ + This connector may not be registered with the archivist.", + connector_id, + e + ); + return Ok(()); // Skip this event, don't fail the handler + } + }; + + let response = self.register_single_session(connector_uid, &session, SessionCompleteness::Complete).await?; + + // If the session was already known (Aliased), check if it's a replay of a Complete session. + if response.status == RegisterStatus::Aliased { + // Check if the existing session is Complete — if so, this is a replay + // and we should suppress re-archiving of replayed messages. + if let Ok(meta) = self.archivist.get_session_metadata(response.scroll_id, None).await { + if meta.completeness == SessionCompleteness::Complete { + self.replaying_sessions.lock().await.insert(session.id.clone()); + tracing::debug!( + "Session {} marked as replaying (already Complete)", + session.id + ); + } + } + // Upgrade completeness to Complete (idempotent for already-Complete sessions) + if let Err(e) = self.archivist.update_session_completeness( + response.scroll_id, + SessionCompleteness::Complete, + None, + ).await { + tracing::debug!( + "Failed to upgrade completeness for session {}: {}", + session.id, e + ); + } + } + + tracing::info!( + "Registered session {} for connector {}", + session.id, + connector_id + ); + + // Emit SessionRegistered to signal that this session is durably + // registered and the frontend can refresh the session list with + // confidence. Published onto the SharingBus. + let registered_event = Event::SessionRegistered { + connector_id: connector_id.clone(), + session_id: session.id.clone(), + scroll_id: response.scroll_id.to_string(), + }; + if let Some(ref publisher) = self.bus_publisher { + let bus_event = BusEvent::from_archivist_event( + registered_event, + &connector_id, + &session.id, + Some(response.scroll_id), + ); + publisher(bus_event); + } + + // T015: Check if this connector is in the excluded list + // If so, automatically exclude the session from archiving + if self.is_connector_excluded(&connector_id).await { + self.exclude_session(&connector_id, &session.id).await; + tracing::info!( + "Session {} auto-excluded from archiving (connector '{}' is excluded by default)", + session.id, + connector_id + ); + } + + Ok(()) + } + + /// Handle SessionsListed event — register all discovered sessions. + /// + /// Sessions already known return ALIASED (fast, from cache) with metadata refresh. + /// Newly discovered sessions return ACCEPTED and get a fresh scroll_id. + /// Errors on individual sessions are logged and skipped. + async fn handle_sessions_listed( + &self, + connector_id: String, + sessions: Vec, + ) -> Result<()> { + if sessions.is_empty() { + return Ok(()); + } + + let connector_uid = match self.archivist.resolve_connector_uid(&connector_id).await { + Ok(uid) => uid, + Err(e) => { + tracing::warn!( + "SessionsListed: Failed to resolve connector '{}': {}. Skipping {} sessions.", + connector_id, e, sessions.len() + ); + return Ok(()); + } + }; + + let is_excluded = self.is_connector_excluded(&connector_id).await; + let mut accepted = 0u32; + let mut aliased = 0u32; + let mut errors = 0u32; + + for session in &sessions { + match self.register_single_session(connector_uid, session, SessionCompleteness::Discovered).await { + Ok(response) => match response.status { + RegisterStatus::Accepted => { + accepted += 1; + tracing::info!( + "SessionsListed: New session '{}' registered (scroll: {}) for connector '{}'", + session.id, response.scroll_id, connector_id + ); + // Emit SessionRegistered to signal durable registration + // (published onto the SharingBus). + let registered_event = Event::SessionRegistered { + connector_id: connector_id.clone(), + session_id: session.id.clone(), + scroll_id: response.scroll_id.to_string(), + }; + if let Some(ref publisher) = self.bus_publisher { + let bus_event = BusEvent::from_archivist_event( + registered_event, + &connector_id, + &session.id, + Some(response.scroll_id), + ); + publisher(bus_event); + } + if is_excluded { + self.exclude_session(&connector_id, &session.id).await; + } + // Refresh metadata — title/model from session/list may differ from + // what was stored at registration (e.g., session was created via + // SessionCreated with empty title, now session/list has the real title) + let title = if session.title.is_empty() { None } else { Some(session.title.clone()) }; + let model = session.metadata.model.clone(); + if title.is_some() || model.is_some() { + if let Err(e) = self.archivist.update_session_metadata( + response.scroll_id, title, model, None, + ).await { + tracing::debug!( + "SessionsListed: metadata refresh after accept failed for '{}': {}", + session.id, e + ); + } + } + } + RegisterStatus::Aliased => { + aliased += 1; + // Refresh metadata — title/model may have changed upstream + let title = if session.title.is_empty() { None } else { Some(session.title.clone()) }; + let model = session.metadata.model.clone(); + if title.is_some() || model.is_some() { + if let Err(e) = self.archivist.update_session_metadata( + response.scroll_id, title, model, None, + ).await { + tracing::debug!( + "SessionsListed: metadata refresh failed for '{}': {}", + session.id, e + ); + } + } + } + RegisterStatus::Rejected => { + errors += 1; + tracing::warn!( + "SessionsListed: Registration rejected for session '{}' (connector '{}')", + session.id, connector_id + ); + } + }, + Err(e) => { + errors += 1; + tracing::warn!( + "SessionsListed: Failed to register '{}': {}. Continuing.", + session.id, e + ); + } + } + } + + tracing::info!( + "SessionsListed: {} sessions for '{}': {} new, {} known, {} errors", + sessions.len(), connector_id, accepted, aliased, errors + ); + + Ok(()) + } + + /// Handle TurnComplete event - triggers finalization and archiving + /// + /// This is THE primary signal that all content for a turn has been received + /// and the message is ready to be finalized and archived. + async fn handle_turn_complete( + &self, + connector_id: String, + session_id: String, + message_id: String, + trigger: TurnCompleteTrigger, + ) -> Result<()> { + tracing::info!( + "TurnComplete: Finalizing message {} in session {} (connector: {}, trigger: {:?})", + message_id, + session_id, + connector_id, + trigger + ); + + // ALWAYS finalize to remove buffer, even if already archived + let finalized_record = { + let mut acc = self.accumulator.lock().await; + acc.finalize(&message_id) + }; + + // Check if we have a buffer for this message + let Some((mut record, _acc_connector_id, _acc_native_session_id)) = finalized_record else { + tracing::warn!( + "TurnComplete: No accumulated content for message {} in session {} - message may not have been streamed", + message_id, + session_id + ); + // This could happen if: + // 1. Message was sent without streaming (but MessageCompleted should have handled it) + // 2. TurnComplete arrived before any chunks (timing issue) + // 3. Message was already finalized by SessionIdle or stale timeout + // For now, we log and skip - the message may already be archived + return Ok(()); + }; + + // Check if we've already archived this message AFTER finalization + // This prevents duplicate writes while ensuring buffer cleanup + { + let mut archived = self.archived_messages.lock().await; + if archived.contains(&message_id) { + tracing::debug!( + "TurnComplete: Message {} already archived, buffer cleaned but not re-written", + message_id + ); + return Ok(()); + } + // Mark as archived before processing to prevent race conditions + archived.insert(message_id.clone()); + } + + // Tag with TurnComplete-specific metadata + if let Some(metadata_obj) = record.metadata.as_object_mut() { + metadata_obj.insert( + "finalized_via".to_string(), + serde_json::Value::String("turn_complete".to_string()), + ); + metadata_obj.insert( + "turn_complete_trigger".to_string(), + serde_json::to_value(&trigger).unwrap_or(serde_json::json!(null)), + ); + } + + // Canonicalize identity and write to archive + self.canonicalize_and_write( + &mut record, + &connector_id, + &session_id, + &message_id, + "turn_complete", + ) + .await?; + + Ok(()) + } + + /// Handle SessionIdle event + /// Finalizes and writes any buffered messages for the session that haven't been completed yet. + /// This provides a safety net for the race condition where MessageCompleted arrives before + /// chunks have been accumulated, or when MessageCompleted is missed entirely. + async fn handle_session_idle(&self, connector_id: String, session_id: String) -> Result<()> { + tracing::debug!("SessionIdle received for session {} (connector: {})", session_id, connector_id); + + // Get all message IDs that have buffers for this session + let message_ids_for_session = { + let acc = self.accumulator.lock().await; + acc.get_message_ids_for_session(&session_id) + }; + + if message_ids_for_session.is_empty() { + tracing::debug!( + "SessionIdle: No active buffers for session {}, nothing to finalize", + session_id + ); + return Ok(()); + } + + tracing::info!( + "SessionIdle: Finalizing {} buffered message(s) for session {}", + message_ids_for_session.len(), + session_id + ); + + // Finalize each message and write to archive + for message_id in message_ids_for_session { + // ALWAYS finalize to remove buffer, even if already archived + let finalized_record = { + let mut acc = self.accumulator.lock().await; + acc.finalize(&message_id) + }; + + // If no buffer existed, skip (already cleaned up) + let Some((mut record, _acc_connector_id, _acc_native_session_id)) = finalized_record else { + tracing::debug!("SessionIdle: No buffer for message {}", message_id); + continue; + }; + + // Check if already archived to prevent duplicate writes + let already_archived = { + let archived = self.archived_messages.lock().await; + archived.contains(&message_id) + }; + + if already_archived { + tracing::debug!( + "SessionIdle: Message {} already archived, buffer cleaned but not re-written", + message_id + ); + continue; + } + + // Mark as archived + { + let mut archived = self.archived_messages.lock().await; + archived.insert(message_id.clone()); + } + + // Tag with finalization method + if let Some(metadata_obj) = record.metadata.as_object_mut() { + metadata_obj.insert( + "finalized_via".to_string(), + serde_json::Value::String("session_idle".to_string()), + ); + } + + // Canonicalize identity and write to archive + self.canonicalize_and_write( + &mut record, + &connector_id, + &session_id, + &message_id, + "session_idle", + ) + .await?; + } + + Ok(()) + } + + /// Write message to archive using connector_uid to resolve scroll_id + /// + /// This method is kept for potential future use with non-streaming messages, + /// though currently unused since TurnComplete handles all finalization. + #[allow(dead_code)] + async fn write_message_to_archive(&self, connector_uid: Uuid, message: Message) -> Result<()> { + // Convert dirigent_protocol::Message to markdown + let mut content_md = String::new(); + + for part in &message.content { + match part { + MessagePart::Text { text } => { + content_md.push_str(text); + } + MessagePart::Thinking { text } => { + content_md.push_str(&format!("{}", text)); + } + MessagePart::Code { language, code } => { + content_md.push_str(&format!("\n```{}\n{}\n```\n", language, code)); + } + MessagePart::Tool { + tool, + tool_call_id: _, + input, + output, + } => { + content_md.push_str(&format!("\n**Tool**: {}\n", tool)); + content_md.push_str(&format!( + "```json\n{}\n```\n", + serde_json::to_string_pretty(input).unwrap_or_else(|_| "{}".to_string()) + )); + if let Some(output_val) = output { + content_md.push_str(&format!( + "\n**Output**:\n```json\n{}\n```\n", + serde_json::to_string_pretty(output_val) + .unwrap_or_else(|_| "{}".to_string()) + )); + } + } + MessagePart::File { path, content } => { + content_md.push_str(&format!("\n**File**: {}\n```\n{}\n```\n", path, content)); + } + } + } + + // Parse message_id + let message_uuid = Uuid::parse_str(&message.id).unwrap_or_else(|_| { + tracing::warn!("Failed to parse message_id as UUID: {}", message.id); + Uuid::now_v7() + }); + + // Resolve scroll_id from native session_id BEFORE creating MessageRecord + // This ensures we use the canonical session identifier + // If resolution fails, attempt lazy registration as a fallback + let scroll_id = self + .resolve_or_register_session( + connector_uid, + &message.session_id, + "Session mapping missing during message write (non-streaming path)", + ) + .await?; + + // Convert MessageRole to string + let role = match message.role { + MessageRole::User => "user".to_string(), + MessageRole::Assistant => "assistant".to_string(), + }; + + // Serialize metadata + let metadata = message + .metadata + .map(|m| serde_json::to_value(m).unwrap_or(serde_json::json!({}))) + .unwrap_or(serde_json::json!({})); + + // Serialize original content parts for rich UI rendering + let content_parts = serde_json::to_value(&message.content).ok(); + + let record = MessageRecord { + version: 1, + message_id: message_uuid, + session: scroll_id, // Use canonical scroll_id instead of native session UUID + parent_id: None, + ts: message.created_at, + role, + author: None, + content_md, + content_parts, + attachments: Vec::new(), + metadata, + }; + + // Write message to archive + self.archivist + .append_messages(scroll_id, vec![record], None) + .await?; + + tracing::info!( + "Wrote message {} to archive for session {} (scroll_id: {})", + message.id, + message.session_id, + scroll_id + ); + + Ok(()) + } + + /// Handle SessionUpdate event + async fn handle_session_update( + &self, + connector_id: String, + session_id: String, + update: SessionUpdate, + ) -> Result<()> { + let mut acc = self.accumulator.lock().await; + + match update { + SessionUpdate::AgentMessageChunk { + message_id, + content, + .. + } => { + acc.add_chunk(message_id, session_id, connector_id, "assistant".to_string(), content); + } + SessionUpdate::UserMessageChunk { + message_id, + content, + .. + } => { + acc.add_chunk(message_id, session_id, connector_id, "user".to_string(), content); + } + SessionUpdate::AgentThoughtChunk { + message_id, + content, + .. + } => { + // Extract text from ContentBlock + if let ContentBlock::Text { text } = content { + acc.add_thinking(message_id, session_id, connector_id, text); + } + } + SessionUpdate::ToolCall { + message_id, + tool_call, + .. + } => { + // Convert to internal ToolCallData and add/update + // This handles both initial tool call and any subsequent updates + let tool_call_data = crate::accumulator::ToolCallData { + id: tool_call.id.clone(), + tool_name: tool_call.tool_name.clone(), + input: tool_call.raw_input.unwrap_or_else(|| serde_json::json!({})), + output: tool_call.raw_output, + }; + + acc.add_or_update_tool_call(message_id, tool_call_data); + } + SessionUpdate::ToolCallUpdate { + message_id, + tool_call_id, + tool_call, + .. + } => { + // Convert to ToolCallData and merge with existing tool call + let tool_call_data = crate::accumulator::ToolCallData { + id: tool_call_id, + tool_name: tool_call.tool_name.clone(), + input: tool_call.raw_input.unwrap_or_else(|| serde_json::json!({})), + output: tool_call.raw_output, + }; + + acc.add_or_update_tool_call(message_id, tool_call_data); + } + SessionUpdate::Unknown { .. } => { + // Ignore unknown update types - forward compatibility + } + } + + Ok(()) + } + + /// Handle SessionMetadataUpdated event + async fn handle_session_metadata_updated( + &self, + connector_id: String, + session_id: String, + title: Option, + model: Option, + ) -> Result<()> { + tracing::debug!( + "SessionMetadataUpdated event: connector={}, session={}, title={:?}, model={:?}", + connector_id, + session_id, + title, + model + ); + + // Resolve connector_id to connector_uid using the archivist + let connector_uid = match self.archivist.resolve_connector_uid(&connector_id).await { + Ok(uid) => uid, + Err(e) => { + tracing::warn!( + "Failed to resolve connector_uid for connector_id '{}': {}. \ + Skipping metadata update.", + connector_id, + e + ); + return Ok(()); // Skip this event, don't fail the handler + } + }; + + // Resolve session to scroll_id + let scroll_id = match self + .archivist + .resolve_session(connector_uid, &session_id, None) + .await + { + Ok(scroll_id) => scroll_id, + Err(e) => { + tracing::warn!( + "Failed to resolve session {} for connector {}: {}. Skipping metadata update.", + session_id, + connector_uid, + e + ); + return Ok(()); // Skip this event if session not found + } + }; + + // Update session metadata in archive + self.archivist + .update_session_metadata(scroll_id, title.clone(), model.clone(), None) + .await?; + + tracing::info!( + "Updated session metadata: scroll_id={}, title={:?}, model={:?}", + scroll_id, + title, + model + ); + + Ok(()) + } + + /// Handle SessionMetadataReceived event (ACP-specific models/modes/config_options metadata) + async fn handle_session_acp_metadata_received( + &self, + connector_id: String, + session_id: String, + models: Option, + modes: Option, + config_options: Option>, + ) -> Result<()> { + tracing::debug!( + "SessionMetadataReceived event: connector={}, session={}, has_models={}, has_modes={}, has_config_options={}", + connector_id, + session_id, + models.is_some(), + modes.is_some(), + config_options.is_some() + ); + + // Resolve connector_id to connector_uid using the archivist + let connector_uid = match self.archivist.resolve_connector_uid(&connector_id).await { + Ok(uid) => uid, + Err(e) => { + tracing::warn!( + "Failed to resolve connector_uid for connector_id '{}': {}. \ + Skipping ACP metadata update.", + connector_id, + e + ); + return Ok(()); // Skip this event, don't fail the handler + } + }; + + // Resolve session to scroll_id + let scroll_id = match self + .archivist + .resolve_session(connector_uid, &session_id, None) + .await + { + Ok(scroll_id) => scroll_id, + Err(e) => { + tracing::warn!( + "Failed to resolve session {} for connector {}: {}. Skipping ACP metadata update.", + session_id, + connector_uid, + e + ); + return Ok(()); // Skip this event if session not found + } + }; + + // Convert protocol types to JSON for storage + let models_json = models.and_then(|m| serde_json::to_value(m).ok()); + let modes_json = modes.and_then(|m| serde_json::to_value(m).ok()); + let config_options_json = config_options.and_then(|co| serde_json::to_value(co).ok()); + + // Update session ACP metadata in archive + self.archivist + .update_session_acp_metadata(scroll_id, models_json.clone(), modes_json.clone(), config_options_json.clone(), None) + .await?; + + tracing::info!( + "Updated session ACP metadata: scroll_id={}, has_models={}, has_modes={}, has_config_options={}", + scroll_id, + models_json.is_some(), + modes_json.is_some(), + config_options_json.is_some() + ); + + Ok(()) + } + + /// Handle AcpClientConnected event + async fn handle_acp_client_connected( + &self, + client_id: String, + connected_at: String, + _capabilities: Option, + connector_uid_str: String, + ) -> Result<()> { + tracing::info!("ACP client connected: {}", client_id); + + // Parse the connector_uid from the event + // This is the Acceptor connector's UID, used to create meta sessions under the right connector + let connector_uid = match Uuid::parse_str(&connector_uid_str) { + Ok(uid) if uid != Uuid::nil() => uid, + Ok(_) | Err(_) => { + // If parsing fails or we get a nil UUID, log a warning and skip meta session creation + tracing::warn!( + "Invalid connector_uid '{}' for ACP client {}. Meta session will not be created.", + connector_uid_str, + client_id + ); + return Ok(()); + } + }; + + // Create or get meta session + let scroll_id = self.get_or_create_meta_session(&client_id, connector_uid, &connected_at).await?; + + // Append ClientConnected event + let event = MetaEventRecord { + version: 1, + event_id: Uuid::now_v7(), + session: scroll_id, + ts: chrono::DateTime::parse_from_rfc3339(&connected_at) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()), + event_type: MetaEventType::ClientConnected, + description: format!("Client {} connected", client_id), + linked_session_id: None, + linked_connector_id: None, + linked_connector_title: None, + metadata: serde_json::json!({}), + }; + + self.archivist.append_meta_events(scroll_id, vec![event], None).await?; + + // Update connection status + self.archivist.update_meta_session_status(scroll_id, true, None, None).await?; + + Ok(()) + } + + /// Handle AcpClientDisconnected event + async fn handle_acp_client_disconnected( + &self, + client_id: String, + disconnected_at: String, + reason: Option, + ) -> Result<()> { + tracing::info!("ACP client disconnected: {} (reason: {:?})", client_id, reason); + + // Find the meta session for this client + let Some(meta_session) = self.archivist.find_meta_session_by_client(&client_id, None).await? else { + tracing::warn!("No meta session found for disconnecting client {}", client_id); + return Ok(()); + }; + + let scroll_id = meta_session.scroll_id; + + // Append ClientDisconnected event + let event = MetaEventRecord { + version: 1, + event_id: Uuid::now_v7(), + session: scroll_id, + ts: chrono::DateTime::parse_from_rfc3339(&disconnected_at) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(|_| Utc::now()), + event_type: MetaEventType::ClientDisconnected, + description: format!( + "Client {} disconnected{}", + client_id, + reason.as_ref().map(|r| format!(" ({})", r)).unwrap_or_default() + ), + linked_session_id: None, + linked_connector_id: None, + linked_connector_title: None, + metadata: serde_json::json!({ + "reason": reason + }), + }; + + self.archivist.append_meta_events(scroll_id, vec![event], None).await?; + + // Update connection status to disconnected + self.archivist.update_meta_session_status(scroll_id, false, None, None).await?; + + Ok(()) + } + + /// Handle AcpClientSessionOpened event + async fn handle_acp_session_opened( + &self, + client_id: String, + gateway_session_id: String, + ) -> Result<()> { + tracing::info!("ACP client {} opened session {}", client_id, gateway_session_id); + + // Find the meta session for this client + let Some(meta_session) = self.archivist.find_meta_session_by_client(&client_id, None).await? else { + tracing::warn!("No meta session found for client {} opening session", client_id); + return Ok(()); + }; + + let scroll_id = meta_session.scroll_id; + + // Try to parse gateway_session_id as UUID for linking + let linked_session_id = Uuid::parse_str(&gateway_session_id).ok(); + + // Append SessionOpened event + let event = MetaEventRecord { + version: 1, + event_id: Uuid::now_v7(), + session: scroll_id, + ts: Utc::now(), + event_type: MetaEventType::SessionOpened, + description: format!("Opened session {}", gateway_session_id), + linked_session_id, + linked_connector_id: None, + linked_connector_title: None, + metadata: serde_json::json!({ + "gateway_session_id": gateway_session_id + }), + }; + + self.archivist.append_meta_events(scroll_id, vec![event], None).await?; + + // Update current session + self.archivist.update_meta_session_status(scroll_id, true, linked_session_id, None).await?; + + Ok(()) + } + + /// Handle AcpClientSessionRouted event + async fn handle_acp_session_routed( + &self, + client_id: String, + from_session_id: String, + to_session_id: String, + to_connector_id: String, + ) -> Result<()> { + tracing::info!( + "ACP client {} session routed: {} -> {} (connector: {})", + client_id, from_session_id, to_session_id, to_connector_id + ); + + // Find the meta session for this client + let Some(meta_session) = self.archivist.find_meta_session_by_client(&client_id, None).await? else { + tracing::warn!("No meta session found for client {} during routing", client_id); + return Ok(()); + }; + + let scroll_id = meta_session.scroll_id; + + // Try to parse to_session_id as UUID for linking + let linked_session_id = Uuid::parse_str(&to_session_id).ok(); + + // Append SessionSwitched event + let event = MetaEventRecord { + version: 1, + event_id: Uuid::now_v7(), + session: scroll_id, + ts: Utc::now(), + event_type: MetaEventType::SessionSwitched, + description: format!("Switched to {} via {}", to_session_id, to_connector_id), + linked_session_id, + linked_connector_id: Some(to_connector_id.clone()), + linked_connector_title: None, // Could be resolved later + metadata: serde_json::json!({ + "from_session_id": from_session_id, + "to_session_id": to_session_id, + "to_connector_id": to_connector_id + }), + }; + + self.archivist.append_meta_events(scroll_id, vec![event], None).await?; + + // Update current session + self.archivist.update_meta_session_status(scroll_id, true, linked_session_id, None).await?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::Utc; + use dirigent_protocol::MessageStatus; + + /// Create a real `Archivist` backed by a unique temp directory. + /// Previously these tests used a hand-rolled `MockArchivist`, but the + /// trait-object interface was removed in Phase 2 and EventHandler now + /// requires a concrete coordinator. + /// + /// Uses `from_single_backend` (not `new_with_single_archive`) so each test + /// is fully isolated — otherwise tests running in parallel race on the + /// shared `.archives.json` that `new_with_single_archive` writes into the + /// tempdir's parent. + async fn mk_test_archivist() -> Arc { + let tmp = std::env::temp_dir() + .join(format!("events_test_{}", Uuid::now_v7())); + let backend = Arc::new( + crate::backends::JsonlBackend::new(tmp) + .await + .expect("create test backend"), + ); + Arc::new( + Archivist::from_single_backend("main".into(), backend) + .await + .expect("create test archivist"), + ) + } + + + #[tokio::test] + async fn test_event_handler_creation() { + let archivist = mk_test_archivist().await; + let _handler = EventHandler::new(archivist); + } + + #[tokio::test] + async fn test_handle_session_update_agent_chunk() { + let archivist = mk_test_archivist().await; + let handler = EventHandler::new(archivist); + + let update = SessionUpdate::AgentMessageChunk { + message_id: "msg_123".to_string(), + content: ContentBlock::Text { + text: "Hello".to_string(), + }, + _meta: None, + }; + + let result = handler + .handle_session_update( + "connector_123".to_string(), + "session_456".to_string(), + update, + ) + .await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_handle_session_update_thinking_chunk() { + let archivist = mk_test_archivist().await; + let handler = EventHandler::new(archivist); + + let update = SessionUpdate::AgentThoughtChunk { + message_id: "msg_789".to_string(), + content: ContentBlock::Text { + text: "Thinking...".to_string(), + }, + _meta: None, + }; + + let result = handler + .handle_session_update( + "connector_456".to_string(), + "session_abc".to_string(), + update, + ) + .await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_write_message_to_archive() { + use uuid::Uuid; + + let archivist = mk_test_archivist().await; + + // Pre-register the connector so lazy session registration can attach to it. + // (The old MockArchivist answered every call unconditionally; the real + // coordinator requires the connector exist first.) + let connector_uid = Uuid::now_v7(); + archivist + .register_connector( + crate::types::RegisterConnectorRequest { + r#type: "Test".into(), + title: "t".into(), + client_native_id: format!("t@{}", Uuid::now_v7()), + custom_uid: Some(connector_uid), + metadata: serde_json::json!({}), + fingerprint: None, + }, + None, + ) + .await + .unwrap(); + + let handler = EventHandler::new(archivist); + + let message = Message { + id: "01936e8f-e5a7-7000-8000-000000000001".to_string(), + session_id: "01936e8f-e5a7-7000-8000-000000000002".to_string(), + role: MessageRole::User, + created_at: Utc::now(), + content: vec![MessagePart::Text { + text: "Hello, world!".to_string(), + }], + status: MessageStatus::Completed, + metadata: None, + }; + + let result = handler + .write_message_to_archive(connector_uid, message) + .await; + assert!(result.is_ok(), "write_message_to_archive failed: {:?}", result); + } + + #[tokio::test] + async fn test_user_message_turn_complete_event() { + use uuid::Uuid; + + let archivist = mk_test_archivist().await; + let handler = EventHandler::new(archivist); + + let connector_id = Uuid::now_v7().to_string(); + let session_id = "01936e8f-e5a7-7000-8000-000000000102".to_string(); + let message_id = "01936e8f-e5a7-7000-8000-000000000101".to_string(); + + // First, add some chunks to accumulator (simulating streaming) + let update = SessionUpdate::UserMessageChunk { + message_id: message_id.clone(), + content: ContentBlock::Text { + text: "What is the capital of France?".to_string(), + }, + _meta: None, + }; + + let _ = handler + .handle_session_update( + connector_id.clone(), + session_id.clone(), + update, + ) + .await; + + // Now handle TurnComplete event for user message + let result = handler + .handle_turn_complete( + connector_id, + session_id, + message_id, + TurnCompleteTrigger::ExplicitSignal, + ) + .await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_assistant_message_turn_complete_event() { + use uuid::Uuid; + + let archivist = mk_test_archivist().await; + let handler = EventHandler::new(archivist); + + let connector_id = Uuid::now_v7().to_string(); + let session_id = "01936e8f-e5a7-7000-8000-000000000202".to_string(); + let message_id = "01936e8f-e5a7-7000-8000-000000000201".to_string(); + + // First, add some chunks to accumulator (simulating streaming) + let update = SessionUpdate::AgentMessageChunk { + message_id: message_id.clone(), + content: ContentBlock::Text { + text: "The capital of France is Paris.".to_string(), + }, + _meta: None, + }; + + let _ = handler + .handle_session_update( + connector_id.clone(), + session_id.clone(), + update, + ) + .await; + + // Now handle TurnComplete event for assistant message + let result = handler + .handle_turn_complete( + connector_id, + session_id, + message_id, + TurnCompleteTrigger::ExplicitSignal, + ) + .await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_session_idle_finalizes_buffered_messages() { + let archivist = mk_test_archivist().await; + let handler = EventHandler::new(archivist); + + let session_id = "01936e8f-e5a7-7000-8000-000000000010".to_string(); + let message_id = "01936e8f-e5a7-7000-8000-000000000011".to_string(); + + // Add chunks to accumulator (simulating streaming) + let update = SessionUpdate::AgentMessageChunk { + message_id: message_id.clone(), + content: ContentBlock::Text { + text: "Hello from stream".to_string(), + }, + _meta: None, + }; + + let result = handler + .handle_session_update( + "connector_123".to_string(), + session_id.clone(), + update, + ) + .await; + assert!(result.is_ok()); + + // Verify buffer exists + let has_buffer = { + let acc = handler.accumulator.lock().await; + !acc.get_message_ids_for_session(&session_id).is_empty() + }; + assert!(has_buffer, "Buffer should exist before SessionIdle"); + + // Handle SessionIdle + let result = handler.handle_session_idle("connector_123".to_string(), session_id.clone()).await; + assert!(result.is_ok()); + + // Verify buffer was finalized and cleared + let has_buffer_after = { + let acc = handler.accumulator.lock().await; + !acc.get_message_ids_for_session(&session_id).is_empty() + }; + assert!(!has_buffer_after, "Buffer should be cleared after SessionIdle"); + + // Verify message was marked as archived + let is_archived = { + let archived = handler.archived_messages.lock().await; + archived.contains(&message_id) + }; + assert!(is_archived, "Message should be marked as archived"); + } + + #[tokio::test] + async fn test_session_idle_with_no_buffers() { + let archivist = mk_test_archivist().await; + let handler = EventHandler::new(archivist); + + let session_id = "01936e8f-e5a7-7000-8000-000000000020".to_string(); + + // Handle SessionIdle with no buffers (should be no-op) + let result = handler.handle_session_idle("test-connector".to_string(), session_id).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_session_idle_skips_already_archived() { + let archivist = mk_test_archivist().await; + let handler = EventHandler::new(archivist); + + let session_id = "01936e8f-e5a7-7000-8000-000000000030".to_string(); + let message_id = "01936e8f-e5a7-7000-8000-000000000031".to_string(); + + // Add chunks to accumulator + let update = SessionUpdate::AgentMessageChunk { + message_id: message_id.clone(), + content: ContentBlock::Text { + text: "Test message".to_string(), + }, + _meta: None, + }; + + handler + .handle_session_update( + "connector_456".to_string(), + session_id.clone(), + update, + ) + .await + .unwrap(); + + // Verify buffer exists before marking as archived + let has_buffer_before = { + let acc = handler.accumulator.lock().await; + !acc.get_message_ids_for_session(&session_id).is_empty() + }; + assert!(has_buffer_before, "Buffer should exist before marking as archived"); + + // Mark message as already archived + { + let mut archived = handler.archived_messages.lock().await; + archived.insert(message_id.clone()); + } + + // Handle SessionIdle - should clean buffer but skip writing already archived message + let result = handler.handle_session_idle("connector_456".to_string(), session_id.clone()).await; + assert!(result.is_ok()); + + // Buffer SHOULD be cleared even though message was already archived + // This is the fix: finalize() is called before checking archived_messages + let has_buffer = { + let acc = handler.accumulator.lock().await; + !acc.get_message_ids_for_session(&session_id).is_empty() + }; + assert!(!has_buffer, "Buffer should be cleared even when message is already archived (prevents leak)"); + } + + fn make_test_session(id: &str, title: &str) -> Session { + Session { + id: id.to_string(), + title: title.to_string(), + created_at: Utc::now(), + updated_at: Utc::now(), + metadata: dirigent_protocol::SessionMetadata { + project_path: "/test/project".to_string(), + model: Some("claude-4".to_string()), + total_messages: 0, + system_message: None, + current_mode_id: None, + _meta: None, + project_id: None, + }, + cwd: None, + models: None, + modes: None, + config_options: None, + acp_client_id: None, + } + } + + #[tokio::test] + async fn test_handle_sessions_listed_registers_sessions() { + let archivist = mk_test_archivist().await; + let handler = EventHandler::new(archivist); + + let connector_id = Uuid::now_v7().to_string(); + let sessions = vec![ + make_test_session("session-1", "First Session"), + make_test_session("session-2", "Second Session"), + ]; + + let result = handler + .handle_sessions_listed(connector_id, sessions) + .await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_handle_sessions_listed_empty_list() { + let archivist = mk_test_archivist().await; + let handler = EventHandler::new(archivist); + + let result = handler + .handle_sessions_listed("some-connector".to_string(), vec![]) + .await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_handle_event_sessions_listed() { + let archivist = mk_test_archivist().await; + let handler = EventHandler::new(archivist); + + let connector_id = Uuid::now_v7().to_string(); + let event = Event::SessionsListed { + connector_id, + sessions: vec![ + make_test_session("sess-a", "Session A"), + make_test_session("sess-b", "Session B"), + make_test_session("sess-c", "Session C"), + ], + }; + + let result = handler.handle_event(event).await; + assert!(result.is_ok()); + } +} diff --git a/crates/dirigent_archivist/src/import/mod.rs b/crates/dirigent_archivist/src/import/mod.rs new file mode 100644 index 0000000..e63dce9 --- /dev/null +++ b/crates/dirigent_archivist/src/import/mod.rs @@ -0,0 +1,933 @@ +//! Generic import infrastructure for bringing external sessions into the archive. +//! +//! This module provides the shared types and orchestration logic that all importers +//! (Claude, ChatGPT, etc.) reuse. Each importer implements discovery and message +//! conversion, then delegates to [`import_sessions`] for the actual import. + +pub mod progress; +pub mod registry; +pub mod sources; +pub mod trait_def; + +/// Backwards-compatible re-export — external callers (e.g. `api`) import +/// `dirigent_archivist::import::claude::{discover_claude_import, +/// import_claude_sessions}`. Keep the path stable until those callsites +/// migrate to the `Importer` trait. +pub use sources::claude; +#[cfg(feature = "importer-claude")] +pub use sources::claude::ClaudeImporter; + +pub use progress::{ImportProgressEvent, ImportProgressSink, SessionOutcome, StatsDelta}; +pub use registry::ImporterRegistry; +pub use trait_def::{ConfigField, ConfigFieldKind, ImportConfig, ImportConfigShape, ImportError, ImportTarget, Importer, ImporterInfo}; + +use std::collections::HashMap; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::coordinator::Archivist; +use crate::error::{ArchivistError, Result}; +use crate::types::{ + MessageRecord, RegisterConnectorRequest, RegisterSessionRequest, RegisterStatus, + SessionCompleteness, +}; + +/// Statistics collected during an import operation. +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] +pub struct ImportStats { + /// Number of sessions found by the importer's discovery phase. + pub sessions_discovered: usize, + /// Number of sessions successfully imported as new. + pub sessions_imported: usize, + /// Number of sessions skipped (already present with same or more messages). + pub sessions_skipped: usize, + /// Number of sessions that were updated with new messages. + pub sessions_updated: usize, + /// Total number of message records written to the archive. + pub messages_written: usize, + /// Number of messages that were already present (from existing sessions). + pub messages_already_present: usize, + /// Number of sessions skipped because the fingerprint matched (no source changes). + #[serde(default)] + pub sessions_fingerprint_skipped: usize, + /// Errors encountered during import (non-fatal; import continues). + pub errors: Vec, +} + +impl ImportStats { + /// Total sessions processed (imported + skipped + updated + errored). + pub fn total_sessions_processed(&self) -> usize { + self.sessions_imported + self.sessions_skipped + self.sessions_updated + self.errors.len() + } + + /// Whether any errors were encountered during import. + pub fn has_errors(&self) -> bool { + !self.errors.is_empty() + } +} + +/// Intermediate representation for a session discovered by any importer. +/// +/// This is source-agnostic: each importer converts its native session format +/// into `DiscoveredSession` before handing it to [`import_sessions`]. +#[derive(Debug, Clone)] +pub struct DiscoveredSession { + /// The session ID from the original source (e.g., Claude's JSONL filename). + pub native_session_id: String, + /// Human-readable session title, if available. + pub title: Option, + /// When the session was created in the source system. + pub created_at: Option>, + /// When the session was last updated in the source system. + pub updated_at: Option>, + /// Number of messages in the source session (used for skip/update decisions). + pub message_count: usize, + /// Arbitrary source-specific metadata preserved for provenance. + pub metadata: serde_json::Value, + /// Project path associated with the session, if known. + pub project_path: Option, + /// Size of the source file in bytes, if available. Used for fingerprint-based + /// change detection to skip unchanged sessions on re-import. + pub file_size: Option, +} + +/// Snapshot of source-side signals captured after a successful import. +/// +/// Stored in the session's `metadata` JSON under the `"_import_snapshot"` key. +/// On re-import, comparing the current `DiscoveredSession` against the stored +/// snapshot lets us skip expensive full-parse when nothing has changed (O(1) gate). +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ImportSnapshot { + /// Number of messages in the source at the time of import. + pub source_message_count: usize, + /// Source-side `updated_at` timestamp at the time of import. + pub source_updated_at: Option>, + /// Source file size in bytes at the time of import. + pub source_file_size: Option, + /// When this snapshot was recorded. + pub imported_at: DateTime, +} + +/// Key used to store [`ImportSnapshot`] in session metadata JSON. +const IMPORT_SNAPSHOT_KEY: &str = "_import_snapshot"; + +impl ImportSnapshot { + /// Check whether the source signals in `discovered` match this snapshot. + /// + /// Returns `true` if all present signals match, meaning the session has not + /// changed since this snapshot was taken and a full re-parse can be skipped. + pub fn matches(&self, discovered: &DiscoveredSession) -> bool { + if self.source_message_count != discovered.message_count { + return false; + } + if self.source_updated_at != discovered.updated_at { + return false; + } + // file_size: only compare when both sides have a value. + if let (Some(snap_size), Some(disc_size)) = (self.source_file_size, discovered.file_size) { + if snap_size != disc_size { + return false; + } + } + true + } + + /// Build a snapshot from a discovered session (captures current source signals). + pub fn from_discovered(discovered: &DiscoveredSession) -> Self { + Self { + source_message_count: discovered.message_count, + source_updated_at: discovered.updated_at, + source_file_size: discovered.file_size, + imported_at: Utc::now(), + } + } + + /// Try to deserialize a snapshot from a session's metadata JSON. + pub fn from_metadata(metadata: &serde_json::Value) -> Option { + metadata + .get(IMPORT_SNAPSHOT_KEY) + .and_then(|v| serde_json::from_value(v.clone()).ok()) + } + + /// Serialize this snapshot into the session's metadata JSON under the + /// `_import_snapshot` key. + pub fn write_to_metadata(&self, metadata: &mut serde_json::Value) { + if let Some(obj) = metadata.as_object_mut() { + if let Ok(val) = serde_json::to_value(self) { + obj.insert(IMPORT_SNAPSHOT_KEY.to_string(), val); + } + } else { + tracing::warn!("cannot write import snapshot: metadata is not a JSON object"); + } + } +} + +/// Summary returned by the discovery phase before actual import begins. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportDiscovery { + /// Human-readable name of the import source (e.g., "Claude Code"). + pub source_name: String, + /// Filesystem path or URI that was scanned. + pub source_path: String, + /// Projects discovered, grouped by name. + pub projects: Vec, + /// Total number of sessions found across all projects. + pub total_sessions: usize, + /// Estimated total messages across all discovered sessions. + pub total_estimated_messages: usize, +} + +/// A project grouping within an import discovery result. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportProject { + /// Project name (typically derived from the directory path). + pub name: String, + /// Number of sessions belonging to this project. + pub session_count: usize, +} + +/// Resolves the `updated_at` timestamp for an imported session. +/// +/// Prefers the source-provided timestamp from `discovered.updated_at`; falls +/// back to `Utc::now()` only when the source does not supply one. +fn resolve_updated_at(discovered: &DiscoveredSession) -> DateTime { + discovered.updated_at.unwrap_or_else(chrono::Utc::now) +} + +/// Generic async orchestrator that imports discovered sessions into the archive. +/// +/// This function handles the full import lifecycle: +/// 1. Registers the connector (idempotent via fingerprint). +/// 2. For each discovered session, checks whether it already exists in the archive. +/// 3. New sessions are registered and their messages are converted and appended. +/// 4. Existing sessions with fewer archived messages are logged and skipped (v1). +/// 5. Existing sessions with the same or more archived messages are skipped. +/// +/// The `convert_messages` closure receives a `native_session_id` and returns +/// `MessageRecord`s with `Uuid::nil()` in the `session` field. This function +/// patches each record's `session` to the real `scroll_id` before appending. +/// +/// # Arguments +/// +/// * `archivist` - The archivist to import into. +/// * `connector_req` - Registration request for the import connector. +/// * `sessions` - Sessions discovered by the importer. +/// * `convert_messages` - Closure that converts a native session into `MessageRecord`s. +/// * `archive` - Optional archive name (`None` for default archive). +/// * `progress` - Sink for per-session progress events (use +/// [`ImportProgressSink::noop`] when progress reporting is not needed). +pub async fn import_sessions( + archivist: &Archivist, + connector_req: RegisterConnectorRequest, + sessions: Vec, + convert_messages: F, + archive: Option, + progress: &ImportProgressSink, + force_deep_scan: bool, + project_map: &HashMap, +) -> Result +where + F: Fn(&str) -> Result> + Send + Sync, +{ + let mut stats = ImportStats::default(); + stats.sessions_discovered = sessions.len(); + + // Step 1: Register the connector (idempotent). + let connector_resp = archivist + .register_connector(connector_req, archive.clone()) + .await?; + let connector_uid = connector_resp.connector_uid; + + tracing::info!( + connector_uid = %connector_uid, + status = ?connector_resp.status, + "Import connector registered" + ); + + // Step 2: Process each discovered session. + let total = sessions.len(); + for (index, session) in sessions.iter().enumerate() { + let native_id = &session.native_session_id; + + progress + .send(ImportProgressEvent::SessionStarted { + native_id: native_id.clone(), + index, + total, + }) + .await; + + // Per-session outcome + stats delta. Updated as we go; on the early + // `continue` paths we emit Failed/Skipped before moving on. + let mut messages_written_delta: u64 = 0; + let mut messages_already_present_delta: u64 = 0; + let mut session_changed = false; + + // Helper: emit SessionFinished and fall out of the iteration. + macro_rules! emit_finished { + ($outcome:expr) => {{ + progress + .send(ImportProgressEvent::SessionFinished { + native_id: native_id.clone(), + outcome: $outcome, + stats_delta: StatsDelta { + messages_written: messages_written_delta, + messages_already_present: messages_already_present_delta, + }, + }) + .await; + }}; + } + + // --- Step 1: Resolve or create scroll_id BEFORE convert_messages --- + let (scroll_id, session_is_new) = match archivist + .resolve_session(connector_uid, native_id, archive.clone()) + .await + { + Ok(id) => (id, false), + Err(ArchivistError::SessionUnknown(_)) => { + // Inject project_id from project_map if the session has a + // project_path that maps to a known project. + let mut metadata = session.metadata.clone(); + if let Some(project_path) = session.project_path.as_deref() { + if let Some(pid) = project_map.get(project_path) { + if let Some(obj) = metadata.as_object_mut() { + obj.insert( + "project_id".to_string(), + serde_json::Value::String(pid.clone()), + ); + } + } + } + + let register_req = RegisterSessionRequest { + connector_uid, + native_session_id: native_id.clone(), + title: session.title.clone(), + custom_scroll_id: None, + metadata, + completeness: SessionCompleteness::Complete, + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + match archivist + .register_session(register_req, archive.clone()) + .await + { + Ok(resp) => match resp.status { + RegisterStatus::Accepted => (resp.scroll_id, true), + RegisterStatus::Aliased => { + stats.sessions_skipped += 1; + emit_finished!(SessionOutcome::Skipped); + continue; + } + RegisterStatus::Rejected => { + stats.errors.push(format!( + "Session registration rejected for {native_id}" + )); + emit_finished!(SessionOutcome::Failed); + continue; + } + }, + Err(e) => { + stats.errors.push(format!( + "Failed to register session {native_id}: {e}" + )); + emit_finished!(SessionOutcome::Failed); + continue; + } + } + } + Err(e) => { + stats.errors.push(format!( + "Failed to resolve session {native_id}: {e}" + )); + emit_finished!(SessionOutcome::Failed); + continue; + } + }; + + // --- Step 2: Hoist metadata read for existing sessions --- + // Load metadata once; reused for fingerprint check AND title/model diff. + let existing_meta = if !session_is_new { + match archivist + .get_session_metadata(scroll_id, archive.clone()) + .await + { + Ok(m) => Some(m), + Err(e) => { + stats.errors.push(format!( + "Failed to read session metadata for {native_id}: {e}" + )); + emit_finished!(SessionOutcome::Failed); + continue; + } + } + } else { + None + }; + + // --- Step 2b: Retroactive project_id linking for existing sessions --- + // Sessions imported before project detection (or before the project was + // created) have project_path but no project_id. Patch it now if the + // project_map has a match — this runs even for fingerprint-skipped + // sessions so re-import can link them without any source-side changes. + if !session_is_new { + if let Some(ref meta) = existing_meta { + let has_project_path = meta + .metadata + .get("project_path") + .and_then(|v| v.as_str()) + .is_some(); + let has_project_id = meta + .metadata + .get("project_id") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + .is_some(); + + if has_project_path && !has_project_id { + let stored_path = meta + .metadata + .get("project_path") + .and_then(|v| v.as_str()) + .unwrap(); + if let Some(pid) = project_map.get(stored_path) { + if let Ok(primary) = + archivist.resolve_primary(archive.clone()).await + { + let mut patched = meta.clone(); + if let Some(obj) = patched.metadata.as_object_mut() { + obj.insert( + "project_id".to_string(), + serde_json::Value::String(pid.clone()), + ); + } + patched.updated_at = resolve_updated_at(session); + match primary.backend.put_session(patched).await { + Ok(_) => { + tracing::info!( + scroll_id = %scroll_id, + project_id = %pid, + "Retroactively linked session to project" + ); + session_changed = true; + } + Err(e) => { + tracing::warn!( + scroll_id = %scroll_id, + error = %e, + "Failed to retroactively link session to project" + ); + } + } + } + } + } + } + } + + // --- Step 3: Fingerprint gate — skip unchanged sessions --- + if !session_is_new && !force_deep_scan { + if let Some(ref meta) = existing_meta { + if let Some(snapshot) = ImportSnapshot::from_metadata(&meta.metadata) { + if snapshot.matches(session) { + stats.sessions_fingerprint_skipped += 1; + if session_changed { + tracing::debug!( + native_id = %native_id, + "Fingerprint match — skipping message scan (metadata was updated)" + ); + stats.sessions_updated += 1; + emit_finished!(SessionOutcome::Updated); + } else { + tracing::debug!( + native_id = %native_id, + "Fingerprint match — skipping unchanged session" + ); + stats.sessions_skipped += 1; + emit_finished!(SessionOutcome::Skipped); + } + continue; + } + } + } + } + + // --- Step 4: Convert messages (EXPENSIVE — after fingerprint gate) --- + let source_records = match convert_messages(native_id) { + Ok(r) => r, + Err(e) => { + stats.errors.push(format!( + "Failed to convert messages for session {native_id}: {e}" + )); + emit_finished!(SessionOutcome::Failed); + continue; + } + }; + + // Build existing_ids set — empty for brand-new sessions. + let existing_ids: std::collections::HashSet = if session_is_new { + std::collections::HashSet::new() + } else { + match archivist.get_messages(scroll_id, archive.clone()).await { + Ok(msgs) => msgs.into_iter().map(|m| m.message_id).collect(), + Err(e) => { + stats.errors.push(format!( + "Failed to read existing messages for session {native_id}: {e}" + )); + emit_finished!(SessionOutcome::Failed); + continue; + } + } + }; + + // Patch placeholder session field and partition. + let mut new_messages: Vec = Vec::new(); + let mut already_present_count: usize = 0; + for mut record in source_records { + if record.session == Uuid::nil() { + record.session = scroll_id; + } + if existing_ids.contains(&record.message_id) { + already_present_count += 1; + } else { + new_messages.push(record); + } + } + + let new_count = new_messages.len(); + if new_count > 0 { + if let Err(e) = archivist + .append_messages(scroll_id, new_messages, archive.clone()) + .await + { + stats.errors.push(format!( + "Failed to append messages for session {native_id}: {e}" + )); + emit_finished!(SessionOutcome::Failed); + continue; + } + stats.messages_written += new_count; + messages_written_delta = new_count as u64; + session_changed = true; + } + stats.messages_already_present += already_present_count; + messages_already_present_delta = already_present_count as u64; + + // --- Step 5: Metadata diff (reuse hoisted metadata) --- + if !session_is_new { + // SAFETY: existing_meta is Some when !session_is_new (guarded above). + let current_meta = existing_meta.unwrap(); + + let new_title = session.title.as_ref(); + let title_differs = new_title.is_some() && new_title != current_meta.title.as_ref(); + + let new_model = session + .metadata + .get("model") + .and_then(|v| v.as_str()) + .map(String::from); + let current_model = current_meta + .metadata + .get("model") + .and_then(|v| v.as_str()) + .map(String::from); + let model_differs = new_model.is_some() && new_model != current_model; + + if title_differs || model_differs { + if let Err(e) = archivist + .update_session_metadata( + scroll_id, + if title_differs { new_title.cloned() } else { None }, + if model_differs { new_model } else { None }, + archive.clone(), + ) + .await + { + stats.errors.push(format!( + "Failed to update session metadata for {native_id}: {e}" + )); + emit_finished!(SessionOutcome::Failed); + continue; + } + session_changed = true; + } + + let new_project_path = session + .metadata + .get("project_path") + .and_then(|v| v.as_str()) + .map(String::from); + let current_project_path = current_meta + .metadata + .get("project_path") + .and_then(|v| v.as_str()) + .map(String::from); + let project_path_differs = + new_project_path.is_some() && new_project_path != current_project_path; + + if project_path_differs { + // project_path lives in the free-form metadata JSON. + // Re-read to pick up any title/model changes applied above. + let mut patched_meta = archivist + .get_session_metadata(scroll_id, archive.clone()) + .await + .unwrap_or(current_meta); + if let Some(obj) = patched_meta.metadata.as_object_mut() { + let path_val = new_project_path.clone().unwrap_or_default(); + obj.insert( + "project_path".to_string(), + serde_json::Value::String(path_val.clone()), + ); + if let Some(pid) = project_map.get(&path_val) { + obj.insert( + "project_id".to_string(), + serde_json::Value::String(pid.clone()), + ); + } + } + patched_meta.updated_at = resolve_updated_at(session); + if let Ok(primary) = archivist.resolve_primary(archive.clone()).await { + if let Err(e) = primary.backend.put_session(patched_meta).await { + tracing::warn!( + scroll_id = %scroll_id, + error = %e, + "Failed to update project_path in session metadata" + ); + } + } + session_changed = true; + } + } + + // --- Step 6: Write import snapshot after successful import/update --- + { + let snapshot = ImportSnapshot::from_discovered(session); + // Re-read metadata to get the latest state (may have been updated above). + let write_result = async { + let mut meta = archivist + .get_session_metadata(scroll_id, archive.clone()) + .await?; + snapshot.write_to_metadata(&mut meta.metadata); + meta.updated_at = resolve_updated_at(session); + let primary = archivist.resolve_primary(archive.clone()).await?; + primary.backend.put_session(meta).await.map_err(|e| { + ArchivistError::InvalidRequest(format!( + "Failed to write import snapshot: {e}" + )) + }) + } + .await; + if let Err(e) = write_result { + tracing::warn!( + scroll_id = %scroll_id, + error = %e, + "Failed to write import snapshot (session still imported)" + ); + } + } + + // Accounting: exactly one of {imported, updated, skipped} per session. + let outcome = if session_is_new { + stats.sessions_imported += 1; + SessionOutcome::Imported + } else if session_changed { + stats.sessions_updated += 1; + SessionOutcome::Updated + } else { + stats.sessions_skipped += 1; + SessionOutcome::Skipped + }; + + emit_finished!(outcome); + } + + Ok(stats) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_import_stats_default() { + let stats = ImportStats::default(); + assert_eq!(stats.sessions_discovered, 0); + assert_eq!(stats.sessions_imported, 0); + assert_eq!(stats.sessions_skipped, 0); + assert_eq!(stats.sessions_updated, 0); + assert_eq!(stats.messages_written, 0); + assert_eq!(stats.messages_already_present, 0); + assert!(stats.errors.is_empty()); + } + + #[test] + fn test_import_stats_total_sessions_processed() { + let mut stats = ImportStats::default(); + stats.sessions_imported = 3; + stats.sessions_skipped = 2; + stats.sessions_updated = 1; + stats.errors.push("oops".to_string()); + assert_eq!(stats.total_sessions_processed(), 7); + } + + #[test] + fn test_import_stats_has_errors() { + let mut stats = ImportStats::default(); + assert!(!stats.has_errors()); + stats.errors.push("something went wrong".to_string()); + assert!(stats.has_errors()); + } +} + +#[cfg(test)] +mod idempotency_tests { + use super::*; + use crate::Archivist; + use chrono::Utc; + use uuid::Uuid; + + async fn mk() -> (Archivist, std::path::PathBuf) { + let tmp = std::env::temp_dir().join(format!("import_idem_{}", Uuid::now_v7())); + // Use `from_single_backend` rather than `new_with_single_archive` so + // each test's archive is fully self-contained (no shared `.archives.json` + // in the parent tempdir racing against sibling tests). + let backend = std::sync::Arc::new( + crate::backends::JsonlBackend::new(tmp.clone()).await.unwrap(), + ); + let a = Archivist::from_single_backend("main".into(), backend) + .await + .unwrap(); + (a, tmp) + } + + fn connector() -> RegisterConnectorRequest { + // Stable client_native_id so that re-registering within the same test + // (which uses an isolated temp dir per test) aliases onto the same + // connector_uid — otherwise each call would produce a fresh connector + // and defeat idempotency. + RegisterConnectorRequest { + r#type: "Fake".into(), + title: "fake".into(), + client_native_id: "fake@local:stable".into(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + } + } + + fn record(session: Uuid, id: Uuid, role: &str, content: &str) -> MessageRecord { + MessageRecord { + version: 1, + message_id: id, + session, + parent_id: None, + ts: Utc::now(), + role: role.to_string(), + author: None, + content_md: content.to_string(), + content_parts: None, + attachments: Vec::new(), + metadata: serde_json::json!({}), + } + } + + #[tokio::test] + async fn import_skips_already_present_messages() { + let (archivist, tmp) = mk().await; + + let a = Uuid::now_v7(); + let b = Uuid::now_v7(); + let c = Uuid::now_v7(); + + let discovered = vec![DiscoveredSession { + native_session_id: "s1".into(), + title: Some("t".into()), + created_at: None, + updated_at: None, + message_count: 3, + metadata: serde_json::json!({}), + project_path: None, + file_size: None, + }]; + let convert = |_: &str| -> Result> { + Ok(vec![ + record(Uuid::nil(), a, "user", "hi-a"), + record(Uuid::nil(), b, "user", "hi-b"), + record(Uuid::nil(), c, "user", "hi-c"), + ]) + }; + let stats = import_sessions(&archivist, connector(), discovered.clone(), convert, None, &ImportProgressSink::noop(), true, &HashMap::new()) + .await + .unwrap(); + assert_eq!(stats.sessions_imported, 1); + assert_eq!(stats.messages_written, 3); + + // Re-import with IDENTICAL records — nothing should be written. + let convert2 = |_: &str| -> Result> { + Ok(vec![ + record(Uuid::nil(), a, "user", "hi-a"), + record(Uuid::nil(), b, "user", "hi-b"), + record(Uuid::nil(), c, "user", "hi-c"), + ]) + }; + let stats2 = import_sessions(&archivist, connector(), discovered, convert2, None, &ImportProgressSink::noop(), true, &HashMap::new()) + .await + .unwrap(); + assert_eq!(stats2.messages_written, 0); + assert_eq!(stats2.messages_already_present, 3); + assert_eq!(stats2.sessions_skipped, 1); + assert_eq!(stats2.sessions_imported, 0); + assert_eq!(stats2.sessions_updated, 0); + + let _ = tokio::fs::remove_dir_all(tmp).await; + } + + #[tokio::test] + async fn import_appends_new_messages_only() { + let (archivist, tmp) = mk().await; + + let a = Uuid::now_v7(); + let b = Uuid::now_v7(); + let c = Uuid::now_v7(); + let d = Uuid::now_v7(); + + let discovered = vec![DiscoveredSession { + native_session_id: "s1".into(), + title: Some("t".into()), + created_at: None, + updated_at: None, + message_count: 2, + metadata: serde_json::json!({}), + project_path: None, + file_size: None, + }]; + let convert1 = |_: &str| -> Result> { + Ok(vec![ + record(Uuid::nil(), a, "user", "hi-a"), + record(Uuid::nil(), b, "user", "hi-b"), + ]) + }; + let _ = import_sessions(&archivist, connector(), discovered.clone(), convert1, None, &ImportProgressSink::noop(), true, &HashMap::new()) + .await + .unwrap(); + + // Second run: source has grown to 4 messages. + let convert2 = |_: &str| -> Result> { + Ok(vec![ + record(Uuid::nil(), a, "user", "hi-a"), + record(Uuid::nil(), b, "user", "hi-b"), + record(Uuid::nil(), c, "user", "hi-c"), + record(Uuid::nil(), d, "user", "hi-d"), + ]) + }; + let stats = import_sessions(&archivist, connector(), discovered, convert2, None, &ImportProgressSink::noop(), true, &HashMap::new()) + .await + .unwrap(); + assert_eq!(stats.messages_written, 2); + assert_eq!(stats.messages_already_present, 2); + assert_eq!(stats.sessions_updated, 1); + assert_eq!(stats.sessions_skipped, 0); + assert_eq!(stats.sessions_imported, 0); + + let _ = tokio::fs::remove_dir_all(tmp).await; + } + + #[tokio::test] + async fn import_updates_metadata_only() { + let (archivist, tmp) = mk().await; + + let a = Uuid::now_v7(); + let convert = |_: &str| -> Result> { + Ok(vec![record(Uuid::nil(), a, "user", "hi")]) + }; + + let first = vec![DiscoveredSession { + native_session_id: "s1".into(), + title: Some("old title".into()), + created_at: None, + updated_at: None, + message_count: 1, + metadata: serde_json::json!({}), + project_path: None, + file_size: None, + }]; + let _ = import_sessions(&archivist, connector(), first, convert, None, &ImportProgressSink::noop(), true, &HashMap::new()) + .await + .unwrap(); + + // Re-import with same messages but new title. + let second = vec![DiscoveredSession { + native_session_id: "s1".into(), + title: Some("new title".into()), + created_at: None, + updated_at: None, + message_count: 1, + metadata: serde_json::json!({}), + project_path: None, + file_size: None, + }]; + let convert2 = |_: &str| -> Result> { + Ok(vec![record(Uuid::nil(), a, "user", "hi")]) + }; + let stats = import_sessions(&archivist, connector(), second, convert2, None, &ImportProgressSink::noop(), true, &HashMap::new()) + .await + .unwrap(); + assert_eq!(stats.messages_written, 0); + assert_eq!(stats.sessions_updated, 1); + assert_eq!(stats.sessions_skipped, 0); + + // Verify title landed on disk. + let meta_list = archivist + .list_sessions_paged( + crate::types::SessionListQuery::default().with_limit(50), + ) + .await + .unwrap(); + assert!(meta_list.items.iter().any(|m| m.title.as_deref() == Some("new title"))); + + let _ = tokio::fs::remove_dir_all(tmp).await; + } + + #[tokio::test] + async fn import_handles_metadata_unchanged() { + let (archivist, tmp) = mk().await; + + let a = Uuid::now_v7(); + let discovered = vec![DiscoveredSession { + native_session_id: "s1".into(), + title: Some("t".into()), + created_at: None, + updated_at: None, + message_count: 1, + metadata: serde_json::json!({"model": "claude"}), + project_path: None, + file_size: None, + }]; + let convert = |_: &str| -> Result> { + Ok(vec![record(Uuid::nil(), a, "user", "hi")]) + }; + let _ = import_sessions(&archivist, connector(), discovered.clone(), convert, None, &ImportProgressSink::noop(), true, &HashMap::new()) + .await + .unwrap(); + + let convert2 = |_: &str| -> Result> { + Ok(vec![record(Uuid::nil(), a, "user", "hi")]) + }; + let stats = import_sessions(&archivist, connector(), discovered, convert2, None, &ImportProgressSink::noop(), true, &HashMap::new()) + .await + .unwrap(); + assert_eq!(stats.sessions_skipped, 1); + assert_eq!(stats.sessions_updated, 0); + assert_eq!(stats.messages_written, 0); + + let _ = tokio::fs::remove_dir_all(tmp).await; + } +} diff --git a/crates/dirigent_archivist/src/import/progress.rs b/crates/dirigent_archivist/src/import/progress.rs new file mode 100644 index 0000000..2a9569b --- /dev/null +++ b/crates/dirigent_archivist/src/import/progress.rs @@ -0,0 +1,117 @@ +//! ImportProgressSink: bounded mpsc with drop-oldest-non-terminal overflow. +//! Terminal events (ImportDone / ImportFailed) are never dropped — on full +//! channel they evict oldest non-terminal events until they fit. The import +//! thread never backpressures on a slow consumer. + +use serde::{Deserialize, Serialize}; +use tokio::sync::mpsc; + +use super::ImportDiscovery; +use super::ImportStats; + +const DEFAULT_CAPACITY: usize = 64; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "kind")] +pub enum ImportProgressEvent { + DiscoveryStarted { source: String }, + DiscoveryProgress { scanned: usize, estimated_total: Option }, + DiscoveryDone { discovered: ImportDiscovery }, + SessionStarted { native_id: String, index: usize, total: usize }, + SessionFinished { native_id: String, outcome: SessionOutcome, stats_delta: StatsDelta }, + ImportDone { stats: ImportStats }, + ImportFailed { error: String }, +} + +impl ImportProgressEvent { + pub fn is_terminal(&self) -> bool { + matches!(self, ImportProgressEvent::ImportDone { .. } | ImportProgressEvent::ImportFailed { .. }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum SessionOutcome { Imported, Skipped, Updated, Failed } + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct StatsDelta { + pub messages_written: u64, + pub messages_already_present: u64, +} + +pub struct ImportProgressSink { + inner: SinkInner, +} + +enum SinkInner { + Live { tx: mpsc::Sender }, + Noop, +} + +impl ImportProgressSink { + pub fn channel() -> (Self, mpsc::Receiver) { + let (tx, rx) = mpsc::channel(DEFAULT_CAPACITY); + (Self { inner: SinkInner::Live { tx } }, rx) + } + + pub fn noop() -> Self { Self { inner: SinkInner::Noop } } + + pub async fn send(&self, evt: ImportProgressEvent) { + match &self.inner { + SinkInner::Noop => {} + SinkInner::Live { tx } => { + if evt.is_terminal() { + // Force-send: guaranteed delivery of terminal events. + let _ = tx.send(evt).await; + } else { + // Best-effort: drop non-terminal events when the channel is full. + match tx.try_send(evt) { + Ok(()) => {} + Err(mpsc::error::TrySendError::Full(_)) => { + tracing::debug!("import progress: dropped non-terminal event (queue full)"); + } + Err(mpsc::error::TrySendError::Closed(_)) => { + tracing::warn!("import progress: consumer gone"); + } + } + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn terminal_events_always_delivered() { + let (sink, mut rx) = ImportProgressSink::channel(); + // Fill the channel with non-terminal events (mostly drop). + for i in 0..1000 { + sink.send(ImportProgressEvent::SessionStarted { + native_id: format!("s{i}"), index: i, total: 1000, + }).await; + } + // Consumer drains in background. + let handle = tokio::spawn(async move { + let mut saw_done = false; + while let Some(e) = rx.recv().await { + if matches!(e, ImportProgressEvent::ImportDone { .. }) { + saw_done = true; + break; + } + } + saw_done + }); + sink.send(ImportProgressEvent::ImportDone { stats: ImportStats::default() }).await; + let saw_done = tokio::time::timeout(std::time::Duration::from_secs(2), handle).await.unwrap().unwrap(); + assert!(saw_done); + } + + #[tokio::test] + async fn noop_sink_never_fails() { + let sink = ImportProgressSink::noop(); + sink.send(ImportProgressEvent::ImportDone { stats: ImportStats::default() }).await; + } +} diff --git a/crates/dirigent_archivist/src/import/registry.rs b/crates/dirigent_archivist/src/import/registry.rs new file mode 100644 index 0000000..3902247 --- /dev/null +++ b/crates/dirigent_archivist/src/import/registry.rs @@ -0,0 +1,93 @@ +//! Dynamic registry of Importer implementations. Populated at boot. + +use std::collections::HashMap; +use std::sync::Arc; + +use super::trait_def::{Importer, ImporterInfo}; + +pub struct ImporterRegistry { + importers: HashMap<&'static str, Arc>, +} + +impl ImporterRegistry { + pub fn new() -> Self { + Self { + importers: HashMap::new(), + } + } + + /// Populate with all built-in importers. Feature flags select which ship. + pub fn with_defaults() -> Self { + let mut r = Self::new(); + #[cfg(feature = "importer-claude")] + r.register(Arc::new(super::sources::claude::ClaudeImporter)); + #[cfg(feature = "importer-chatgpt")] + r.register(Arc::new(super::sources::chatgpt::ChatGptImporter)); + #[cfg(feature = "importer-codex")] + r.register(Arc::new(super::sources::codex::CodexImporter)); + r + } + + pub fn register(&mut self, importer: Arc) { + self.importers.insert(importer.source_name(), importer); + } + + pub fn get(&self, name: &str) -> Option> { + self.importers.get(name).cloned() + } + + pub fn list(&self) -> Vec { + self.importers + .values() + .map(|i| ImporterInfo { + source_name: i.source_name().to_string(), + display_name: pretty_name(i.source_name()), + config_shape: i.config_shape(), + }) + .collect() + } +} + +fn pretty_name(source: &str) -> String { + match source { + "claude" => "Claude Code".into(), + "chatgpt" => "ChatGPT (OpenAI)".into(), + "codex" => "OpenAI Codex".into(), + other => other.to_string(), + } +} + +impl Default for ImporterRegistry { + fn default() -> Self { + Self::with_defaults() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn defaults_include_claude_when_feature_enabled() { + let reg = ImporterRegistry::with_defaults(); + let list = reg.list(); + #[cfg(feature = "importer-claude")] + { + assert!(list.iter().any(|i| i.source_name == "claude")); + assert!(reg.get("claude").is_some()); + } + #[cfg(not(feature = "importer-claude"))] + { + let _ = list; + assert!(reg.get("claude").is_none()); + } + } + + #[test] + fn pretty_name_known_sources() { + assert_eq!(pretty_name("claude"), "Claude Code"); + assert_eq!(pretty_name("chatgpt"), "ChatGPT (OpenAI)"); + assert_eq!(pretty_name("codex"), "OpenAI Codex"); + assert_eq!(pretty_name("custom"), "custom"); + } +} diff --git a/crates/dirigent_archivist/src/import/sources/chatgpt.rs b/crates/dirigent_archivist/src/import/sources/chatgpt.rs new file mode 100644 index 0000000..7af054d --- /dev/null +++ b/crates/dirigent_archivist/src/import/sources/chatgpt.rs @@ -0,0 +1,361 @@ +//! ChatGPT importer: takes a path to a conversations.json file. + +use std::path::PathBuf; + +use async_trait::async_trait; +use chrono::Utc; +use uuid::Uuid; + +use dirigent_chatgpt::{ContentPart, ParsedConversation, ParsedMessage}; + +use super::super::progress::ImportProgressSink; +use super::super::trait_def::{ + ConfigField, ConfigFieldKind, ImportConfig, ImportConfigShape, ImportError, ImportTarget, + Importer, +}; +use super::super::{ + import_sessions, DiscoveredSession, ImportDiscovery, ImportProject, ImportStats, +}; +use crate::coordinator::Archivist; +use crate::error::{ArchivistError, Result}; +use crate::types::{MessageRecord, RegisterConnectorRequest}; + +/// Connector type string used for imported ChatGPT sessions. +pub const CHATGPT_CONNECTOR_TYPE: &str = "ChatGPT"; + +/// Fingerprint prefix for locally-imported ChatGPT exports. +pub const CHATGPT_FINGERPRINT_PREFIX: &str = "import/local:chatgpt"; + +/// Namespace UUID for deterministic UUIDv5 derivations on ChatGPT message ids +/// that are not already valid UUIDs. +const CHATGPT_MESSAGE_NS: Uuid = Uuid::from_u128(0x4e58_a7cb_bf1c_4de2_b7c9_8c31_11b3_1112); + +pub struct ChatGptImporter; + +#[async_trait] +impl Importer for ChatGptImporter { + fn source_name(&self) -> &'static str { + "chatgpt" + } + + fn config_shape(&self) -> ImportConfigShape { + ImportConfigShape { + fields: vec![ConfigField { + key: "path".into(), + label: "conversations.json path".into(), + kind: ConfigFieldKind::File { + extension: Some("json".into()), + }, + required: true, + help: Some( + "Unzipped OpenAI data export \u{2192} conversations.json".into(), + ), + }], + example: ImportConfig { + source: "chatgpt".into(), + params: { + let mut m = std::collections::BTreeMap::new(); + m.insert( + "path".into(), + serde_json::json!("~/Downloads/chatgpt-export/conversations.json"), + ); + m + }, + }, + } + } + + async fn discover( + &self, + cfg: &ImportConfig, + ) -> std::result::Result { + let path = require_path(cfg)?; + let convs = dirigent_chatgpt::parse_export(&path) + .map_err(|e| ImportError::Discovery(e.to_string()))?; + + let total_sessions = convs.len(); + let total_estimated_messages: usize = convs.iter().map(|c| c.messages.len()).sum(); + + // ChatGPT exports don't carry per-project information, so we bucket + // everything into a single synthetic project named after the file. + let project_name = path + .file_name() + .and_then(|s| s.to_str()) + .unwrap_or("ChatGPT export") + .to_string(); + + Ok(ImportDiscovery { + source_name: "ChatGPT".to_string(), + source_path: path.display().to_string(), + projects: vec![ImportProject { + name: project_name, + session_count: total_sessions, + }], + total_sessions, + total_estimated_messages, + }) + } + + async fn import( + &self, + cfg: &ImportConfig, + archivist: &Archivist, + target: ImportTarget, + progress: ImportProgressSink, + ) -> std::result::Result { + let path = require_path(cfg)?; + let convs = dirigent_chatgpt::parse_export(&path) + .map_err(|e| ImportError::Parser(e.to_string()))?; + + // Build discovered-session list + keep the parsed convs handy for + // message conversion inside the closure. + let mut discovered: Vec = Vec::with_capacity(convs.len()); + for c in &convs { + let metadata = serde_json::json!({ + "source": "chatgpt", + "conversation_id": c.id, + "parser_metadata": c.metadata.clone(), + }); + discovered.push(DiscoveredSession { + native_session_id: c.id.clone(), + title: c.title.clone(), + created_at: c.created_at, + updated_at: c.updated_at, + message_count: c.messages.len(), + metadata, + project_path: None, + file_size: None, + }); + } + + // Map native_id -> parsed conversation for O(1) lookup in `convert`. + let conv_lookup: std::collections::HashMap = convs + .into_iter() + .map(|c| (c.id.clone(), c)) + .collect(); + + // Fingerprint the import by the canonical path. Re-running against the + // same file aliases onto the same connector. + let canonical_path = path.canonicalize().unwrap_or_else(|_| path.clone()); + let fingerprint = format!("{}:{}", CHATGPT_FINGERPRINT_PREFIX, canonical_path.display()); + + let connector_req = RegisterConnectorRequest { + r#type: CHATGPT_CONNECTOR_TYPE.to_string(), + title: format!("ChatGPT ({})", canonical_path.display()), + client_native_id: fingerprint.clone(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some(fingerprint), + }; + + let convert = |native_id: &str| -> Result> { + let conv = conv_lookup.get(native_id).ok_or_else(|| { + ArchivistError::InvalidRequest(format!( + "Parsed conversation not found for native_id: {}", + native_id + )) + })?; + Ok(convert_conversation_to_records(conv)) + }; + + import_sessions( + archivist, + connector_req, + discovered, + convert, + target.archive, + &progress, + false, + &target.project_map, + ) + .await + .map_err(|e| ImportError::Archivist(e.to_string())) + } +} + +// --------------------------------------------------------------------------- +// Conversion helpers +// --------------------------------------------------------------------------- + +fn require_path(cfg: &ImportConfig) -> std::result::Result { + cfg.params + .get("path") + .and_then(|v| v.as_str()) + .map(PathBuf::from) + .ok_or_else(|| ImportError::Config("missing `path`".into())) +} + +/// Prefer to parse the native id as a UUID if possible; otherwise derive a +/// stable UUIDv5 under [`CHATGPT_MESSAGE_NS`]. +fn parse_or_derive_uuid(native_id: &str) -> Uuid { + Uuid::parse_str(native_id) + .unwrap_or_else(|_| Uuid::new_v5(&CHATGPT_MESSAGE_NS, native_id.as_bytes())) +} + +/// Convert parsed `ContentPart`s into `dirigent_protocol::MessagePart`s. +fn parts_to_message_parts(parts: &[ContentPart]) -> Vec { + parts + .iter() + .map(|p| match p { + ContentPart::Text { text } => dirigent_protocol::MessagePart::Text { + text: text.clone(), + }, + ContentPart::Code { language, text } => dirigent_protocol::MessagePart::Code { + language: language.clone().unwrap_or_default(), + code: text.clone(), + }, + ContentPart::Tool { name, input, output } => dirigent_protocol::MessagePart::Tool { + tool: name.clone(), + tool_call_id: None, + input: input.clone(), + output: output.clone(), + }, + }) + .collect() +} + +/// Flatten a list of parsed content parts into a markdown-y string for the +/// `content_md` fallback surface. +fn parts_to_markdown(parts: &[ContentPart]) -> String { + parts + .iter() + .map(|p| match p { + ContentPart::Text { text } => text.clone(), + ContentPart::Code { language, text } => { + let lang = language.clone().unwrap_or_default(); + format!("```{}\n{}\n```", lang, text) + } + ContentPart::Tool { name, .. } => format!("[Tool: {}]", name), + }) + .collect::>() + .join("\n\n") +} + +/// Convert a parsed ChatGPT conversation into a vector of `MessageRecord`s. +/// +/// Each message's `session` field is left as `Uuid::nil()`; the generic +/// `import_sessions` orchestrator patches it to the real scroll id. +fn convert_conversation_to_records(conv: &ParsedConversation) -> Vec { + conv.messages + .iter() + .filter_map(convert_parsed_message) + .collect() +} + +fn convert_parsed_message(msg: &ParsedMessage) -> Option { + // Skip messages with entirely empty text payloads (nothing to archive). + let content_md = parts_to_markdown(&msg.content); + if content_md.trim().is_empty() && msg.content.iter().all(is_part_empty) { + return None; + } + + let parts = parts_to_message_parts(&msg.content); + let content_parts = serde_json::to_value(&parts).ok(); + + let ts = msg.ts.unwrap_or_else(Utc::now); + let message_id = if msg.id.is_empty() { + // Fallback: derive from role + timestamp + a hash of content. + let key = format!("{}:{}:{}", msg.role, ts.to_rfc3339(), content_md); + Uuid::new_v5(&CHATGPT_MESSAGE_NS, key.as_bytes()) + } else { + parse_or_derive_uuid(&msg.id) + }; + + Some(MessageRecord { + version: 1, + message_id, + session: Uuid::nil(), + parent_id: None, + ts, + role: msg.role.clone(), + author: None, + content_md, + content_parts, + attachments: Vec::new(), + metadata: msg.metadata.clone(), + }) +} + +fn is_part_empty(p: &ContentPart) -> bool { + match p { + ContentPart::Text { text } => text.trim().is_empty(), + ContentPart::Code { text, .. } => text.trim().is_empty(), + ContentPart::Tool { .. } => false, + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_or_derive_uuid_parses_real_uuid() { + let real = "12345678-1234-5678-1234-567812345678"; + let u = parse_or_derive_uuid(real); + assert_eq!(u.to_string(), real); + } + + #[test] + fn parse_or_derive_uuid_falls_back_to_v5() { + let a = parse_or_derive_uuid("not-a-uuid"); + let b = parse_or_derive_uuid("not-a-uuid"); + assert_eq!(a, b, "deterministic UUIDv5 derivation"); + let c = parse_or_derive_uuid("different"); + assert_ne!(a, c); + } + + #[test] + fn parts_to_message_parts_covers_all_variants() { + let parts = vec![ + ContentPart::Text { text: "hi".into() }, + ContentPart::Code { + language: Some("rust".into()), + text: "fn main() {}".into(), + }, + ContentPart::Tool { + name: "browser".into(), + input: serde_json::json!({"url": "https://example.com"}), + output: Some(serde_json::json!({"status": 200})), + }, + ]; + let mp = parts_to_message_parts(&parts); + assert_eq!(mp.len(), 3); + assert!(matches!(&mp[0], dirigent_protocol::MessagePart::Text { .. })); + assert!(matches!(&mp[1], dirigent_protocol::MessagePart::Code { .. })); + assert!(matches!(&mp[2], dirigent_protocol::MessagePart::Tool { .. })); + } + + #[test] + fn empty_parsed_message_is_skipped() { + let msg = ParsedMessage { + id: "m1".into(), + role: "system".into(), + ts: None, + content: vec![ContentPart::Text { text: " ".into() }], + metadata: serde_json::Value::Null, + }; + assert!(convert_parsed_message(&msg).is_none()); + } + + #[test] + fn non_empty_parsed_message_round_trips() { + let msg = ParsedMessage { + id: "m1".into(), + role: "user".into(), + ts: None, + content: vec![ContentPart::Text { + text: "hello".into(), + }], + metadata: serde_json::Value::Null, + }; + let record = convert_parsed_message(&msg).expect("should convert"); + assert_eq!(record.role, "user"); + assert_eq!(record.content_md, "hello"); + assert_eq!(record.session, Uuid::nil()); + assert!(record.content_parts.is_some()); + } +} diff --git a/crates/dirigent_archivist/src/import/sources/claude.rs b/crates/dirigent_archivist/src/import/sources/claude.rs new file mode 100644 index 0000000..c72d464 --- /dev/null +++ b/crates/dirigent_archivist/src/import/sources/claude.rs @@ -0,0 +1,1356 @@ +//! Claude Code session importer. +//! +//! Uses dirigent_anth to discover and parse Claude Code's local JSONL sessions, +//! then converts them into archivist MessageRecords for import. + +use std::collections::HashMap; + +use camino::Utf8PathBuf; +use chrono::{DateTime, Utc}; +use uuid::Uuid; + +use dirigent_anth::types::{ + Content, ContentBlock, RawAssistantMessage, RawMessage, RawUserMessage, SessionRef, +}; +use dirigent_anth::{classify_noise, discover_projects, load_session, parse_timestamp}; + +use async_trait::async_trait; + +use super::super::progress::ImportProgressSink; +use super::super::trait_def::{ + ConfigField, ConfigFieldKind, ImportConfig, ImportConfigShape, ImportError, ImportTarget, + Importer, +}; +use super::super::{ + import_sessions, DiscoveredSession, ImportDiscovery, ImportProject, ImportStats, +}; +use crate::coordinator::Archivist; +use crate::error::{ArchivistError, Result}; +use crate::types::{MessageRecord, RegisterConnectorRequest}; + +/// Connector type string used for imported Claude Code sessions. +pub const CLAUDE_CONNECTOR_TYPE: &str = "ClaudeCode"; + +/// Fingerprint prefix for locally-imported Claude Code sessions. +pub const CLAUDE_FINGERPRINT_PREFIX: &str = "import/local:claude-code"; + +/// Namespace UUID for deterministic fallback message_ids when Claude's source +/// `uuid` field is missing. Arbitrary, stable constant. +const CLAUDE_MESSAGE_NS: uuid::Uuid = uuid::Uuid::from_bytes([ + 0x43, 0x4c, 0x41, 0x55, 0x44, 0x45, 0x2d, 0x4d, 0x53, 0x47, 0x2d, 0x4e, 0x53, 0x2d, 0x56, 0x35, +]); + +/// Derive a stable `message_id` for a Claude message record. +/// +/// Priority: +/// 1. Parse the `source_uuid` as a UUID (Claude's JSONL carries per-message +/// UUIDs natively). +/// 2. Fall back to `Uuid::new_v5(CLAUDE_MESSAGE_NS, ":::")`. +fn derive_message_id( + source_uuid: Option<&str>, + native_session_id: &str, + ts: &chrono::DateTime, + role: &str, + content_md: &str, +) -> uuid::Uuid { + if let Some(raw) = source_uuid { + if let Ok(parsed) = uuid::Uuid::parse_str(raw) { + return parsed; + } + } + let key = format!("{}:{}:{}:{}", native_session_id, ts.to_rfc3339(), role, content_md); + uuid::Uuid::new_v5(&CLAUDE_MESSAGE_NS, key.as_bytes()) +} + +// --------------------------------------------------------------------------- +// Discovery +// --------------------------------------------------------------------------- + +/// Discover Claude Code sessions available for import. +/// +/// If `claude_home` is `Some`, uses it as the Claude home directory; +/// otherwise auto-detects via `dirigent_anth::discover_claude_home()`. +/// +/// Returns the resolved home path and a summary of discovered projects/sessions. +pub fn discover_claude_import( + claude_home: Option<&str>, +) -> std::result::Result<(Utf8PathBuf, ImportDiscovery), String> { + let home = match claude_home { + Some(p) => Utf8PathBuf::from(p), + None => dirigent_anth::discover_claude_home().map_err(|e| e.to_string())?, + }; + + let projects = discover_projects(&home).map_err(|e| e.to_string())?; + + let mut import_projects = Vec::new(); + let mut total_sessions: usize = 0; + let mut total_estimated_messages: usize = 0; + + for project in &projects { + let session_count = project.sessions.len(); + total_sessions += session_count; + + for session in &project.sessions { + if let Some(ref idx) = session.index_entry { + total_estimated_messages += idx.message_count.unwrap_or(0) as usize; + } + } + + import_projects.push(ImportProject { + name: project.original_path.clone(), + session_count, + }); + } + + let discovery = ImportDiscovery { + source_name: "Claude Code".to_string(), + source_path: home.to_string(), + projects: import_projects, + total_sessions, + total_estimated_messages, + }; + + Ok((home, discovery)) +} + +// --------------------------------------------------------------------------- +// Import orchestration +// --------------------------------------------------------------------------- + +/// Import all Claude Code sessions from `claude_home` into the archivist. +/// +/// This discovers all projects and sessions, registers a connector, and imports +/// each session by parsing its JSONL and converting messages to `MessageRecord`s. +/// +/// Pass `&ImportProgressSink::noop()` when progress reporting is not needed. +pub async fn import_claude_sessions( + archivist: &crate::coordinator::Archivist, + claude_home: &Utf8PathBuf, + archive: Option, + progress: &ImportProgressSink, + project_map: &HashMap, +) -> Result { + let projects = + discover_projects(claude_home).map_err(|e| ArchivistError::InvalidRequest(e.to_string()))?; + + // Build discovered sessions and a lookup map. + let mut discovered: Vec = Vec::new(); + let mut session_lookup: HashMap = HashMap::new(); + + for project in &projects { + for session_ref in &project.sessions { + let idx = session_ref.index_entry.as_ref(); + + let mut title = idx + .and_then(|i| i.summary.clone()) + .or_else(|| idx.and_then(|i| i.first_prompt.clone())); + + // Fallback: derive title from first user message content. + // NOTE: This parses the JSONL file, adding I/O cost during discovery. + // The same session will be parsed again during message conversion. + // Acceptable for a background import task; could be optimized later + // by restructuring to parse once and extract title + messages together. + if title.is_none() { + if let Ok(parsed) = load_session(session_ref) { + for msg in &parsed.messages { + // Skip noise (queue ops, meta, warmup, etc.) + if classify_noise(msg).is_some() { + continue; + } + if let RawMessage::User(user) = msg { + let text = match &user.message.content { + Content::Text(s) => s.clone(), + Content::Blocks(blocks) => blocks + .iter() + .filter_map(|b| match b { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join(" "), + }; + let trimmed = text.trim(); + // Skip system-injected XML content (commands, caveats, etc.) + if trimmed.is_empty() + || trimmed.starts_with('<') + || is_pure_tool_result_message(user) + { + continue; + } + title = Some(derive_title_from_content(trimmed)); + break; + } + } + } + } + + let created_at = idx + .and_then(|i| i.created.as_ref()) + .and_then(parse_timestamp); + + let updated_at = idx + .and_then(|i| i.modified.as_ref()) + .and_then(parse_timestamp); + + let message_count = idx.and_then(|i| i.message_count).unwrap_or(0) as usize; + + let metadata = serde_json::json!({ + "project_path": project.original_path, + "git_branch": idx.and_then(|i| i.git_branch.clone()), + }); + + // Get file size from the JSONL file for fingerprint-based change detection. + let file_size = std::fs::metadata(session_ref.jsonl_path.as_std_path()) + .ok() + .map(|m| m.len()); + + discovered.push(DiscoveredSession { + native_session_id: session_ref.id.clone(), + title, + created_at, + updated_at, + message_count, + metadata, + project_path: Some(project.original_path.clone()), + file_size, + }); + + session_lookup.insert(session_ref.id.clone(), session_ref.clone()); + } + } + + let fingerprint = format!("{}:{}", CLAUDE_FINGERPRINT_PREFIX, claude_home); + + let connector_req = RegisterConnectorRequest { + r#type: CLAUDE_CONNECTOR_TYPE.to_string(), + title: format!("Claude Code ({})", claude_home), + client_native_id: fingerprint.clone(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some(fingerprint), + }; + + let convert = |native_id: &str| -> Result> { + let session_ref = session_lookup.get(native_id).ok_or_else(|| { + ArchivistError::InvalidRequest(format!( + "Session ref not found for native_id: {}", + native_id + )) + })?; + convert_session_to_records(session_ref) + }; + + let mut stats = import_sessions( + archivist, + connector_req, + discovered, + convert, + archive.clone(), + progress, + false, + project_map, + ) + .await?; + + // Phase 2: Import subagent sessions + import_subagents(archivist, &session_lookup, archive, &mut stats).await?; + + Ok(stats) +} + +/// Import subagent sessions for all parent sessions that have them. +/// +/// For each session in the lookup, loads and parses it, then for each subagent: +/// 1. Registers a new session with `is_subagent=true` and `parent_scroll_id` +/// 2. Converts subagent messages to MessageRecords +/// 3. Writes a DAG edge linking parent to child +async fn import_subagents( + archivist: &crate::coordinator::Archivist, + session_lookup: &HashMap, + archive: Option, + stats: &mut ImportStats, +) -> Result<()> { + use crate::types::{ + Continuation, DagEdge, RegisterSessionRequest, RegisterStatus, SessionCompleteness, + }; + + for (native_id, session_ref) in session_lookup { + // Load the full parsed session (includes subagents with linkage) + let parsed = match load_session(session_ref) { + Ok(p) => p, + Err(e) => { + tracing::debug!(native_id = %native_id, error = %e, "Failed to load session for subagent import"); + continue; + } + }; + + if parsed.subagents.is_empty() { + continue; + } + + // Resolve the parent's scroll_id from the archive + let parent_scroll_id = match archivist + .find_session_owner(native_id, archive.clone()) + .await + { + Ok(Some((_connector_uid, scroll_id))) => scroll_id, + _ => { + tracing::debug!(native_id = %native_id, "Parent session not found in archive, skipping subagent import"); + continue; + } + }; + + // Get parent's connector_uid from metadata + let parent_meta = match archivist.get_session_metadata(parent_scroll_id, archive.clone()).await { + Ok(m) => m, + Err(_) => continue, + }; + + for subagent in &parsed.subagents { + // Use composite native ID for idempotent dedup + let subagent_native_id = format!("{}:agent:{}", native_id, subagent.agent_id); + + // Check if already imported (idempotent) + if archivist + .resolve_session(parent_meta.connector_uid, &subagent_native_id, archive.clone()) + .await + .is_ok() + { + tracing::debug!( + agent_id = %subagent.agent_id, + parent = %native_id, + "Subagent already imported, skipping" + ); + continue; + } + + // Derive a title from the Agent tool call description if available + let title = parsed.tool_exchanges.iter() + .find(|ex| { + ex.call.name == dirigent_anth::types::ToolName::Agent + && subagent.parent_tool_call_id.as_deref() == Some(&ex.call.id) + }) + .and_then(|ex| ex.call.input.get("description").and_then(|v| v.as_str())) + .map(|s| format!("[{}] {}", subagent.meta.agent_type.as_deref().unwrap_or("Agent"), s)); + + let register_req = RegisterSessionRequest { + connector_uid: parent_meta.connector_uid, + native_session_id: subagent_native_id.clone(), + title, + custom_scroll_id: None, + metadata: serde_json::json!({ + "parent_native_session_id": native_id, + }), + completeness: SessionCompleteness::Complete, + parent_scroll_id: Some(parent_scroll_id), + is_subagent: true, + continuation: Some(Continuation::Subagent), + agent_id: Some(subagent.agent_id.clone()), + subagent_type: subagent.meta.agent_type.clone(), + spawning_tool_use_id: subagent.parent_tool_call_id.clone(), + }; + + let child_scroll_id = match archivist + .register_session(register_req, archive.clone()) + .await + { + Ok(resp) => match resp.status { + RegisterStatus::Accepted => resp.scroll_id, + RegisterStatus::Aliased => { + tracing::debug!(agent_id = %subagent.agent_id, "Subagent session aliased"); + continue; + } + RegisterStatus::Rejected => { + stats.errors.push(format!( + "Subagent registration rejected for agent_id {}", + subagent.agent_id + )); + continue; + } + }, + Err(e) => { + stats.errors.push(format!( + "Failed to register subagent {}: {}", + subagent.agent_id, e + )); + continue; + } + }; + + // Convert subagent messages to records + let mut records: Vec = subagent + .messages + .iter() + .filter_map(|msg| { + convert_ant_message_with_exchanges( + msg, + child_scroll_id, + &[], + &subagent_native_id, + ) + }) + .collect(); + + // Patch session field + for record in &mut records { + if record.session == Uuid::nil() { + record.session = child_scroll_id; + } + } + + let msg_count = records.len(); + + if let Err(e) = archivist + .append_messages(child_scroll_id, records, archive.clone()) + .await + { + stats.errors.push(format!( + "Failed to append subagent messages for {}: {}", + subagent.agent_id, e + )); + continue; + } + + // Write DAG edge + let edge = DagEdge { + parent: parent_scroll_id, + child: child_scroll_id, + agent_id: subagent.agent_id.clone(), + subagent_type: subagent.meta.agent_type.clone(), + tool_use_id: subagent.parent_tool_call_id.clone(), + ts: Some(chrono::Utc::now()), + }; + + if let Err(e) = archivist.append_dag_edge(edge, archive.clone()).await { + tracing::warn!( + agent_id = %subagent.agent_id, + error = %e, + "Failed to write DAG edge (subagent still imported)" + ); + } + + stats.sessions_imported += 1; + stats.messages_written += msg_count; + + tracing::info!( + agent_id = %subagent.agent_id, + parent_scroll_id = %parent_scroll_id, + child_scroll_id = %child_scroll_id, + messages = msg_count, + "Subagent session imported" + ); + } + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Session → MessageRecord conversion +// --------------------------------------------------------------------------- + +/// Convert a single Claude Code session (referenced by `SessionRef`) into a +/// vector of `MessageRecord`s. +/// +/// Uses `dirigent_anth::load_session` to parse, dedup, and correlate the JSONL, +/// then converts each non-noise message. Tool calls are correlated with their +/// results using `ParsedSession.tool_exchanges`, and user messages that contain +/// only tool results are suppressed (their content is merged into the assistant's +/// `MessagePart::Tool` parts). The `session` field on each record is set to +/// `Uuid::nil()` — the import orchestrator patches it to the real scroll_id. +pub fn convert_session_to_records(session_ref: &SessionRef) -> Result> { + let parsed = load_session(session_ref) + .map_err(|e| ArchivistError::InvalidRequest(format!("Failed to load session: {}", e)))?; + + let placeholder = Uuid::nil(); + let native_session_id = session_ref.id.as_str(); + let records: Vec = parsed + .messages + .iter() + .filter_map(|msg| { + convert_ant_message_with_exchanges( + msg, + placeholder, + &parsed.tool_exchanges, + native_session_id, + ) + }) + .collect(); + + Ok(records) +} + +/// Convert a single `RawMessage` from dirigent_anth into a `MessageRecord`. +/// +/// Returns `None` for noise messages (queue operations, meta, warmup, etc.) +/// and for message types that don't carry user-visible content (Progress, System). +/// +/// This is a convenience wrapper that delegates to [`convert_ant_message_with_exchanges`] +/// with an empty exchange list (tool results won't be correlated). +/// +/// Note: callers using this public API do not benefit from stable fallback +/// message_ids (native_session_id is empty). Prefer +/// [`convert_ant_message_with_exchanges`] with a real `native_session_id` for +/// idempotent imports. +pub fn convert_ant_message(msg: &RawMessage, scroll_id: Uuid) -> Option { + convert_ant_message_with_exchanges(msg, scroll_id, &[], "") +} + +/// Convert a single `RawMessage` with tool exchange correlation. +/// +/// User messages that contain ONLY tool results are suppressed (their content +/// is merged into the assistant's `MessagePart::Tool` parts via exchanges). +/// +/// `native_session_id` is used as part of the deterministic fallback +/// `message_id` when Claude's native per-message `uuid` field is absent or +/// unparseable. Pass the native session id (`session_ref.id`) for idempotent +/// imports; an empty string is acceptable for one-off conversions that do not +/// need stable ids. +pub fn convert_ant_message_with_exchanges( + msg: &RawMessage, + scroll_id: Uuid, + tool_exchanges: &[dirigent_anth::types::ToolExchange], + native_session_id: &str, +) -> Option { + if classify_noise(msg).is_some() { + return None; + } + + match msg { + RawMessage::User(user) => { + // Suppress user messages that are purely tool results + if is_pure_tool_result_message(user) { + return None; + } + convert_user_message(user, scroll_id, native_session_id) + } + RawMessage::Assistant(assistant) => convert_assistant_message_with_exchanges( + assistant, + scroll_id, + tool_exchanges, + native_session_id, + ), + RawMessage::Progress(_) + | RawMessage::System(_) + | RawMessage::QueueOperation(_) + | RawMessage::FileHistorySnapshot(_) + | RawMessage::LastPrompt(_) => None, + } +} + +/// Check if a user message contains only ToolResult blocks (no actual user text). +fn is_pure_tool_result_message(user: &RawUserMessage) -> bool { + match &user.message.content { + Content::Text(s) => s.trim().is_empty(), + Content::Blocks(blocks) => { + blocks + .iter() + .all(|b| matches!(b, ContentBlock::ToolResult { .. })) + } + } +} + +// --------------------------------------------------------------------------- +// User message conversion +// --------------------------------------------------------------------------- + +fn convert_user_message( + user: &RawUserMessage, + scroll_id: Uuid, + native_session_id: &str, +) -> Option { + let ts = user + .timestamp + .as_ref() + .and_then(|s| parse_timestamp_value(&serde_json::Value::String(s.clone()))) + .unwrap_or_else(Utc::now); + + let content_md = match &user.message.content { + Content::Text(s) => s.clone(), + Content::Blocks(blocks) => blocks_to_markdown_user(blocks), + }; + + if content_md.trim().is_empty() { + return None; + } + + // Build content_parts as Vec for proper UI rendering + let content_parts = serde_json::to_value(vec![dirigent_protocol::MessagePart::Text { + text: content_md.clone(), + }]) + .ok(); + + let mut meta = serde_json::Map::new(); + if let Some(ref cwd) = user.cwd { + meta.insert("cwd".to_string(), serde_json::Value::String(cwd.clone())); + } + if let Some(ref branch) = user.git_branch { + meta.insert( + "git_branch".to_string(), + serde_json::Value::String(branch.clone()), + ); + } + if let Some(ref version) = user.version { + meta.insert( + "claude_version".to_string(), + serde_json::Value::String(version.clone()), + ); + } + + Some(MessageRecord { + version: 1, + message_id: derive_message_id( + user.uuid.as_deref(), + native_session_id, + &ts, + "user", + &content_md, + ), + session: scroll_id, + parent_id: None, + ts, + role: "user".to_string(), + author: None, + content_md, + content_parts, + attachments: Vec::new(), + metadata: serde_json::Value::Object(meta), + }) +} + +// --------------------------------------------------------------------------- +// Assistant message conversion +// --------------------------------------------------------------------------- + +fn convert_assistant_message_with_exchanges( + assistant: &RawAssistantMessage, + scroll_id: Uuid, + tool_exchanges: &[dirigent_anth::types::ToolExchange], + native_session_id: &str, +) -> Option { + let ts = assistant + .timestamp + .as_ref() + .and_then(|s| parse_timestamp_value(&serde_json::Value::String(s.clone()))) + .unwrap_or_else(Utc::now); + + let parts = assistant_blocks_to_message_parts(&assistant.message.content, tool_exchanges); + + if parts.is_empty() { + return None; + } + + // Build content_md as fallback from parts + let content_md = parts + .iter() + .map(|p| match p { + dirigent_protocol::MessagePart::Text { text } => text.clone(), + dirigent_protocol::MessagePart::Thinking { text } => { + format!("\n{}\n", text) + } + dirigent_protocol::MessagePart::Tool { tool, .. } => format!("[Tool: {}]", tool), + dirigent_protocol::MessagePart::Code { language, code } => { + format!("```{}\n{}\n```", language, code) + } + dirigent_protocol::MessagePart::File { path, .. } => format!("[File: {}]", path), + }) + .collect::>() + .join("\n\n"); + + let content_parts = serde_json::to_value(&parts).ok(); + + let model = assistant.message.model.clone(); + + let mut meta = serde_json::Map::new(); + if let Some(ref m) = model { + meta.insert("model".to_string(), serde_json::Value::String(m.clone())); + } + if let Some(ref usage) = assistant.message.usage { + meta.insert("usage".to_string(), usage.clone()); + } + if let Some(ref stop_reason) = assistant.message.stop_reason { + meta.insert( + "stop_reason".to_string(), + serde_json::Value::String(stop_reason.clone()), + ); + } + + Some(MessageRecord { + version: 1, + message_id: derive_message_id( + assistant.uuid.as_deref(), + native_session_id, + &ts, + "assistant", + &content_md, + ), + session: scroll_id, + parent_id: None, + ts, + role: "assistant".to_string(), + author: model, + content_md, + content_parts, + attachments: Vec::new(), + metadata: serde_json::Value::Object(meta), + }) +} + +// --------------------------------------------------------------------------- +// MessagePart conversion (structured content for UI rendering) +// --------------------------------------------------------------------------- + +/// Convert assistant content blocks to `Vec`, using tool exchanges +/// for correlating tool_use with their results. +/// +/// Each `ToolUse` block becomes a `MessagePart::Tool` with the correlated result +/// (if found in `tool_exchanges`). `Thinking` and `Text` blocks map directly. +fn assistant_blocks_to_message_parts( + blocks: &[ContentBlock], + tool_exchanges: &[dirigent_anth::types::ToolExchange], +) -> Vec { + // Build lookup: tool_use_id → &ToolExchange + let exchange_map: HashMap<&str, &dirigent_anth::types::ToolExchange> = tool_exchanges + .iter() + .map(|e| (e.call.id.as_str(), e)) + .collect(); + + let mut parts = Vec::new(); + + for block in blocks { + match block { + ContentBlock::Text { text } if !text.is_empty() => { + parts.push(dirigent_protocol::MessagePart::Text { text: text.clone() }); + } + ContentBlock::Thinking { thinking } if !thinking.is_empty() => { + parts.push(dirigent_protocol::MessagePart::Thinking { + text: thinking.clone(), + }); + } + ContentBlock::ToolUse { id, name, input, .. } => { + let output = exchange_map.get(id.as_str()).and_then(|ex| { + ex.result.as_ref().map(|r| { + if r.is_error { + serde_json::json!({ "error": r.content.as_deref().unwrap_or("Unknown error") }) + } else { + serde_json::json!({ "result": r.content.as_deref().unwrap_or("") }) + } + }) + }); + + parts.push(dirigent_protocol::MessagePart::Tool { + tool: name.clone(), + tool_call_id: Some(id.clone()), + input: input.clone(), + output, + }); + } + ContentBlock::Image { .. } => { + parts.push(dirigent_protocol::MessagePart::Text { + text: "[Image]".to_string(), + }); + } + _ => {} // Empty text/thinking, ToolResult on assistant side (shouldn't happen) + } + } + + parts +} + +// --------------------------------------------------------------------------- +// Markdown conversion helpers (legacy, used for markdown-only fallback) +// --------------------------------------------------------------------------- + +/// Convert assistant content blocks to a single markdown string. +/// Retained for tests; production path uses `assistant_blocks_to_message_parts`. +#[cfg(test)] +fn assistant_blocks_to_markdown(blocks: &[ContentBlock]) -> String { + let parts: Vec = blocks + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => { + if text.is_empty() { + None + } else { + Some(text.clone()) + } + } + ContentBlock::Thinking { thinking } => { + if thinking.is_empty() { + None + } else { + Some(format!("\n{}\n", thinking)) + } + } + ContentBlock::ToolUse { name, input, .. } => { + let input_pretty = + serde_json::to_string_pretty(input).unwrap_or_else(|_| input.to_string()); + Some(format!("**Tool: {}**\n\n```json\n{}\n```", name, input_pretty)) + } + ContentBlock::ToolResult { + content, is_error, .. + } => { + let label = if *is_error { + "**Tool Error:**" + } else { + "**Tool Result:**" + }; + let body = match content { + Some(Content::Text(s)) => format!("\n\n```\n{}\n```", s), + Some(Content::Blocks(inner)) => { + let text = inner + .iter() + .filter_map(|b| match b { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("\n"); + if text.is_empty() { + String::new() + } else { + format!("\n\n```\n{}\n```", text) + } + } + None => String::new(), + }; + Some(format!("{}{}", label, body)) + } + ContentBlock::Image { .. } => Some("[Image]".to_string()), + }) + .collect(); + + parts.join("\n\n") +} + +/// Convert user content blocks (typically tool results) to markdown. +fn blocks_to_markdown_user(blocks: &[ContentBlock]) -> String { + let parts: Vec = blocks + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => { + if text.is_empty() { + None + } else { + Some(text.clone()) + } + } + ContentBlock::ToolResult { + content, is_error, .. + } => { + let label = if *is_error { + "**Tool Error:**" + } else { + "**Tool Result:**" + }; + let body = match content { + Some(Content::Text(s)) => format!("\n\n```\n{}\n```", s), + Some(Content::Blocks(inner)) => { + let text = inner + .iter() + .filter_map(|b| match b { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join("\n"); + if text.is_empty() { + String::new() + } else { + format!("\n\n```\n{}\n```", text) + } + } + None => String::new(), + }; + Some(format!("{}{}", label, body)) + } + _ => None, + }) + .collect(); + + parts.join("\n\n") +} + +// --------------------------------------------------------------------------- +// Title helpers +// --------------------------------------------------------------------------- + +/// Derive a session title from the first line of user content. +/// Truncates to ~100 characters at a char boundary if needed. +fn derive_title_from_content(content: &str) -> String { + let first_line = content.lines().next().unwrap_or(content).trim(); + if first_line.len() > 100 { + // Find the last char boundary at or before byte 100 + let end = first_line + .char_indices() + .take_while(|&(i, _)| i <= 100) + .last() + .map(|(i, _)| i) + .unwrap_or(100); + format!("{}...", &first_line[..end]) + } else { + first_line.to_string() + } +} + +// --------------------------------------------------------------------------- +// Timestamp helpers +// --------------------------------------------------------------------------- + +/// Parse a timestamp from a serde_json::Value. +/// +/// Handles both string (ISO 8601) and numeric (Unix seconds/millis) values. +/// This wraps `dirigent_anth::parse_timestamp` for convenience. +pub fn parse_timestamp_value(v: &serde_json::Value) -> Option> { + parse_timestamp(v) +} + +// --------------------------------------------------------------------------- +// Importer trait wrapper +// --------------------------------------------------------------------------- + +/// `Importer` adapter around the Claude Code session importer. +/// +/// Wraps the free functions [`discover_claude_import`] and +/// [`import_claude_sessions`] so the generic import registry can drive Claude +/// imports the same way as any other source. +pub struct ClaudeImporter; + +#[async_trait] +impl Importer for ClaudeImporter { + fn source_name(&self) -> &'static str { + "claude" + } + + fn config_shape(&self) -> ImportConfigShape { + // Expose one field: path to ~/.claude directory (directory picker). + ImportConfigShape { + fields: vec![ConfigField { + key: "path".into(), + label: "Claude directory".into(), + kind: ConfigFieldKind::Path { directory: true }, + required: true, + help: Some("Usually ~/.claude".into()), + }], + example: ImportConfig { + source: "claude".into(), + params: { + let mut m = std::collections::BTreeMap::new(); + m.insert("path".into(), serde_json::json!("~/.claude")); + m + }, + }, + } + } + + fn detect_defaults(&self) -> Option { + let home = dirigent_anth::discover_claude_home().ok()?; + let mut params = std::collections::BTreeMap::new(); + params.insert("path".into(), serde_json::json!(home.to_string())); + Some(ImportConfig { + source: "claude".into(), + params, + }) + } + + async fn discover(&self, cfg: &ImportConfig) -> std::result::Result { + let path = cfg + .params + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| ImportError::Config("missing `path`".into()))?; + let (_home, discovery) = discover_claude_import(Some(path)) + .map_err(ImportError::Discovery)?; + Ok(discovery) + } + + async fn import( + &self, + cfg: &ImportConfig, + archivist: &Archivist, + target: ImportTarget, + progress: ImportProgressSink, + ) -> std::result::Result { + let path = cfg + .params + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| ImportError::Config("missing `path`".into()))?; + let home = Utf8PathBuf::from(path); + import_claude_sessions(archivist, &home, target.archive, &progress, &target.project_map) + .await + .map_err(|e| ImportError::Archivist(e.to_string())) + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn make_user_text(text: &str, timestamp: &str) -> RawMessage { + let json = format!( + r#"{{ + "type": "user", + "uuid": "test-uuid-001", + "parentUuid": null, + "timestamp": "{}", + "sessionId": "test-session", + "cwd": "/home/user/project", + "version": "2.1.71", + "gitBranch": "main", + "isSidechain": false, + "message": {{ + "role": "user", + "content": "{}" + }} + }}"#, + timestamp, text + ); + serde_json::from_str(&json).unwrap() + } + + fn make_assistant_text(text: &str, timestamp: &str) -> RawMessage { + let json = format!( + r#"{{ + "type": "assistant", + "uuid": "test-uuid-002", + "parentUuid": "test-uuid-001", + "timestamp": "{}", + "sessionId": "test-session", + "message": {{ + "model": "claude-opus-4-6", + "id": "msg_test", + "role": "assistant", + "content": [{{"type": "text", "text": "{}"}}], + "stop_reason": "end_turn", + "usage": {{"input_tokens": 100, "output_tokens": 50}} + }} + }}"#, + timestamp, text + ); + serde_json::from_str(&json).unwrap() + } + + fn make_assistant_with_tool(tool_name: &str, timestamp: &str) -> RawMessage { + let json = format!( + r#"{{ + "type": "assistant", + "uuid": "test-uuid-003", + "parentUuid": "test-uuid-001", + "timestamp": "{}", + "sessionId": "test-session", + "message": {{ + "model": "claude-opus-4-6", + "id": "msg_test2", + "role": "assistant", + "content": [ + {{"type": "text", "text": "Let me check that."}}, + {{"type": "tool_use", "id": "toolu_abc", "name": "{}", "input": {{"command": "ls"}}}} + ], + "stop_reason": "tool_use" + }} + }}"#, + timestamp, tool_name + ); + serde_json::from_str(&json).unwrap() + } + + fn make_queue_operation() -> RawMessage { + let json = r#"{ + "type": "queue-operation", + "operation": "enqueue", + "timestamp": "2026-03-14T21:15:17.531Z", + "sessionId": "test-session" + }"#; + serde_json::from_str(json).unwrap() + } + + #[test] + fn convert_user_text_message() { + let msg = make_user_text("Hello world", "2026-04-01T12:00:00Z"); + let record = convert_ant_message(&msg, Uuid::nil()).expect("should produce a record"); + + assert_eq!(record.role, "user"); + assert_eq!(record.content_md, "Hello world"); + assert_eq!(record.session, Uuid::nil()); + assert!(record.author.is_none()); + + // Check metadata + let meta = record.metadata.as_object().unwrap(); + assert_eq!(meta.get("cwd").unwrap(), "/home/user/project"); + assert_eq!(meta.get("git_branch").unwrap(), "main"); + assert_eq!(meta.get("claude_version").unwrap(), "2.1.71"); + + // Check content_parts is present + assert!(record.content_parts.is_some()); + } + + #[test] + fn convert_assistant_text_message() { + let msg = make_assistant_text("Here is your answer.", "2026-04-01T12:00:05Z"); + let record = convert_ant_message(&msg, Uuid::nil()).expect("should produce a record"); + + assert_eq!(record.role, "assistant"); + assert_eq!(record.content_md, "Here is your answer."); + assert_eq!(record.author.as_deref(), Some("claude-opus-4-6")); + + // Check metadata + let meta = record.metadata.as_object().unwrap(); + assert_eq!(meta.get("model").unwrap(), "claude-opus-4-6"); + assert_eq!(meta.get("stop_reason").unwrap(), "end_turn"); + assert!(meta.get("usage").is_some()); + } + + #[test] + fn convert_queue_operation_returns_none() { + let msg = make_queue_operation(); + let result = convert_ant_message(&msg, Uuid::nil()); + assert!(result.is_none(), "QueueOperation should be skipped as noise"); + } + + #[test] + fn convert_assistant_with_tool_use_contains_tool_name() { + let msg = make_assistant_with_tool("Bash", "2026-04-01T12:00:10Z"); + let record = convert_ant_message(&msg, Uuid::nil()).expect("should produce a record"); + + assert_eq!(record.role, "assistant"); + assert!( + record.content_md.contains("[Tool: Bash]"), + "markdown should contain tool reference, got: {}", + record.content_md + ); + // content_parts should contain proper MessagePart::Tool + assert!(record.content_parts.is_some()); + let parts: Vec = + serde_json::from_value(record.content_parts.unwrap()).unwrap(); + assert_eq!(parts.len(), 2); + assert!(matches!(&parts[0], dirigent_protocol::MessagePart::Text { text } if text == "Let me check that.")); + assert!(matches!(&parts[1], dirigent_protocol::MessagePart::Tool { tool, .. } if tool == "Bash")); + } + + #[test] + fn parse_timestamp_value_string() { + let v = serde_json::json!("2026-04-01T12:00:00Z"); + let dt = parse_timestamp_value(&v).unwrap(); + assert_eq!(dt.year(), 2026); + assert_eq!(dt.month(), 4); + assert_eq!(dt.day(), 1); + } + + #[test] + fn parse_timestamp_value_numeric_millis() { + let v = serde_json::json!(1769461914249_i64); + let dt = parse_timestamp_value(&v).unwrap(); + assert!(dt.year() >= 2025); + } + + #[test] + fn parse_timestamp_value_null_returns_none() { + let v = serde_json::json!(null); + assert!(parse_timestamp_value(&v).is_none()); + } + + use chrono::Datelike; + + #[test] + fn convert_empty_user_message_returns_none() { + let msg = make_user_text("", "2026-04-01T12:00:00Z"); + let result = convert_ant_message(&msg, Uuid::nil()); + assert!(result.is_none(), "Empty user message should be skipped"); + } + + #[test] + fn assistant_blocks_to_markdown_thinking() { + let blocks = vec![ContentBlock::Thinking { + thinking: "Let me think about this...".to_string(), + }]; + let md = assistant_blocks_to_markdown(&blocks); + assert!(md.contains("")); + assert!(md.contains("Let me think about this...")); + assert!(md.contains("")); + } + + #[test] + fn assistant_blocks_to_markdown_image() { + let blocks = vec![ContentBlock::Image { + source: serde_json::json!({"type": "base64"}), + }]; + let md = assistant_blocks_to_markdown(&blocks); + assert_eq!(md, "[Image]"); + } + + #[test] + fn convert_assistant_with_tools_produces_message_parts() { + use dirigent_protocol::MessagePart; + + let blocks = vec![ + ContentBlock::Text { + text: "Let me check that.".to_string(), + }, + ContentBlock::ToolUse { + id: "toolu_abc".to_string(), + name: "Bash".to_string(), + input: serde_json::json!({"command": "ls"}), + caller: None, + }, + ]; + + let exchanges: Vec = vec![]; + let parts = assistant_blocks_to_message_parts(&blocks, &exchanges); + + assert_eq!(parts.len(), 2); + assert!( + matches!(&parts[0], MessagePart::Text { text } if text == "Let me check that.") + ); + assert!( + matches!(&parts[1], MessagePart::Tool { tool, output, .. } if tool == "Bash" && output.is_none()) + ); + } + + #[test] + fn convert_assistant_with_correlated_tool_result() { + use dirigent_protocol::MessagePart; + + let blocks = vec![ + ContentBlock::Text { + text: "Let me read that file.".to_string(), + }, + ContentBlock::ToolUse { + id: "toolu_xyz".to_string(), + name: "Read".to_string(), + input: serde_json::json!({"file_path": "/tmp/test.rs"}), + caller: None, + }, + ]; + + let exchanges = vec![dirigent_anth::types::ToolExchange { + call: dirigent_anth::types::ToolCall { + id: "toolu_xyz".to_string(), + name: dirigent_anth::types::ToolName::Read, + input: serde_json::json!({"file_path": "/tmp/test.rs"}), + source_message_uuid: "msg1".to_string(), + }, + result: Some(dirigent_anth::types::ToolResultData { + tool_use_id: "toolu_xyz".to_string(), + content: Some("fn main() {}".to_string()), + is_error: false, + source_message_uuid: "msg2".to_string(), + }), + }]; + + let parts = assistant_blocks_to_message_parts(&blocks, &exchanges); + assert_eq!(parts.len(), 2); + if let MessagePart::Tool { + tool, + tool_call_id, + output, + .. + } = &parts[1] + { + assert_eq!(tool, "Read"); + assert_eq!(tool_call_id.as_deref(), Some("toolu_xyz")); + assert!(output.is_some()); + } else { + panic!("Expected Tool part"); + } + } + + #[test] + fn user_message_with_only_tool_results_is_suppressed() { + let user_json = r#"{ + "type": "user", + "uuid": "test-uuid-005", + "timestamp": "2026-04-01T12:00:15Z", + "sessionId": "test-session", + "message": { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": "toolu_abc", "content": "output text"} + ] + } + }"#; + let msg: RawMessage = serde_json::from_str(user_json).unwrap(); + let record = convert_ant_message_with_exchanges(&msg, Uuid::nil(), &[], ""); + assert!( + record.is_none(), + "Pure tool-result user messages should be suppressed" + ); + } + + #[test] + fn derive_title_from_content_short() { + let title = derive_title_from_content("Hello, can you help me refactor this module?"); + assert_eq!(title, "Hello, can you help me refactor this module?"); + } + + #[test] + fn derive_title_from_content_truncates_long() { + let long = "a".repeat(200); + let title = derive_title_from_content(&long); + assert!(title.len() <= 104); // ~100 + "..." + assert!(title.ends_with("...")); + } + + #[test] + fn derive_title_from_content_handles_multibyte_chars() { + // 'x' repeated with a multi-byte char near the boundary + let mut s = "a".repeat(98); + s.push('\u{2019}'); // right single quote (3 bytes) at byte 98..101 + s.push_str(&"b".repeat(100)); + let title = derive_title_from_content(&s); + assert!(title.ends_with("...")); + // Should NOT panic on multi-byte boundary + assert!(title.is_char_boundary(title.len())); + } + + #[test] + fn derive_title_from_content_uses_first_line() { + let title = derive_title_from_content("First line\nSecond line\nThird line"); + assert_eq!(title, "First line"); + } +} + +#[cfg(test)] +mod id_stability_tests { + use super::*; + + #[test] + fn derive_message_id_stable_when_source_uuid_present() { + let ts = chrono::Utc::now(); + let a = derive_message_id( + Some("12345678-1234-5678-1234-567812345678"), + "session-1", + &ts, + "user", + "hello", + ); + let b = derive_message_id( + Some("12345678-1234-5678-1234-567812345678"), + "session-1", + &ts, + "user", + "hello", + ); + assert_eq!(a, b, "same source uuid must produce same message_id"); + } + + #[test] + fn derive_message_id_stable_fallback_when_uuid_absent() { + let ts = chrono::Utc::now(); + let a = derive_message_id(None, "session-1", &ts, "user", "hello"); + let b = derive_message_id(None, "session-1", &ts, "user", "hello"); + assert_eq!(a, b); + } + + #[test] + fn derive_message_id_different_content_different_id() { + let ts = chrono::Utc::now(); + let a = derive_message_id(None, "session-1", &ts, "user", "hello"); + let b = derive_message_id(None, "session-1", &ts, "user", "world"); + assert_ne!(a, b); + } +} diff --git a/crates/dirigent_archivist/src/import/sources/codex.rs b/crates/dirigent_archivist/src/import/sources/codex.rs new file mode 100644 index 0000000..a9cc422 --- /dev/null +++ b/crates/dirigent_archivist/src/import/sources/codex.rs @@ -0,0 +1,331 @@ +//! OpenAI Codex CLI importer: takes a path to a directory of JSONL session files. + +use std::path::PathBuf; + +use async_trait::async_trait; +use chrono::Utc; +use uuid::Uuid; + +use dirigent_codex::{ParsedMessage, ParsedSession}; + +use super::super::progress::ImportProgressSink; +use super::super::trait_def::{ + ConfigField, ConfigFieldKind, ImportConfig, ImportConfigShape, ImportError, ImportTarget, + Importer, +}; +use super::super::{ + import_sessions, DiscoveredSession, ImportDiscovery, ImportProject, ImportStats, +}; +use crate::coordinator::Archivist; +use crate::error::{ArchivistError, Result}; +use crate::types::{MessageRecord, RegisterConnectorRequest}; + +/// Connector type string used for imported Codex sessions. +pub const CODEX_CONNECTOR_TYPE: &str = "Codex"; + +/// Fingerprint prefix for locally-imported Codex sessions. +pub const CODEX_FINGERPRINT_PREFIX: &str = "import/local:codex"; + +/// Namespace UUID for deterministic UUIDv5 derivations of message ids that +/// Codex does not expose natively. +const CODEX_MESSAGE_NS: Uuid = Uuid::from_u128(0x9e28_b7d4_af9c_4fe2_a8d1_8c41_21b3_2222); + +pub struct CodexImporter; + +#[async_trait] +impl Importer for CodexImporter { + fn source_name(&self) -> &'static str { + "codex" + } + + fn config_shape(&self) -> ImportConfigShape { + ImportConfigShape { + fields: vec![ConfigField { + key: "path".into(), + label: "Codex sessions directory".into(), + kind: ConfigFieldKind::Path { directory: true }, + required: true, + help: Some("Usually ~/.codex/sessions".into()), + }], + example: ImportConfig { + source: "codex".into(), + params: { + let mut m = std::collections::BTreeMap::new(); + m.insert("path".into(), serde_json::json!("~/.codex/sessions")); + m + }, + }, + } + } + + async fn discover( + &self, + cfg: &ImportConfig, + ) -> std::result::Result { + let path = require_path(cfg)?; + let files = dirigent_codex::discover_sessions(&path) + .map_err(|e| ImportError::Discovery(e.to_string()))?; + + // Parse each file to count messages. This is a best-effort estimate — + // malformed lines are skipped by the parser, so counts reflect what + // the importer would actually write. + let mut total_estimated_messages: usize = 0; + for file in &files { + if let Ok(session) = dirigent_codex::parse_file(file) { + total_estimated_messages += session.messages.len(); + } + } + let total_sessions = files.len(); + + // Codex sessions live flat in one directory; bucket them into a + // single synthetic project named after the directory. + let project_name = path + .file_name() + .and_then(|s| s.to_str()) + .unwrap_or("Codex sessions") + .to_string(); + + Ok(ImportDiscovery { + source_name: "Codex".to_string(), + source_path: path.display().to_string(), + projects: vec![ImportProject { + name: project_name, + session_count: total_sessions, + }], + total_sessions, + total_estimated_messages, + }) + } + + async fn import( + &self, + cfg: &ImportConfig, + archivist: &Archivist, + target: ImportTarget, + progress: ImportProgressSink, + ) -> std::result::Result { + let path = require_path(cfg)?; + let files = dirigent_codex::discover_sessions(&path) + .map_err(|e| ImportError::Discovery(e.to_string()))?; + + // Parse every session file up front so that `convert_messages` + // (called by `import_sessions`) can do O(1) lookups. + let mut parsed: Vec = Vec::with_capacity(files.len()); + for file in &files { + match dirigent_codex::parse_file(file) { + Ok(session) => parsed.push(session), + Err(e) => { + tracing::warn!( + path = %file.display(), + error = %e, + "Skipping unreadable Codex session file" + ); + } + } + } + + let mut discovered: Vec = Vec::with_capacity(parsed.len()); + for s in &parsed { + let metadata = serde_json::json!({ + "source": "codex", + "source_path": s.source_path.display().to_string(), + "native_id": s.native_id, + }); + let file_size = std::fs::metadata(&s.source_path).ok().map(|m| m.len()); + + discovered.push(DiscoveredSession { + native_session_id: s.native_id.clone(), + title: None, + created_at: s.created_at, + updated_at: s.updated_at, + message_count: s.messages.len(), + metadata, + project_path: None, + file_size, + }); + } + + // Map native_id -> parsed session for O(1) lookup in `convert`. + let session_lookup: std::collections::HashMap = parsed + .into_iter() + .map(|s| (s.native_id.clone(), s)) + .collect(); + + // Fingerprint the import by the canonical directory path. Re-running + // against the same directory aliases onto the same connector. + let canonical_path = path.canonicalize().unwrap_or_else(|_| path.clone()); + let fingerprint = format!("{}:{}", CODEX_FINGERPRINT_PREFIX, canonical_path.display()); + + let connector_req = RegisterConnectorRequest { + r#type: CODEX_CONNECTOR_TYPE.to_string(), + title: format!("Codex ({})", canonical_path.display()), + client_native_id: fingerprint.clone(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some(fingerprint), + }; + + let convert = |native_id: &str| -> Result> { + let session = session_lookup.get(native_id).ok_or_else(|| { + ArchivistError::InvalidRequest(format!( + "Parsed session not found for native_id: {}", + native_id + )) + })?; + Ok(convert_session_to_records(session)) + }; + + import_sessions( + archivist, + connector_req, + discovered, + convert, + target.archive, + &progress, + false, + &target.project_map, + ) + .await + .map_err(|e| ImportError::Archivist(e.to_string())) + } +} + +// --------------------------------------------------------------------------- +// Conversion helpers +// --------------------------------------------------------------------------- + +fn require_path(cfg: &ImportConfig) -> std::result::Result { + cfg.params + .get("path") + .and_then(|v| v.as_str()) + .map(PathBuf::from) + .ok_or_else(|| ImportError::Config("missing `path`".into())) +} + +/// Convert every [`ParsedMessage`] in a session into a [`MessageRecord`], +/// leaving `session = Uuid::nil()` for the generic orchestrator to patch. +fn convert_session_to_records(session: &ParsedSession) -> Vec { + session + .messages + .iter() + .enumerate() + .filter_map(|(idx, m)| convert_parsed_message(&session.native_id, idx, m)) + .collect() +} + +fn convert_parsed_message( + native_session_id: &str, + index: usize, + msg: &ParsedMessage, +) -> Option { + // Skip purely empty messages — nothing to archive. + if msg.content.trim().is_empty() { + return None; + } + + let ts = msg.ts.unwrap_or_else(Utc::now); + + // Codex events don't carry per-message UUIDs, so always derive a stable + // UUIDv5 from (native_session, index, role, ts). Index disambiguates + // otherwise-identical back-to-back messages. + let key = format!( + "{}:{}:{}:{}", + native_session_id, + index, + msg.role, + ts.to_rfc3339(), + ); + let message_id = Uuid::new_v5(&CODEX_MESSAGE_NS, key.as_bytes()); + + let parts = vec![dirigent_protocol::MessagePart::Text { + text: msg.content.clone(), + }]; + let content_parts = serde_json::to_value(&parts).ok(); + + Some(MessageRecord { + version: 1, + message_id, + session: Uuid::nil(), + parent_id: None, + ts, + role: msg.role.clone(), + author: None, + content_md: msg.content.clone(), + content_parts, + attachments: Vec::new(), + metadata: msg.metadata.clone(), + }) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_message(role: &str, content: &str) -> ParsedMessage { + ParsedMessage { + ts: None, + role: role.into(), + content: content.into(), + metadata: serde_json::Value::Null, + } + } + + fn sample_message_at(role: &str, content: &str, ts: chrono::DateTime) -> ParsedMessage { + ParsedMessage { + ts: Some(ts), + role: role.into(), + content: content.into(), + metadata: serde_json::Value::Null, + } + } + + #[test] + fn empty_content_is_skipped() { + let m = sample_message("user", " "); + assert!(convert_parsed_message("s", 0, &m).is_none()); + } + + #[test] + fn non_empty_message_converts() { + let m = sample_message("user", "hello"); + let r = convert_parsed_message("s", 0, &m).expect("converts"); + assert_eq!(r.role, "user"); + assert_eq!(r.content_md, "hello"); + assert_eq!(r.session, Uuid::nil()); + assert!(r.content_parts.is_some()); + } + + #[test] + fn message_id_is_deterministic_per_session_index() { + // Fix ts so we don't accidentally hash Utc::now() into the id key. + let ts = chrono::TimeZone::timestamp_opt(&Utc, 1_735_732_800, 0) + .single() + .unwrap(); + let m = sample_message_at("user", "hello", ts); + let a = convert_parsed_message("session-a", 0, &m).unwrap(); + let b = convert_parsed_message("session-a", 0, &m).unwrap(); + assert_eq!(a.message_id, b.message_id); + + // Different index → different id. + let c = convert_parsed_message("session-a", 1, &m).unwrap(); + assert_ne!(a.message_id, c.message_id); + + // Different session → different id. + let d = convert_parsed_message("session-b", 0, &m).unwrap(); + assert_ne!(a.message_id, d.message_id); + } + + #[test] + fn require_path_reports_missing_config() { + let cfg = ImportConfig { + source: "codex".into(), + params: Default::default(), + }; + let err = require_path(&cfg).expect_err("should fail"); + assert!(matches!(err, ImportError::Config(_))); + } +} diff --git a/crates/dirigent_archivist/src/import/sources/mod.rs b/crates/dirigent_archivist/src/import/sources/mod.rs new file mode 100644 index 0000000..f3cbc31 --- /dev/null +++ b/crates/dirigent_archivist/src/import/sources/mod.rs @@ -0,0 +1,7 @@ +//! Per-source importer implementations. + +pub mod claude; +#[cfg(feature = "importer-chatgpt")] +pub mod chatgpt; +#[cfg(feature = "importer-codex")] +pub mod codex; diff --git a/crates/dirigent_archivist/src/import/trait_def.rs b/crates/dirigent_archivist/src/import/trait_def.rs new file mode 100644 index 0000000..294400b --- /dev/null +++ b/crates/dirigent_archivist/src/import/trait_def.rs @@ -0,0 +1,113 @@ +//! Importer trait and config-shape types consumed by the UI (dynamic form +//! rendering) and the CLI (future). Scripts can serialise ImportConfig as JSON. + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use std::collections::{BTreeMap, HashMap}; +use thiserror::Error; +use uuid::Uuid; + +use crate::coordinator::Archivist; +use super::progress::ImportProgressSink; + +#[async_trait] +pub trait Importer: Send + Sync { + fn source_name(&self) -> &'static str; + fn config_shape(&self) -> ImportConfigShape; + + async fn discover( + &self, + cfg: &ImportConfig, + ) -> Result; + + async fn import( + &self, + cfg: &ImportConfig, + archivist: &Archivist, + target: ImportTarget, + progress: ImportProgressSink, + ) -> Result; + + /// Attempt to auto-detect default configuration values. + /// + /// Importers that can discover their source location automatically + /// (e.g., Claude Code's `~/.claude` directory) should override this. + /// Returns `None` when auto-detection is not supported or fails. + fn detect_defaults(&self) -> Option { + None + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImporterInfo { + pub source_name: String, + pub display_name: String, + pub config_shape: ImportConfigShape, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportConfigShape { + pub fields: Vec, + pub example: ImportConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConfigField { + pub key: String, + pub label: String, + pub kind: ConfigFieldKind, + pub required: bool, + pub help: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "type")] +pub enum ConfigFieldKind { + Path { directory: bool }, + File { extension: Option }, + String, + Bool, + Enum { variants: Vec }, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ImportConfig { + pub source: String, + #[serde(default)] + pub params: BTreeMap, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct ImportTarget { + pub archive: Option, + pub connector_alias: Option, + pub project_id: Option, + /// Maps normalized project_path -> project_id (as string UUID). + /// When a session's project_path is found in this map, the corresponding + /// project_id is injected into the session metadata during import. + #[serde(default)] + pub project_map: HashMap, +} + +#[derive(Debug, Error)] +pub enum ImportError { + #[error("source not found: {0}")] SourceNotFound(String), + #[error("config: {0}")] Config(String), + #[error("discovery: {0}")] Discovery(String), + #[error("I/O: {0}")] Io(#[from] std::io::Error), + #[error("archivist: {0}")] Archivist(String), + #[error("parser: {0}")] Parser(String), + #[error("cancelled")] Cancelled, +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn config_round_trips() { + let cfg = ImportConfig { source: "claude".into(), params: BTreeMap::new() }; + let json = serde_json::to_string(&cfg).unwrap(); + let back: ImportConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(back.source, "claude"); + } +} diff --git a/crates/dirigent_archivist/src/lib.rs b/crates/dirigent_archivist/src/lib.rs new file mode 100644 index 0000000..b3c7449 --- /dev/null +++ b/crates/dirigent_archivist/src/lib.rs @@ -0,0 +1,45 @@ +//! Dirigent Archivist +//! +//! Persistent storage for all agentic interactions in Dirigent. +//! +//! The Archivist provides file-based archival storage using NDJSON, JSON, and TSV +//! formats for durability and human-readability. It implements an archive-first +//! architecture with connector API fallback for session data. +//! +//! # Key Features +//! +//! - File-based storage for easy curation and grep-ability +//! - Content-addressable file storage for attachments +//! - Session lineage tracking (splits, continuations, mutations) +//! - Connector registry with UID coordination +//! - Real-time event streaming for archive updates +//! +//! # Architecture +//! +//! See `docs/building/05_archivist/vision.md` for detailed design. + +pub mod accumulator; +pub mod backend; +pub mod backends; +pub mod backfill; +pub mod coordinator; +pub mod error; +pub mod events; +pub mod import; +pub mod registry; +pub mod session; +pub mod storage; +pub mod types; + +// Re-export commonly used types +pub use accumulator::{MessageAccumulator, ToolCallData}; +pub use backend::{ + ArchiveBackend, ArchiveCapability, CapabilitySet, ConnectorRegistryBackend, + DagBackend, HealthStatus, MetaEventsBackend, SearchBackend, SessionMappingBackend, +}; +pub use backends::JsonlBackend; +pub use backfill::{backfill_from_sessions, convert_message_to_record, BackfillStats}; +pub use coordinator::{ArchiveInfo, ArchiveMetadata, Archivist}; +pub use error::{ArchivistError, Result}; +pub use events::EventHandler; +pub use types::*; diff --git a/crates/dirigent_archivist/src/registry/cache.rs b/crates/dirigent_archivist/src/registry/cache.rs new file mode 100644 index 0000000..d25698d --- /dev/null +++ b/crates/dirigent_archivist/src/registry/cache.rs @@ -0,0 +1,116 @@ +//! Positive LRU cache mapping `scroll_id` to the backend that holds the +//! authoritative session metadata, populated on the first successful read. + +use std::num::NonZeroUsize; + +use lru::LruCache; +use tokio::sync::Mutex; +use uuid::Uuid; + +const DEFAULT_CAPACITY: usize = 10_000; + +pub struct ReadCache { + inner: Mutex>, +} + +impl ReadCache { + pub fn new() -> Self { + Self::with_capacity(DEFAULT_CAPACITY) + } + + pub fn with_capacity(capacity: usize) -> Self { + let cap = NonZeroUsize::new(capacity.max(1)).unwrap(); + Self { + inner: Mutex::new(LruCache::new(cap)), + } + } + + pub async fn get(&self, scroll_id: Uuid) -> Option { + let mut guard = self.inner.lock().await; + guard.get(&scroll_id).cloned() + } + + pub async fn put(&self, scroll_id: Uuid, backend_name: String) { + let mut guard = self.inner.lock().await; + guard.put(scroll_id, backend_name); + } + + pub async fn invalidate(&self, scroll_id: Uuid) { + let mut guard = self.inner.lock().await; + guard.pop(&scroll_id); + } + + pub async fn rewrite(&self, scroll_id: Uuid, new_backend: String) { + let mut guard = self.inner.lock().await; + guard.put(scroll_id, new_backend); + } + + pub async fn clear(&self) { + let mut guard = self.inner.lock().await; + guard.clear(); + } + + pub async fn len(&self) -> usize { + let guard = self.inner.lock().await; + guard.len() + } +} + +impl Default for ReadCache { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn id(b: u8) -> Uuid { + Uuid::from_bytes([b; 16]) + } + + #[tokio::test] + async fn put_then_get() { + let c = ReadCache::new(); + c.put(id(1), "main".into()).await; + assert_eq!(c.get(id(1)).await.as_deref(), Some("main")); + assert!(c.get(id(2)).await.is_none()); + } + + #[tokio::test] + async fn invalidate_removes_entry() { + let c = ReadCache::new(); + c.put(id(1), "main".into()).await; + c.invalidate(id(1)).await; + assert!(c.get(id(1)).await.is_none()); + } + + #[tokio::test] + async fn rewrite_changes_backend() { + let c = ReadCache::new(); + c.put(id(1), "a".into()).await; + c.rewrite(id(1), "b".into()).await; + assert_eq!(c.get(id(1)).await.as_deref(), Some("b")); + } + + #[tokio::test] + async fn lru_evicts_oldest() { + let c = ReadCache::with_capacity(2); + c.put(id(1), "a".into()).await; + c.put(id(2), "b".into()).await; + c.put(id(3), "c".into()).await; // evicts id(1) + assert!(c.get(id(1)).await.is_none()); + assert_eq!(c.get(id(2)).await.as_deref(), Some("b")); + assert_eq!(c.get(id(3)).await.as_deref(), Some("c")); + } + + #[tokio::test] + async fn clear_empties() { + let c = ReadCache::new(); + c.put(id(1), "a".into()).await; + c.put(id(2), "b".into()).await; + c.clear().await; + assert_eq!(c.len().await, 0); + } +} diff --git a/crates/dirigent_archivist/src/registry/config.rs b/crates/dirigent_archivist/src/registry/config.rs new file mode 100644 index 0000000..7c13d9c --- /dev/null +++ b/crates/dirigent_archivist/src/registry/config.rs @@ -0,0 +1,253 @@ +//! Declarative `[[archives]]` config block parsed from `dirigent.toml`. +//! +//! The TOML schema is documented in `docs/plans/2026-04-19-archivist-phase3-design.md`. + +use serde::{Deserialize, Serialize}; + +use super::filter::ArchiveFilter; +use super::registration::{FailureMode, OverflowPolicy}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ArchivesConfig { + #[serde(default, rename = "archives")] + pub entries: Vec, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct ArchiveConfig { + pub name: String, + #[serde(rename = "type")] + pub type_name: String, + #[serde(default = "default_write_active")] + pub write_active: bool, + #[serde(default)] + pub failure_mode: FailureMode, + #[serde(default)] + pub read_priority: u32, + #[serde(default = "default_enabled")] + pub enabled: bool, + #[serde(default)] + pub write_policy: WritePolicyConfig, + /// Per-archive include/exclude filter applied during non-primary write + /// fanout. Absent or `{}` means unrestricted. + #[serde(default)] + pub filter: ArchiveFilter, + #[serde(default = "default_params")] + pub params: toml::Value, +} + +fn default_params() -> toml::Value { + toml::Value::Table(toml::value::Table::new()) +} + +fn default_write_active() -> bool { + true +} +fn default_enabled() -> bool { + true +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum WritePolicyConfig { + Tag(WritePolicyTag), + Detailed(WritePolicyDetailed), +} + +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum WritePolicyTag { + Inline, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum WritePolicyDetailed { + Inline, + Queued { + #[serde(default = "default_batch_window_ms")] + batch_window_ms: u64, + #[serde(default = "default_capacity")] + capacity: usize, + #[serde(default)] + overflow: OverflowPolicy, + }, +} + +fn default_batch_window_ms() -> u64 { + 50 +} +fn default_capacity() -> usize { + 1024 +} + +impl Default for WritePolicyConfig { + fn default() -> Self { + WritePolicyConfig::Tag(WritePolicyTag::Inline) + } +} + +impl WritePolicyConfig { + pub fn into_runtime(self) -> super::registration::WritePolicy { + use super::registration::WritePolicy; + match self { + WritePolicyConfig::Tag(WritePolicyTag::Inline) => WritePolicy::Inline, + WritePolicyConfig::Detailed(WritePolicyDetailed::Inline) => WritePolicy::Inline, + WritePolicyConfig::Detailed(WritePolicyDetailed::Queued { + batch_window_ms, + capacity, + overflow, + }) => WritePolicy::Queued { + batch_window_ms, + capacity, + overflow, + }, + } + } +} + +use std::collections::BTreeSet; + +#[derive(Debug, thiserror::Error, PartialEq)] +pub enum ConfigValidationError { + #[error("duplicate archive name `{0}`")] + DuplicateName(String), + #[error("no `required` write-active backend configured (need at least one)")] + NoPrimary, +} + +impl ArchivesConfig { + pub fn validate(&self) -> Result<(), ConfigValidationError> { + let mut seen: BTreeSet<&str> = BTreeSet::new(); + for entry in &self.entries { + if !seen.insert(entry.name.as_str()) { + return Err(ConfigValidationError::DuplicateName(entry.name.clone())); + } + } + + // Empty config is allowed (ephemeral mode). + if self.entries.is_empty() { + return Ok(()); + } + + let has_primary = self + .entries + .iter() + .any(|e| e.enabled && e.write_active && e.failure_mode == FailureMode::Required); + + if !has_primary { + return Err(ConfigValidationError::NoPrimary); + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn parse(toml_src: &str) -> ArchivesConfig { + toml::from_str(toml_src).expect("parse") + } + + #[test] + fn empty_config_is_ephemeral() { + let cfg: ArchivesConfig = toml::from_str("").unwrap(); + assert!(cfg.entries.is_empty()); + assert!(cfg.validate().is_ok()); + } + + #[test] + fn minimal_single_archive() { + let cfg = parse( + r#" + [[archives]] + name = "main" + type = "jsonl" + [archives.params] + path = "dirigent_archive" + "#, + ); + assert_eq!(cfg.entries.len(), 1); + let e = &cfg.entries[0]; + assert_eq!(e.name, "main"); + assert_eq!(e.type_name, "jsonl"); + assert!(e.write_active); + assert_eq!(e.failure_mode, FailureMode::Required); + assert_eq!(e.read_priority, 0); + assert!(e.enabled); + assert!(matches!(e.write_policy, WritePolicyConfig::Tag(WritePolicyTag::Inline))); + cfg.validate().unwrap(); + } + + #[test] + fn duplicate_name_rejected() { + let cfg = parse( + r#" + [[archives]] + name = "main" + type = "jsonl" + [archives.params] + path = "a" + + [[archives]] + name = "main" + type = "jsonl" + [archives.params] + path = "b" + "#, + ); + assert_eq!( + cfg.validate(), + Err(ConfigValidationError::DuplicateName("main".into())) + ); + } + + #[test] + fn no_primary_rejected() { + let cfg = parse( + r#" + [[archives]] + name = "mirror" + type = "jsonl" + failure_mode = "best_effort" + [archives.params] + path = "a" + "#, + ); + assert_eq!(cfg.validate(), Err(ConfigValidationError::NoPrimary)); + } + + #[test] + fn queued_write_policy_parses() { + let cfg = parse( + r#" + [[archives]] + name = "main" + type = "jsonl" + [archives.params] + path = "a" + + [archives.write_policy] + type = "queued" + batch_window_ms = 100 + capacity = 4096 + overflow = "drop_oldest" + "#, + ); + let entry = &cfg.entries[0]; + match &entry.write_policy { + WritePolicyConfig::Detailed(WritePolicyDetailed::Queued { + batch_window_ms, + capacity, + overflow, + }) => { + assert_eq!(*batch_window_ms, 100); + assert_eq!(*capacity, 4096); + assert_eq!(*overflow, OverflowPolicy::DropOldest); + } + other => panic!("unexpected write_policy: {:?}", other), + } + } +} diff --git a/crates/dirigent_archivist/src/registry/factory.rs b/crates/dirigent_archivist/src/registry/factory.rs new file mode 100644 index 0000000..dbb5a8a --- /dev/null +++ b/crates/dirigent_archivist/src/registry/factory.rs @@ -0,0 +1,192 @@ +//! Pluggable backend instantiation: type-string → factory → backend. + +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::backend::ArchiveBackend; + +#[derive(Debug, thiserror::Error)] +pub enum BackendBuildError { + #[error("unknown backend type `{0}`")] + UnknownType(String), + #[error("invalid params for backend `{name}` (type `{type_name}`): {source}")] + InvalidParams { + name: String, + type_name: String, + #[source] + source: anyhow::Error, + }, + #[error("backend `{name}` (type `{type_name}`) failed to initialise: {source}")] + BackendInit { + name: String, + type_name: String, + #[source] + source: anyhow::Error, + }, +} + +#[async_trait] +pub trait BackendFactory: Send + Sync { + fn type_name(&self) -> &'static str; + + async fn build( + &self, + archive_name: &str, + params: toml::Value, + ) -> Result, BackendBuildError>; +} + +pub struct BackendRegistry { + factories: HashMap<&'static str, Arc>, +} + +impl BackendRegistry { + pub fn new() -> Self { + Self { + factories: HashMap::new(), + } + } + + pub fn register(&mut self, factory: Arc) { + self.factories.insert(factory.type_name(), factory); + } + + pub fn get(&self, type_name: &str) -> Option<&Arc> { + self.factories.get(type_name) + } + + pub async fn build( + &self, + archive_name: &str, + type_name: &str, + params: toml::Value, + ) -> Result, BackendBuildError> { + let factory = self + .get(type_name) + .ok_or_else(|| BackendBuildError::UnknownType(type_name.into()))?; + factory.build(archive_name, params).await + } +} + +impl Default for BackendRegistry { + fn default() -> Self { + Self::new() + } +} + +use std::path::PathBuf; + +use crate::backends::JsonlBackend; + +#[derive(Debug, serde::Deserialize)] +struct JsonlParams { + path: PathBuf, +} + +pub struct JsonlFactory; + +#[async_trait] +impl BackendFactory for JsonlFactory { + fn type_name(&self) -> &'static str { + "jsonl" + } + + async fn build( + &self, + archive_name: &str, + params: toml::Value, + ) -> Result, BackendBuildError> { + let parsed: JsonlParams = + params + .try_into() + .map_err(|e: toml::de::Error| BackendBuildError::InvalidParams { + name: archive_name.into(), + type_name: "jsonl".into(), + source: anyhow::Error::new(e), + })?; + + let backend = JsonlBackend::new(parsed.path).await.map_err(|e| { + BackendBuildError::BackendInit { + name: archive_name.into(), + type_name: "jsonl".into(), + source: anyhow::Error::new(e), + } + })?; + + Ok(Arc::new(backend) as Arc) + } +} + +impl BackendRegistry { + /// Convenience: a registry with `jsonl` pre-registered. + pub fn with_jsonl() -> Self { + let mut r = Self::new(); + r.register(Arc::new(JsonlFactory)); + r + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::backend::mock::MockBackend; + + struct MockFactory; + #[async_trait] + impl BackendFactory for MockFactory { + fn type_name(&self) -> &'static str { + "mock" + } + async fn build( + &self, + _archive_name: &str, + _params: toml::Value, + ) -> Result, BackendBuildError> { + Ok(Arc::new(MockBackend::new()) as Arc) + } + } + + #[tokio::test] + async fn unknown_type_rejected() { + let r = BackendRegistry::new(); + let err = r + .build("a", "nope", toml::Value::Table(Default::default())) + .await + .map(|_| ()) + .unwrap_err(); + assert!(matches!(err, BackendBuildError::UnknownType(s) if s == "nope")); + } + + #[tokio::test] + async fn registered_factory_builds() { + let mut r = BackendRegistry::new(); + r.register(Arc::new(MockFactory)); + let backend = r + .build("a", "mock", toml::Value::Table(Default::default())) + .await + .unwrap(); + let _: &dyn ArchiveBackend = &*backend; + } + + #[tokio::test] + async fn jsonl_factory_builds_under_tempdir() { + let dir = tempfile::tempdir().unwrap(); + let r = BackendRegistry::with_jsonl(); + let mut params = toml::value::Table::new(); + params.insert( + "path".into(), + toml::Value::String(dir.path().to_string_lossy().into_owned()), + ); + let backend = r + .build("main", "jsonl", toml::Value::Table(params)) + .await + .unwrap(); + let health = backend.health_check().await; + assert!(matches!( + health, + crate::backend::HealthStatus::Healthy | crate::backend::HealthStatus::Degraded { .. } + )); + } +} diff --git a/crates/dirigent_archivist/src/registry/filter.rs b/crates/dirigent_archivist/src/registry/filter.rs new file mode 100644 index 0000000..c2eaba8 --- /dev/null +++ b/crates/dirigent_archivist/src/registry/filter.rs @@ -0,0 +1,187 @@ +//! Per-archive include/exclude filter. Consulted during non-primary write +//! fanout (Task 20). Primary always writes regardless of filter. + +use std::collections::HashSet; + +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::types::SessionMetadata; + +/// Declarative filter applied to non-primary fanout writes. +/// +/// A registration's filter decides whether a given session should be +/// replicated to that archive. The primary write target ignores the filter +/// and always writes. A default filter (`ArchiveFilter::default()`) is +/// unrestricted and allows every session. +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct ArchiveFilter { + /// If `Some`, only sessions whose `connector_uid` is in the set are + /// accepted. If `None`, any connector is allowed (subject to other rules). + #[serde(default)] + pub include_connectors: Option>, + /// Connector UIDs that are explicitly rejected. Takes precedence over + /// `include_connectors`. + #[serde(default)] + pub exclude_connectors: HashSet, + /// If non-empty, the session must carry at least one of these tags. + #[serde(default)] + pub include_tags: HashSet, + /// Tags that cause the session to be rejected. + #[serde(default)] + pub exclude_tags: HashSet, + /// When `false`, sessions whose `metadata.hidden == true` are rejected. + #[serde(default = "default_include_hidden")] + pub include_hidden: bool, +} + +fn default_include_hidden() -> bool { + true +} + +impl Default for ArchiveFilter { + fn default() -> Self { + Self { + include_connectors: None, + exclude_connectors: HashSet::new(), + include_tags: HashSet::new(), + exclude_tags: HashSet::new(), + include_hidden: true, + } + } +} + +impl ArchiveFilter { + /// Returns true when this session should be written to the archive. + pub fn allows(&self, session: &SessionMetadata, connector_uid: &Uuid) -> bool { + // Exclude rules win. + if self.exclude_connectors.contains(connector_uid) { + return false; + } + if let Some(inc) = &self.include_connectors { + if !inc.contains(connector_uid) { + return false; + } + } + if session.tags.iter().any(|t| self.exclude_tags.contains(t)) { + return false; + } + if !self.include_tags.is_empty() + && !session.tags.iter().any(|t| self.include_tags.contains(t)) + { + return false; + } + if !self.include_hidden + && session.metadata.get("hidden") == Some(&serde_json::Value::Bool(true)) + { + return false; + } + true + } + + /// A filter that allows everything is equivalent to no filter. + pub fn is_unrestricted(&self) -> bool { + self.include_connectors.is_none() + && self.exclude_connectors.is_empty() + && self.include_tags.is_empty() + && self.exclude_tags.is_empty() + && self.include_hidden + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::SessionMetadata; + + fn make_session(tags: Vec, hidden: bool) -> SessionMetadata { + let mut s = SessionMetadata::stub(Uuid::now_v7()); + s.tags = tags; + s.metadata = if hidden { + serde_json::json!({ "hidden": true }) + } else { + serde_json::Value::Null + }; + s + } + + #[test] + fn default_allows_all() { + let f = ArchiveFilter::default(); + let s = make_session(vec![], false); + let uid = Uuid::new_v4(); + assert!(f.allows(&s, &uid)); + assert!(f.is_unrestricted()); + } + + #[test] + fn exclude_connector_rejects() { + let excluded = Uuid::new_v4(); + let mut f = ArchiveFilter::default(); + f.exclude_connectors.insert(excluded); + let s = make_session(vec![], false); + assert!(!f.allows(&s, &excluded)); + assert!(f.allows(&s, &Uuid::new_v4())); + assert!(!f.is_unrestricted()); + } + + #[test] + fn include_connector_only_allows_listed() { + let allowed = Uuid::new_v4(); + let mut f = ArchiveFilter::default(); + f.include_connectors = Some(HashSet::from_iter([allowed])); + let s = make_session(vec![], false); + assert!(f.allows(&s, &allowed)); + assert!(!f.allows(&s, &Uuid::new_v4())); + } + + #[test] + fn tag_intersection_semantics() { + let mut f = ArchiveFilter::default(); + f.include_tags = HashSet::from_iter(["prod".into()]); + let s_prod = make_session(vec!["prod".into()], false); + let s_dev = make_session(vec!["dev".into()], false); + let uid = Uuid::new_v4(); + assert!(f.allows(&s_prod, &uid)); + assert!(!f.allows(&s_dev, &uid)); + } + + #[test] + fn exclude_tag_wins_over_include() { + let mut f = ArchiveFilter::default(); + f.include_tags = HashSet::from_iter(["prod".into()]); + f.exclude_tags = HashSet::from_iter(["sensitive".into()]); + let s = make_session(vec!["prod".into(), "sensitive".into()], false); + let uid = Uuid::new_v4(); + assert!(!f.allows(&s, &uid)); + } + + #[test] + fn include_hidden_false_rejects_hidden_sessions() { + let mut f = ArchiveFilter::default(); + f.include_hidden = false; + let s_hidden = make_session(vec![], true); + let s_visible = make_session(vec![], false); + let uid = Uuid::new_v4(); + assert!(!f.allows(&s_hidden, &uid)); + assert!(f.allows(&s_visible, &uid)); + } + + #[test] + fn default_include_hidden_accepts_hidden_sessions() { + let f = ArchiveFilter::default(); + let s = make_session(vec![], true); + let uid = Uuid::new_v4(); + assert!(f.allows(&s, &uid)); + } + + #[test] + fn toml_roundtrip_default() { + // Serializing via TOML is exercised through ArchiveConfig, but a + // plain JSON roundtrip on the struct catches serde attribute typos. + let f = ArchiveFilter::default(); + let json = serde_json::to_string(&f).unwrap(); + let back: ArchiveFilter = serde_json::from_str(&json).unwrap(); + assert_eq!(f, back); + } +} diff --git a/crates/dirigent_archivist/src/registry/health.rs b/crates/dirigent_archivist/src/registry/health.rs new file mode 100644 index 0000000..24fd4d6 --- /dev/null +++ b/crates/dirigent_archivist/src/registry/health.rs @@ -0,0 +1,72 @@ +//! Health drift helpers used by both read and write paths. +//! +//! The drift model: successful writes reset `consecutive_failures` to 0 and +//! promote `Degraded` → `Healthy`. Write failures bump the counter and drift +//! to `Degraded { reason }`; after K consecutive failures, the registration +//! drifts to `Unavailable { reason }`, which causes read walks to skip it. +//! +//! Read successes rescue `Degraded` → `Healthy` but don't touch the failure +//! counter (writes are the authoritative health signal). Read failures drift +//! `Healthy` → `Degraded` but never to `Unavailable` by themselves (a truly +//! broken backend will be caught on the next write attempt). + +use chrono::Utc; + +use crate::backend::HealthStatus; +use crate::registry::ArchiveRegistration; + +const FAILURE_THRESHOLD: u32 = 5; + +impl crate::coordinator::Archivist { + pub(crate) async fn record_write_success(&self, reg: &ArchiveRegistration) { + *reg.consecutive_failures.write().await = 0; + let mut h = reg.last_health.write().await; + if !matches!(*h, HealthStatus::Healthy) { + *h = HealthStatus::Healthy; + } + } + + pub(crate) async fn record_read_success(&self, reg: &ArchiveRegistration) { + // Reads don't reset the failure counter — writes are the authoritative + // health signal. But reads DO recover from `Degraded` to `Healthy`. + let mut h = reg.last_health.write().await; + if matches!(*h, HealthStatus::Degraded { .. }) { + *h = HealthStatus::Healthy; + } + } + + pub(crate) async fn record_write_failure( + &self, + reg: &ArchiveRegistration, + reason: &str, + ) { + let mut n = reg.consecutive_failures.write().await; + *n = n.saturating_add(1); + *reg.last_error.write().await = Some((Utc::now(), reason.to_string())); + let mut h = reg.last_health.write().await; + if *n >= FAILURE_THRESHOLD { + *h = HealthStatus::Unavailable { + reason: format!("{} consecutive failures: {reason}", *n), + }; + } else { + *h = HealthStatus::Degraded { + reason: reason.to_string(), + }; + } + } + + pub(crate) async fn record_read_failure(&self, reg: &ArchiveRegistration) { + // Reads alone do not drift to Unavailable. Only drift to Degraded. + let mut h = reg.last_health.write().await; + if matches!(*h, HealthStatus::Healthy) { + *h = HealthStatus::Degraded { + reason: "read failure".into(), + }; + } + } + + #[allow(dead_code)] + pub(crate) async fn current_health(&self, reg: &ArchiveRegistration) -> HealthStatus { + reg.last_health.read().await.clone() + } +} diff --git a/crates/dirigent_archivist/src/registry/mod.rs b/crates/dirigent_archivist/src/registry/mod.rs new file mode 100644 index 0000000..ed0e284 --- /dev/null +++ b/crates/dirigent_archivist/src/registry/mod.rs @@ -0,0 +1,22 @@ +//! Multi-backend registry: configuration, factory, registration entries, +//! read cache, queued writer tasks, and health drift helpers. +//! +//! The single `registry.rs` file from Phase 2 (on-disk archive metadata +//! persistence) has been replaced; archive declaration moves to +//! `dirigent.toml` and is consumed at boot via +//! `coordinator::boot::Archivist::from_config` in later Phase 3 tasks. + +pub mod cache; +pub mod config; +pub mod factory; +pub mod filter; +pub mod health; +pub mod registration; +pub mod writer; + +pub use config::{ArchiveConfig, ArchivesConfig, ConfigValidationError}; +pub use factory::{BackendBuildError, BackendFactory, BackendRegistry, JsonlFactory}; +pub use filter::ArchiveFilter; +pub use registration::{ + ArchiveRegistration, ArchiveStatus, FailureMode, OverflowPolicy, WritePolicy, +}; diff --git a/crates/dirigent_archivist/src/registry/registration.rs b/crates/dirigent_archivist/src/registry/registration.rs new file mode 100644 index 0000000..6e6568e --- /dev/null +++ b/crates/dirigent_archivist/src/registry/registration.rs @@ -0,0 +1,181 @@ +//! Per-backend configuration value types used by `ArchiveRegistration`. + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FailureMode { + Required, + BestEffort, +} + +impl Default for FailureMode { + fn default() -> Self { + FailureMode::Required + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum WritePolicy { + Inline, + Queued { + batch_window_ms: u64, + capacity: usize, + overflow: OverflowPolicy, + }, +} + +impl Default for WritePolicy { + fn default() -> Self { + WritePolicy::Inline + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum OverflowPolicy { + Block, + DropOldest, + Error, +} + +impl Default for OverflowPolicy { + fn default() -> Self { + OverflowPolicy::Block + } +} + +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use tokio::sync::RwLock; + +use crate::backend::{ArchiveBackend, CapabilitySet, HealthStatus}; +use crate::registry::filter::ArchiveFilter; + +use super::writer::WriterHandle; + +pub struct ArchiveRegistration { + pub name: String, + pub type_name: &'static str, + pub backend: Arc, + pub write_active: bool, + pub failure_mode: FailureMode, + pub read_priority: u32, + pub enabled: bool, + pub write_policy: WritePolicy, + /// Per-archive include/exclude filter consulted during non-primary + /// write fanout. Default (unrestricted) accepts every session; the + /// primary target always writes regardless of its filter. + pub filter: ArchiveFilter, + pub last_health: Arc>, + pub last_error: Arc, String)>>>, + pub consecutive_failures: Arc>, + pub writer: Option, +} + +impl ArchiveRegistration { + /// Convenience constructor: builds new `Arc>` instances for the + /// drift trio. Use this for single-process, single-owner registrations + /// (tests and the simple single-archive constructors). + #[allow(clippy::too_many_arguments)] + pub fn new( + name: String, + type_name: &'static str, + backend: Arc, + write_active: bool, + failure_mode: FailureMode, + read_priority: u32, + enabled: bool, + write_policy: WritePolicy, + writer: Option, + initial_health: HealthStatus, + ) -> Self { + Self::new_with_shared_state( + name, + type_name, + backend, + write_active, + failure_mode, + read_priority, + enabled, + write_policy, + writer, + Arc::new(RwLock::new(initial_health)), + Arc::new(RwLock::new(None)), + Arc::new(RwLock::new(0)), + ) + } + + /// Constructor used by `from_config` so the writer task and the + /// registration share the same drift state (both mutate it). + #[allow(clippy::too_many_arguments)] + pub fn new_with_shared_state( + name: String, + type_name: &'static str, + backend: Arc, + write_active: bool, + failure_mode: FailureMode, + read_priority: u32, + enabled: bool, + write_policy: WritePolicy, + writer: Option, + last_health: Arc>, + last_error: Arc, String)>>>, + consecutive_failures: Arc>, + ) -> Self { + Self { + name, + type_name, + backend, + write_active, + failure_mode, + read_priority, + enabled, + write_policy, + filter: ArchiveFilter::default(), + last_health, + last_error, + consecutive_failures, + writer, + } + } + + /// Override the registration's filter. Intended for boot-time wiring + /// (`from_config`) and tests; the field itself is public for other + /// direct consumers. + pub fn with_filter(mut self, filter: ArchiveFilter) -> Self { + self.filter = filter; + self + } + + pub fn capabilities(&self) -> &CapabilitySet { + self.backend.capabilities() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ArchiveStatus { + pub name: String, + pub type_name: String, + pub enabled: bool, + pub write_active: bool, + pub failure_mode: FailureMode, + pub read_priority: u32, + pub capabilities: CapabilitySet, + pub health: HealthStatus, + pub last_error: Option<(DateTime, String)>, + pub queue_depth: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn defaults_are_safe() { + assert_eq!(FailureMode::default(), FailureMode::Required); + assert_eq!(WritePolicy::default(), WritePolicy::Inline); + assert_eq!(OverflowPolicy::default(), OverflowPolicy::Block); + } +} diff --git a/crates/dirigent_archivist/src/registry/writer.rs b/crates/dirigent_archivist/src/registry/writer.rs new file mode 100644 index 0000000..4f4a146 --- /dev/null +++ b/crates/dirigent_archivist/src/registry/writer.rs @@ -0,0 +1,256 @@ +//! Per-backend writer task for `WritePolicy::Queued` backends. +//! +//! The task drains a per-backend mpsc, optionally batching/coalescing within +//! a configured window, and invokes `ArchiveBackend` methods directly. Errors +//! drift health on the parent registration; they do not propagate to the +//! caller. + +use std::sync::Arc; +use std::time::Duration; + +use chrono::Utc; +use tokio::sync::{mpsc, oneshot, watch, RwLock}; +use tokio::task::JoinHandle; +use tracing::{debug, warn}; +use uuid::Uuid; + +use crate::backend::{ArchiveBackend, HealthStatus}; + +use super::OverflowPolicy; + +#[derive(Debug)] +pub enum WriteOp { + PutSession(crate::types::SessionMetadata), + AppendMessages { + scroll_id: Uuid, + msgs: Vec, + }, + DeleteSession { + scroll_id: Uuid, + }, + ClearSessionMessages { + scroll_id: Uuid, + }, + AppendDagEdge(crate::types::DagEdge), + AppendMetaEvents { + scroll_id: Uuid, + events: Vec, + }, + Shutdown(oneshot::Sender<()>), +} + +impl WriteOp { + pub fn op_label(&self) -> &'static str { + match self { + WriteOp::PutSession(_) => "put_session", + WriteOp::AppendMessages { .. } => "append_messages", + WriteOp::DeleteSession { .. } => "delete_session", + WriteOp::ClearSessionMessages { .. } => "clear_session_messages", + WriteOp::AppendDagEdge(_) => "append_dag_edge", + WriteOp::AppendMetaEvents { .. } => "append_meta_events", + WriteOp::Shutdown(_) => "shutdown", + } + } +} + +#[derive(Debug)] +pub struct WriterHandle { + pub sender: mpsc::Sender, + pub overflow: OverflowPolicy, + pub queue_depth: watch::Receiver, + pub join: tokio::sync::Mutex>>, + pub backend_name: String, +} + +impl WriterHandle { + pub async fn enqueue(&self, op: WriteOp) -> Result<(), crate::error::ArchivistError> { + match self.overflow { + OverflowPolicy::Block => self.sender.send(op).await.map_err(|_| { + crate::error::ArchivistError::Other(format!( + "writer task for `{}` has closed", + self.backend_name + )) + }), + OverflowPolicy::Error => self.sender.try_send(op).map_err(|e| match e { + mpsc::error::TrySendError::Full(op) => { + crate::error::ArchivistError::WriteQueueFull { + backend: self.backend_name.clone(), + op: op.op_label(), + } + } + mpsc::error::TrySendError::Closed(_) => { + crate::error::ArchivistError::Other(format!( + "writer task for `{}` has closed", + self.backend_name + )) + } + }), + OverflowPolicy::DropOldest => { + // Tokio mpsc can't truly "drop oldest" without draining from the + // other side; we approximate with "drop newest when full". For + // observability sinks this is acceptable — the contract is + // "never block, may lose data". + let _ = self.sender.try_send(op); + Ok(()) + } + } + } + + pub fn queue_depth_now(&self) -> usize { + *self.queue_depth.borrow() + } +} + +#[allow(clippy::too_many_arguments)] +pub fn spawn_writer( + backend: Arc, + backend_name: String, + capacity: usize, + batch_window: Duration, + overflow: OverflowPolicy, + health: Arc>, + last_error: Arc, String)>>>, + consecutive_failures: Arc>, +) -> WriterHandle { + let (tx, mut rx) = mpsc::channel::(capacity); + let (depth_tx, depth_rx) = watch::channel(0usize); + + let join = tokio::spawn({ + let backend_name = backend_name.clone(); + async move { + const FAILURE_THRESHOLD: u32 = 5; + + loop { + let Some(first) = rx.recv().await else { break }; + let mut batch: Vec = vec![first]; + + let deadline = tokio::time::Instant::now() + batch_window; + while tokio::time::Instant::now() < deadline { + match tokio::time::timeout_at(deadline, rx.recv()).await { + Ok(Some(op)) => batch.push(op), + Ok(None) => break, + Err(_) => break, + } + } + + let _ = depth_tx.send(rx.len()); + + let coalesced = coalesce(batch); + + let mut shutdown_ack: Option> = None; + for op in coalesced { + if let WriteOp::Shutdown(ack) = op { + shutdown_ack = Some(ack); + break; + } + match dispatch_op(&*backend, op).await { + Ok(()) => { + *consecutive_failures.write().await = 0; + let mut h = health.write().await; + if matches!(*h, HealthStatus::Degraded { .. }) { + *h = HealthStatus::Healthy; + } + } + Err(e) => { + warn!( + backend = backend_name.as_str(), + error = %e, + "queued write failed; drifting health" + ); + let mut n = consecutive_failures.write().await; + *n = n.saturating_add(1); + *last_error.write().await = Some((Utc::now(), format!("{e}"))); + let mut h = health.write().await; + if *n >= FAILURE_THRESHOLD { + *h = HealthStatus::Unavailable { + reason: format!("{} consecutive failures", *n), + }; + } else { + *h = HealthStatus::Degraded { reason: format!("{e}") }; + } + } + } + } + + if let Some(ack) = shutdown_ack { + debug!(backend = backend_name.as_str(), "writer task shutting down"); + let _ = ack.send(()); + break; + } + } + } + }); + + WriterHandle { + sender: tx, + overflow, + queue_depth: depth_rx, + join: tokio::sync::Mutex::new(Some(join)), + backend_name, + } +} + +fn coalesce(batch: Vec) -> Vec { + let mut out: Vec = Vec::with_capacity(batch.len()); + for op in batch { + let merged = match (out.last_mut(), &op) { + ( + Some(WriteOp::AppendMessages { scroll_id: a, .. }), + WriteOp::AppendMessages { scroll_id: b, .. }, + ) if a == b => true, + ( + Some(WriteOp::AppendMetaEvents { scroll_id: a, .. }), + WriteOp::AppendMetaEvents { scroll_id: b, .. }, + ) if a == b => true, + _ => false, + }; + + if merged { + match out.last_mut().unwrap() { + WriteOp::AppendMessages { msgs: m1, .. } => { + if let WriteOp::AppendMessages { msgs: m2, .. } = op { + m1.extend(m2); + continue; + } + } + WriteOp::AppendMetaEvents { events: e1, .. } => { + if let WriteOp::AppendMetaEvents { events: e2, .. } = op { + e1.extend(e2); + continue; + } + } + _ => {} + } + } + out.push(op); + } + out +} + +async fn dispatch_op(backend: &dyn ArchiveBackend, op: WriteOp) -> crate::error::Result<()> { + match op { + WriteOp::PutSession(meta) => backend.put_session(meta).await, + WriteOp::AppendMessages { scroll_id, msgs } => { + backend.append_messages(scroll_id, msgs).await + } + WriteOp::DeleteSession { scroll_id } => backend.delete_session(scroll_id).await, + WriteOp::ClearSessionMessages { scroll_id } => { + backend.clear_session_messages(scroll_id).await + } + WriteOp::AppendDagEdge(edge) => { + if let Some(d) = backend.as_dag() { + d.append_dag_edge(edge).await + } else { + Ok(()) + } + } + WriteOp::AppendMetaEvents { scroll_id, events } => { + if let Some(m) = backend.as_meta_events() { + m.append_meta_events(scroll_id, events).await + } else { + Ok(()) + } + } + WriteOp::Shutdown(_) => Ok(()), + } +} diff --git a/crates/dirigent_archivist/src/session.rs b/crates/dirigent_archivist/src/session.rs new file mode 100644 index 0000000..87118b2 --- /dev/null +++ b/crates/dirigent_archivist/src/session.rs @@ -0,0 +1,24 @@ +//! Session management and lineage tracking. +//! +//! Handles session metadata, lineage relationships (splits, continuations), +//! and session lifecycle operations. + +use crate::error::Result; + +/// Session manager for tracking session metadata and lineage +pub struct SessionManager { + // Placeholder - will be populated in implementation phases +} + +impl SessionManager { + /// Create a new session manager + pub fn new() -> Result { + Ok(Self {}) + } +} + +impl Default for SessionManager { + fn default() -> Self { + Self::new().expect("Failed to create default SessionManager") + } +} diff --git a/crates/dirigent_archivist/src/storage/files.rs b/crates/dirigent_archivist/src/storage/files.rs new file mode 100644 index 0000000..90f54b4 --- /dev/null +++ b/crates/dirigent_archivist/src/storage/files.rs @@ -0,0 +1,465 @@ +//! Content-addressable file storage. +//! +//! Handles storage and retrieval of binary files (images, documents, etc.) +//! using content-addressable naming based on SHA-256 hashes. +//! +//! Files are stored with deduplication: +//! - Same content = same file_id = stored once +//! - Multiple sessions can reference the same file +//! - File index tracks all referencing sessions + +use crate::storage::{ndjson, paths::ArchivePaths}; +use crate::types::FileRecord; +use sha2::{Digest, Sha256}; +use uuid::Uuid; + +/// Store a file in the archive +/// +/// This function: +/// 1. Computes SHA-256 hash of content +/// 2. Generates file_id: `sha256:{hex_digest}` +/// 3. Stores blob in sharded directory (if not already exists - deduplication) +/// 4. Updates file index to track the session referencing this file +/// 5. Returns the file_id +/// +/// # Arguments +/// * `paths` - Archive paths helper +/// * `content` - File content bytes +/// * `original_name` - Original filename +/// * `mime` - Optional MIME type +/// * `session` - Session UUID that references this file +/// +/// # Returns +/// The file_id (e.g., "sha256:abc123...") +pub async fn store_file( + paths: &ArchivePaths, + content: &[u8], + original_name: String, + mime: Option, + session: Uuid, +) -> std::io::Result { + // Compute SHA-256 hash + let hash = Sha256::digest(content); + let hex_digest = hex::encode(hash); + let file_id = format!("sha256:{}", hex_digest); + + // Get blob path + let blob_path = paths.file_blob_path(&file_id); + + // Create parent directories for blob + if let Some(parent) = blob_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + + // Write blob if it doesn't exist (deduplication) + if !blob_path.exists() { + tokio::fs::write(&blob_path, content).await?; + } + + // Update file index + let index_path = paths.root().join(".files").join("file_index.jsonl"); + + // Create .files directory if it doesn't exist + if let Some(parent) = index_path.parent() { + tokio::fs::create_dir_all(parent).await?; + } + + // Serialize the read-modify-rewrite below. Concurrent callers against + // the same archive would otherwise lose records (both read the same + // snapshot) and race on `rename(.tmp → .ndjson)` (second call hits + // ENOENT because the first already consumed the shared temp path). + let index_lock = paths.file_index_lock(); + let _index_guard = index_lock.lock().await; + + // Read existing index + let mut records: Vec = ndjson::read_ndjson(&index_path).await?; + + // Find or create FileRecord + if let Some(existing) = records.iter_mut().find(|r| r.file_id == file_id) { + // File already exists - add session if not already present + if !existing.sessions.contains(&session) { + existing.sessions.push(session); + } + } else { + // New file - create record + let relative_path = blob_path + .strip_prefix(paths.root()) + .unwrap_or(&blob_path) + .to_string_lossy() + .to_string(); + + let new_record = FileRecord { + version: 1, + file_id: file_id.clone(), + path: relative_path, + size: content.len() as u64, + mime: mime.clone(), + original_name: original_name.clone(), + sessions: vec![session], + metadata: serde_json::json!({}), + }; + + records.push(new_record); + } + + // Rewrite entire index atomically + // Use temp file + rename pattern + let temp_index_path = index_path.with_extension("tmp"); + + // Clear temp file and write all records + if temp_index_path.exists() { + tokio::fs::remove_file(&temp_index_path).await?; + } + + for rec in &records { + ndjson::append_ndjson(&temp_index_path, rec).await?; + } + + // Rename to final location + tokio::fs::rename(&temp_index_path, &index_path).await?; + + Ok(file_id) +} + +/// Retrieve a file from the archive +/// +/// # Arguments +/// * `paths` - Archive paths helper +/// * `file_id` - File identifier (e.g., "sha256:abc123...") +/// +/// # Returns +/// File content bytes +pub async fn get_file(paths: &ArchivePaths, file_id: &str) -> std::io::Result> { + let blob_path = paths.file_blob_path(file_id); + tokio::fs::read(&blob_path).await +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_content_deduplication() { + let temp_dir = + std::env::temp_dir().join(format!("archivist_files_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + + let content = b"Hello, world! This is test content."; + let session1 = Uuid::now_v7(); + let session2 = Uuid::now_v7(); + + // Store same content from two different sessions + let file_id1 = store_file( + &paths, + content, + "test1.txt".to_string(), + Some("text/plain".to_string()), + session1, + ) + .await + .unwrap(); + + let file_id2 = store_file( + &paths, + content, + "test2.txt".to_string(), // Different name + Some("text/plain".to_string()), + session2, + ) + .await + .unwrap(); + + // Same content should produce same file_id + assert_eq!(file_id1, file_id2); + + // Verify blob was only written once + let blob_path = paths.file_blob_path(&file_id1); + assert!(blob_path.exists()); + + // Verify index tracks both sessions + let index_path = paths.root().join(".files").join("file_index.jsonl"); + let records: Vec = ndjson::read_ndjson(&index_path).await.unwrap(); + + let record = records.iter().find(|r| r.file_id == file_id1).unwrap(); + assert_eq!(record.sessions.len(), 2); + assert!(record.sessions.contains(&session1)); + assert!(record.sessions.contains(&session2)); + + // Clean up + tokio::fs::remove_dir_all(&temp_dir).await.ok(); + } + + #[tokio::test] + async fn test_sharding_distributes_files() { + let temp_dir = + std::env::temp_dir().join(format!("archivist_files_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + + let session = Uuid::now_v7(); + + // Store files with different content + let content1 = b"Content A"; + let content2 = b"Content B"; + let content3 = b"Content C"; + + let file_id1 = store_file(&paths, content1, "file1.txt".to_string(), None, session) + .await + .unwrap(); + + let file_id2 = store_file(&paths, content2, "file2.txt".to_string(), None, session) + .await + .unwrap(); + + let file_id3 = store_file(&paths, content3, "file3.txt".to_string(), None, session) + .await + .unwrap(); + + // Verify different content produces different file_ids + assert_ne!(file_id1, file_id2); + assert_ne!(file_id2, file_id3); + + // Verify files are distributed across sharded directories + let blob_path1 = paths.file_blob_path(&file_id1); + let blob_path2 = paths.file_blob_path(&file_id2); + let blob_path3 = paths.file_blob_path(&file_id3); + + assert!(blob_path1.exists()); + assert!(blob_path2.exists()); + assert!(blob_path3.exists()); + + // Verify sharding creates subdirectories + let files_dir = paths.root().join(".files"); + let mut shard_dirs = Vec::new(); + for entry in std::fs::read_dir(&files_dir).unwrap() { + let entry = entry.unwrap(); + if entry.file_type().unwrap().is_dir() { + shard_dirs.push(entry.path()); + } + } + + // Should have at least one shard directory + assert!(!shard_dirs.is_empty()); + + // Clean up + tokio::fs::remove_dir_all(&temp_dir).await.ok(); + } + + #[tokio::test] + async fn test_index_tracks_sessions() { + let temp_dir = + std::env::temp_dir().join(format!("archivist_files_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + + let content = b"Shared content"; + let session1 = Uuid::now_v7(); + let session2 = Uuid::now_v7(); + let session3 = Uuid::now_v7(); + + // Store from session1 + let file_id = store_file( + &paths, + content, + "file.txt".to_string(), + Some("text/plain".to_string()), + session1, + ) + .await + .unwrap(); + + // Verify index has 1 session + let index_path = paths.root().join(".files").join("file_index.jsonl"); + let records: Vec = ndjson::read_ndjson(&index_path).await.unwrap(); + let record = records.iter().find(|r| r.file_id == file_id).unwrap(); + assert_eq!(record.sessions.len(), 1); + assert_eq!(record.sessions[0], session1); + + // Store same content from session2 + store_file( + &paths, + content, + "file2.txt".to_string(), + Some("text/plain".to_string()), + session2, + ) + .await + .unwrap(); + + // Verify index now has 2 sessions + let records: Vec = ndjson::read_ndjson(&index_path).await.unwrap(); + let record = records.iter().find(|r| r.file_id == file_id).unwrap(); + assert_eq!(record.sessions.len(), 2); + assert!(record.sessions.contains(&session1)); + assert!(record.sessions.contains(&session2)); + + // Store same content from session3 + store_file(&paths, content, "file3.txt".to_string(), None, session3) + .await + .unwrap(); + + // Verify index now has 3 sessions + let records: Vec = ndjson::read_ndjson(&index_path).await.unwrap(); + let record = records.iter().find(|r| r.file_id == file_id).unwrap(); + assert_eq!(record.sessions.len(), 3); + assert!(record.sessions.contains(&session1)); + assert!(record.sessions.contains(&session2)); + assert!(record.sessions.contains(&session3)); + + // Clean up + tokio::fs::remove_dir_all(&temp_dir).await.ok(); + } + + #[tokio::test] + async fn test_concurrent_writes_different_files() { + let temp_dir = + std::env::temp_dir().join(format!("archivist_files_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + + let content1 = b"Content 1"; + let content2 = b"Content 2"; + let session = Uuid::now_v7(); + + // Store concurrently + let (file_id1, file_id2) = tokio::join!( + store_file(&paths, content1, "file1.txt".to_string(), None, session,), + store_file(&paths, content2, "file2.txt".to_string(), None, session,) + ); + + let file_id1 = file_id1.unwrap(); + let file_id2 = file_id2.unwrap(); + + // Verify both files exist + assert_ne!(file_id1, file_id2); + + let retrieved1 = get_file(&paths, &file_id1).await.unwrap(); + let retrieved2 = get_file(&paths, &file_id2).await.unwrap(); + + assert_eq!(retrieved1, content1); + assert_eq!(retrieved2, content2); + + // Clean up + tokio::fs::remove_dir_all(&temp_dir).await.ok(); + } + + #[tokio::test] + async fn test_get_file_missing() { + let temp_dir = + std::env::temp_dir().join(format!("archivist_files_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + + // Try to get non-existent file + let result = get_file(&paths, "sha256:nonexistent").await; + assert!(result.is_err()); + + match result { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::NotFound); + } + Ok(_) => panic!("Expected NotFound error"), + } + + // Clean up + tokio::fs::remove_dir_all(&temp_dir).await.ok(); + } + + #[tokio::test] + async fn test_roundtrip_binary_content() { + let temp_dir = + std::env::temp_dir().join(format!("archivist_files_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + + // Binary content (not UTF-8) + let content: Vec = (0..256).map(|i| i as u8).collect(); + let session = Uuid::now_v7(); + + // Store + let file_id = store_file( + &paths, + &content, + "binary.dat".to_string(), + Some("application/octet-stream".to_string()), + session, + ) + .await + .unwrap(); + + // Retrieve + let retrieved = get_file(&paths, &file_id).await.unwrap(); + + // Verify exact match + assert_eq!(retrieved, content); + + // Clean up + tokio::fs::remove_dir_all(&temp_dir).await.ok(); + } + + #[tokio::test] + async fn test_file_metadata_preserved() { + let temp_dir = + std::env::temp_dir().join(format!("archivist_files_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + + let content = b"Test content"; + let session = Uuid::now_v7(); + let original_name = "document.pdf".to_string(); + let mime = Some("application/pdf".to_string()); + + // Store + let file_id = store_file( + &paths, + content, + original_name.clone(), + mime.clone(), + session, + ) + .await + .unwrap(); + + // Read index + let index_path = paths.root().join(".files").join("file_index.jsonl"); + let records: Vec = ndjson::read_ndjson(&index_path).await.unwrap(); + + let record = records.iter().find(|r| r.file_id == file_id).unwrap(); + + // Verify metadata + assert_eq!(record.original_name, original_name); + assert_eq!(record.mime, mime); + assert_eq!(record.size, content.len() as u64); + assert!(record.path.contains(".files")); + + // Clean up + tokio::fs::remove_dir_all(&temp_dir).await.ok(); + } + + #[tokio::test] + async fn test_deduplicate_same_session() { + let temp_dir = + std::env::temp_dir().join(format!("archivist_files_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + + let content = b"Duplicate content"; + let session = Uuid::now_v7(); + + // Store same content twice from same session + let file_id1 = store_file(&paths, content, "file1.txt".to_string(), None, session) + .await + .unwrap(); + + let file_id2 = store_file(&paths, content, "file2.txt".to_string(), None, session) + .await + .unwrap(); + + // Same file_id + assert_eq!(file_id1, file_id2); + + // Session should only appear once in the index + let index_path = paths.root().join(".files").join("file_index.jsonl"); + let records: Vec = ndjson::read_ndjson(&index_path).await.unwrap(); + let record = records.iter().find(|r| r.file_id == file_id1).unwrap(); + + assert_eq!(record.sessions.len(), 1); + assert_eq!(record.sessions[0], session); + + // Clean up + tokio::fs::remove_dir_all(&temp_dir).await.ok(); + } +} diff --git a/crates/dirigent_archivist/src/storage/json.rs b/crates/dirigent_archivist/src/storage/json.rs new file mode 100644 index 0000000..ea117d8 --- /dev/null +++ b/crates/dirigent_archivist/src/storage/json.rs @@ -0,0 +1,342 @@ +//! JSON storage utilities for session metadata. +//! +//! Handles reading and writing JSON files for session and connector metadata. +//! Uses atomic write operations (write-to-temp + rename) to ensure consistency. + +use serde::{Deserialize, Serialize}; +use std::path::Path; +use tokio::io::AsyncWriteExt; + +/// Write a value to a JSON file atomically +/// +/// This function: +/// 1. Serializes the value to pretty-printed JSON +/// 2. Writes to a temporary file (`{path}.tmp`) +/// 3. Renames the temp file to the target path (atomic operation) +/// +/// The rename operation is atomic on most filesystems, ensuring that +/// readers will either see the old complete file or the new complete file, +/// never a partially written file. +/// +/// # Arguments +/// * `path` - Path to the JSON file +/// * `value` - Value to serialize and write +/// +/// # Example +/// ```no_run +/// use dirigent_archivist::storage::json::write_json; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Serialize, Deserialize)] +/// struct Config { +/// setting: String, +/// } +/// +/// # async fn example() -> std::io::Result<()> { +/// let config = Config { setting: "value".to_string() }; +/// write_json(std::path::Path::new("config.json"), &config).await?; +/// # Ok(()) +/// # } +/// ``` +pub async fn write_json(path: &Path, value: &T) -> std::io::Result<()> { + // Serialize to pretty-printed JSON + let json = serde_json::to_string_pretty(value) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Create temporary file path (same directory for atomic rename) + let temp_path = path.with_extension("tmp"); + + // Write to temporary file + let mut file = tokio::fs::File::create(&temp_path).await?; + file.write_all(json.as_bytes()).await?; + file.sync_all().await?; + drop(file); // Close the file before rename + + // Atomically rename temp file to target path + tokio::fs::rename(&temp_path, path).await?; + + Ok(()) +} + +/// Read a value from a JSON file +/// +/// If the file doesn't exist, returns a NotFound error. +/// +/// # Arguments +/// * `path` - Path to the JSON file +/// +/// # Returns +/// Deserialized value +/// +/// # Example +/// ```no_run +/// use dirigent_archivist::storage::json::read_json; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Serialize, Deserialize)] +/// struct Config { +/// setting: String, +/// } +/// +/// # async fn example() -> std::io::Result<()> { +/// let config: Config = read_json(std::path::Path::new("config.json")).await?; +/// # Ok(()) +/// # } +/// ``` +pub async fn read_json Deserialize<'de>>(path: &Path) -> std::io::Result { + // Read file to string + let content = tokio::fs::read_to_string(path).await?; + + // Deserialize from JSON + let value: T = serde_json::from_str(&content) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + Ok(value) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + use uuid::Uuid; + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + struct TestData { + id: String, + value: i32, + nested: NestedData, + } + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + struct NestedData { + flag: bool, + items: Vec, + } + + #[tokio::test] + async fn test_write_and_read_roundtrip() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_json_{}.json", Uuid::now_v7())); + + let data = TestData { + id: "test-123".to_string(), + value: 42, + nested: NestedData { + flag: true, + items: vec!["a".to_string(), "b".to_string(), "c".to_string()], + }, + }; + + // Write + write_json(&file_path, &data).await.unwrap(); + + // Read back + let read_data: TestData = read_json(&file_path).await.unwrap(); + + // Verify + assert_eq!(read_data, data); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_pretty_printed_output() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_pretty_{}.json", Uuid::now_v7())); + + let data = TestData { + id: "test".to_string(), + value: 100, + nested: NestedData { + flag: false, + items: vec!["x".to_string()], + }, + }; + + // Write + write_json(&file_path, &data).await.unwrap(); + + // Read as raw string + let content = tokio::fs::read_to_string(&file_path).await.unwrap(); + + // Verify it's pretty-printed (contains newlines and indentation) + assert!(content.contains('\n')); + assert!(content.contains(" ")); // Indentation + assert!(content.contains(r#""id""#)); + assert!(content.contains(r#""value""#)); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_read_missing_file() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("nonexistent_{}.json", Uuid::now_v7())); + + // Should return NotFound error + let result: std::io::Result = read_json(&file_path).await; + assert!(result.is_err()); + + match result { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::NotFound); + } + Ok(_) => panic!("Expected NotFound error"), + } + } + + #[tokio::test] + async fn test_atomic_write() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_atomic_{}.json", Uuid::now_v7())); + + let data1 = TestData { + id: "first".to_string(), + value: 1, + nested: NestedData { + flag: true, + items: vec![], + }, + }; + + let data2 = TestData { + id: "second".to_string(), + value: 2, + nested: NestedData { + flag: false, + items: vec!["updated".to_string()], + }, + }; + + // Write first version + write_json(&file_path, &data1).await.unwrap(); + + // Verify first version + let read1: TestData = read_json(&file_path).await.unwrap(); + assert_eq!(read1.id, "first"); + + // Overwrite with second version + write_json(&file_path, &data2).await.unwrap(); + + // Verify second version + let read2: TestData = read_json(&file_path).await.unwrap(); + assert_eq!(read2.id, "second"); + assert_eq!(read2.value, 2); + + // Temp file should not exist after rename + let temp_path = file_path.with_extension("tmp"); + assert!(!temp_path.exists()); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_invalid_json_error() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_invalid_{}.json", Uuid::now_v7())); + + // Write invalid JSON manually + tokio::fs::write(&file_path, "{ invalid json }") + .await + .unwrap(); + + // Reading should fail with InvalidData error + let result: std::io::Result = read_json(&file_path).await; + assert!(result.is_err()); + + match result { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + } + Ok(_) => panic!("Expected InvalidData error"), + } + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_concurrent_writes_different_files() { + let temp_dir = std::env::temp_dir(); + let file1 = temp_dir.join(format!("test_concurrent_1_{}.json", Uuid::now_v7())); + let file2 = temp_dir.join(format!("test_concurrent_2_{}.json", Uuid::now_v7())); + + let data1 = TestData { + id: "file1".to_string(), + value: 1, + nested: NestedData { + flag: true, + items: vec![], + }, + }; + + let data2 = TestData { + id: "file2".to_string(), + value: 2, + nested: NestedData { + flag: false, + items: vec![], + }, + }; + + // Write concurrently + let (r1, r2) = tokio::join!(write_json(&file1, &data1), write_json(&file2, &data2)); + + r1.unwrap(); + r2.unwrap(); + + // Verify both files + let read1: TestData = read_json(&file1).await.unwrap(); + let read2: TestData = read_json(&file2).await.unwrap(); + + assert_eq!(read1, data1); + assert_eq!(read2, data2); + + // Clean up + tokio::fs::remove_file(&file1).await.ok(); + tokio::fs::remove_file(&file2).await.ok(); + } + + #[tokio::test] + async fn test_write_creates_parent_directory() { + let temp_dir = std::env::temp_dir(); + let base_dir = temp_dir.join(format!("test_parent_{}", Uuid::now_v7())); + + // Note: Parent directory does NOT exist yet + // This test verifies that write_json does NOT auto-create parent dirs + // (Caller is responsible for creating parent directories) + + let file_path = base_dir.join("subdir").join("test.json"); + + let data = TestData { + id: "test".to_string(), + value: 42, + nested: NestedData { + flag: true, + items: vec![], + }, + }; + + // This should fail because parent directory doesn't exist + let result = write_json(&file_path, &data).await; + assert!(result.is_err()); + + // Now create parent directory + tokio::fs::create_dir_all(file_path.parent().unwrap()) + .await + .unwrap(); + + // Now write should succeed + write_json(&file_path, &data).await.unwrap(); + + // Verify + let read_data: TestData = read_json(&file_path).await.unwrap(); + assert_eq!(read_data, data); + + // Clean up + tokio::fs::remove_dir_all(&base_dir).await.ok(); + } +} diff --git a/crates/dirigent_archivist/src/storage/mod.rs b/crates/dirigent_archivist/src/storage/mod.rs new file mode 100644 index 0000000..0f49474 --- /dev/null +++ b/crates/dirigent_archivist/src/storage/mod.rs @@ -0,0 +1,118 @@ +//! Storage layer for the Archivist. +//! +//! Provides file-based storage using NDJSON, JSON, and TSV formats, +//! along with content-addressable file storage for attachments. + +use uuid::Uuid; + +pub mod files; +pub mod json; +pub mod ndjson; +pub mod paths; +pub mod tsv; + +// Re-export commonly used types and functions +pub use files::{get_file, store_file}; +pub use json::{read_json, write_json}; +pub use ndjson::{append_ndjson, read_ndjson, write_ndjson}; +pub use paths::ArchivePaths; +pub use tsv::{read_connector_index, write_connector_index}; + +/// Check if a UUID is version 7 (time-ordered). +/// +/// UUID version 7 is used throughout the archivist for scroll_ids and other +/// identifiers that need to be time-ordered and sortable. +/// +/// # Examples +/// +/// ``` +/// use uuid::Uuid; +/// use dirigent_archivist::storage::is_uuid7; +/// +/// let uuid7 = Uuid::now_v7(); +/// assert!(is_uuid7(&uuid7)); +/// +/// let uuid4 = Uuid::new_v4(); +/// assert!(!is_uuid7(&uuid4)); +/// ``` +pub fn is_uuid7(uuid: &Uuid) -> bool { + uuid.get_version_num() == 7 +} + +/// Parse a string as UUID7, returning None for other versions. +/// +/// This function ensures that only UUID version 7 identifiers are accepted, +/// rejecting other UUID versions (v1, v4, v5, etc.) that may be valid UUIDs +/// but don't meet the archivist's time-ordering requirements. +/// +/// # Examples +/// +/// ``` +/// use uuid::Uuid; +/// use dirigent_archivist::storage::parse_uuid7; +/// +/// // UUID7 string parses successfully +/// let uuid7_str = Uuid::now_v7().to_string(); +/// assert!(parse_uuid7(&uuid7_str).is_some()); +/// +/// // UUID4 string is rejected +/// let uuid4_str = Uuid::new_v4().to_string(); +/// assert!(parse_uuid7(&uuid4_str).is_none()); +/// +/// // Invalid UUID string is rejected +/// assert!(parse_uuid7("not-a-uuid").is_none()); +/// ``` +pub fn parse_uuid7(s: &str) -> Option { + Uuid::parse_str(s).ok().filter(|u| is_uuid7(u)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_uuid7_accepts_uuid7() { + let uuid7 = Uuid::now_v7(); + assert!(is_uuid7(&uuid7), "UUID7 should be recognized as version 7"); + } + + #[test] + fn test_is_uuid7_rejects_uuid4() { + let uuid4 = Uuid::new_v4(); + assert!(!is_uuid7(&uuid4), "UUID4 should not be recognized as version 7"); + } + + #[test] + fn test_parse_uuid7_accepts_valid_uuid7_string() { + let uuid7 = Uuid::now_v7(); + let uuid7_str = uuid7.to_string(); + let parsed = parse_uuid7(&uuid7_str); + + assert!(parsed.is_some(), "Valid UUID7 string should parse"); + assert_eq!(parsed.unwrap(), uuid7, "Parsed UUID should match original"); + } + + #[test] + fn test_parse_uuid7_rejects_uuid4_string() { + let uuid4 = Uuid::new_v4(); + let uuid4_str = uuid4.to_string(); + let parsed = parse_uuid7(&uuid4_str); + + assert!(parsed.is_none(), "UUID4 string should be rejected"); + } + + #[test] + fn test_parse_uuid7_rejects_invalid_uuid_string() { + let invalid_strings = vec![ + "not-a-uuid", + "12345678-1234-1234-1234-123456789012-extra", + "", + "invalid", + ]; + + for invalid in invalid_strings { + let parsed = parse_uuid7(invalid); + assert!(parsed.is_none(), "Invalid UUID string '{}' should be rejected", invalid); + } + } +} diff --git a/crates/dirigent_archivist/src/storage/ndjson.rs b/crates/dirigent_archivist/src/storage/ndjson.rs new file mode 100644 index 0000000..34f0676 --- /dev/null +++ b/crates/dirigent_archivist/src/storage/ndjson.rs @@ -0,0 +1,361 @@ +//! NDJSON (Newline Delimited JSON) storage utilities. +//! +//! Handles reading and writing NDJSON files for incremental message logs. +//! NDJSON format stores one JSON object per line, making it ideal for +//! append-only logs that can be read incrementally. + +use serde::{Deserialize, Serialize}; +use std::path::Path; +use tokio::fs::OpenOptions; +use tokio::io::AsyncWriteExt; + +/// Append a record to an NDJSON file +/// +/// This function: +/// 1. Serializes the record to JSON +/// 2. Opens the file in append mode (creates if not exists) +/// 3. Writes the JSON followed by a newline +/// 4. Calls fsync to ensure durability +/// +/// # Arguments +/// * `path` - Path to the NDJSON file +/// * `record` - Record to append (must be serializable) +/// +/// # Example +/// ```no_run +/// use dirigent_archivist::storage::ndjson::append_ndjson; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Serialize, Deserialize)] +/// struct LogEntry { +/// message: String, +/// } +/// +/// # async fn example() -> std::io::Result<()> { +/// let entry = LogEntry { message: "Hello".to_string() }; +/// append_ndjson(std::path::Path::new("log.ndjson"), &entry).await?; +/// # Ok(()) +/// # } +/// ``` +pub async fn append_ndjson(path: &Path, record: &T) -> std::io::Result<()> { + // Serialize to JSON string + let json = serde_json::to_string(record) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Open file in append mode (create if not exists) + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(path) + .await?; + + // Write JSON + newline + file.write_all(json.as_bytes()).await?; + file.write_all(b"\n").await?; + + // Fsync for durability + file.sync_all().await?; + + Ok(()) +} + +/// Atomically rewrite an NDJSON file with the given records. +/// +/// Uses a temp file + rename for crash safety. If the process crashes during +/// the write, the original file remains untouched. Only after the new content +/// is fully written and fsynced is the old file replaced. +/// +/// # Arguments +/// * `path` - Path to the NDJSON file (will be created or overwritten) +/// * `records` - Records to write (one per line) +/// +/// # Example +/// ```no_run +/// use dirigent_archivist::storage::ndjson::write_ndjson; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Serialize, Deserialize)] +/// struct LogEntry { +/// message: String, +/// } +/// +/// # async fn example() -> std::io::Result<()> { +/// let entries = vec![ +/// LogEntry { message: "First".to_string() }, +/// LogEntry { message: "Second".to_string() }, +/// ]; +/// write_ndjson(std::path::Path::new("log.ndjson"), &entries).await?; +/// # Ok(()) +/// # } +/// ``` +pub async fn write_ndjson(path: &Path, records: &[T]) -> std::io::Result<()> { + let temp_path = path.with_extension("jsonl.tmp"); + let mut file = OpenOptions::new() + .create(true) + .write(true) + .truncate(true) + .open(&temp_path) + .await?; + for record in records { + let json = serde_json::to_string(record) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + file.write_all(json.as_bytes()).await?; + file.write_all(b"\n").await?; + } + file.flush().await?; + file.sync_all().await?; + drop(file); + tokio::fs::rename(&temp_path, path).await?; + Ok(()) +} + +/// Read all records from an NDJSON file +/// +/// This function: +/// 1. Reads the entire file to a string +/// 2. Splits by newlines +/// 3. Deserializes each non-empty line +/// +/// If the file doesn't exist, returns an empty vector. +/// +/// # Arguments +/// * `path` - Path to the NDJSON file +/// +/// # Returns +/// Vector of deserialized records +/// +/// # Example +/// ```no_run +/// use dirigent_archivist::storage::ndjson::read_ndjson; +/// use serde::{Serialize, Deserialize}; +/// +/// #[derive(Serialize, Deserialize)] +/// struct LogEntry { +/// message: String, +/// } +/// +/// # async fn example() -> std::io::Result<()> { +/// let entries: Vec = read_ndjson(std::path::Path::new("log.ndjson")).await?; +/// # Ok(()) +/// # } +/// ``` +pub async fn read_ndjson Deserialize<'de>>(path: &Path) -> std::io::Result> { + // Check if file exists + if !path.exists() { + return Ok(Vec::new()); + } + + // Read entire file to string + let content = tokio::fs::read_to_string(path).await?; + + // Parse line by line + let mut records = Vec::new(); + for (line_num, line) in content.lines().enumerate() { + // Skip empty lines + if line.trim().is_empty() { + continue; + } + + // Deserialize the line + let record: T = serde_json::from_str(line).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Failed to parse line {}: {}", line_num + 1, e), + ) + })?; + + records.push(record); + } + + Ok(records) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + use uuid::Uuid; + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + struct TestRecord { + id: String, + value: i32, + } + + #[tokio::test] + async fn test_append_and_read_roundtrip() { + // Create a temporary file + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_ndjson_{}.ndjson", Uuid::now_v7())); + + // Append multiple records + let record1 = TestRecord { + id: "rec1".to_string(), + value: 42, + }; + let record2 = TestRecord { + id: "rec2".to_string(), + value: 100, + }; + let record3 = TestRecord { + id: "rec3".to_string(), + value: -5, + }; + + append_ndjson(&file_path, &record1).await.unwrap(); + append_ndjson(&file_path, &record2).await.unwrap(); + append_ndjson(&file_path, &record3).await.unwrap(); + + // Read back + let records: Vec = read_ndjson(&file_path).await.unwrap(); + + // Verify + assert_eq!(records.len(), 3); + assert_eq!(records[0], record1); + assert_eq!(records[1], record2); + assert_eq!(records[2], record3); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_read_empty_file() { + // Create a temporary empty file + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_empty_{}.ndjson", Uuid::now_v7())); + + tokio::fs::write(&file_path, "").await.unwrap(); + + // Read should return empty vector + let records: Vec = read_ndjson(&file_path).await.unwrap(); + assert_eq!(records.len(), 0); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_read_missing_file() { + // Read from a non-existent file + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("nonexistent_{}.ndjson", Uuid::now_v7())); + + // Should return empty vector, not error + let records: Vec = read_ndjson(&file_path).await.unwrap(); + assert_eq!(records.len(), 0); + } + + #[tokio::test] + async fn test_trailing_newlines() { + // Create a file with trailing newlines + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_trailing_{}.ndjson", Uuid::now_v7())); + + // Write manually with extra newlines + let content = r#"{"id":"rec1","value":42} + +{"id":"rec2","value":100} + + +"#; + tokio::fs::write(&file_path, content).await.unwrap(); + + // Read should skip empty lines + let records: Vec = read_ndjson(&file_path).await.unwrap(); + assert_eq!(records.len(), 2); + assert_eq!(records[0].id, "rec1"); + assert_eq!(records[1].id, "rec2"); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_concurrent_appends() { + // Test appending to different files concurrently + let temp_dir = std::env::temp_dir(); + let file1 = temp_dir.join(format!("test_concurrent_1_{}.ndjson", Uuid::now_v7())); + let file2 = temp_dir.join(format!("test_concurrent_2_{}.ndjson", Uuid::now_v7())); + + let record1 = TestRecord { + id: "file1".to_string(), + value: 1, + }; + let record2 = TestRecord { + id: "file2".to_string(), + value: 2, + }; + + // Append concurrently + let (r1, r2) = tokio::join!( + append_ndjson(&file1, &record1), + append_ndjson(&file2, &record2) + ); + + r1.unwrap(); + r2.unwrap(); + + // Verify both files have correct content + let records1: Vec = read_ndjson(&file1).await.unwrap(); + let records2: Vec = read_ndjson(&file2).await.unwrap(); + + assert_eq!(records1.len(), 1); + assert_eq!(records1[0], record1); + assert_eq!(records2.len(), 1); + assert_eq!(records2[0], record2); + + // Clean up + tokio::fs::remove_file(&file1).await.ok(); + tokio::fs::remove_file(&file2).await.ok(); + } + + #[tokio::test] + async fn test_invalid_json_error() { + // Create a file with invalid JSON + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_invalid_{}.ndjson", Uuid::now_v7())); + + let content = r#"{"id":"rec1","value":42} +invalid json here +{"id":"rec2","value":100}"#; + tokio::fs::write(&file_path, content).await.unwrap(); + + // Reading should fail with InvalidData error + let result: std::io::Result> = read_ndjson(&file_path).await; + assert!(result.is_err()); + + match result { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert!(e.to_string().contains("line 2")); + } + Ok(_) => panic!("Expected error"), + } + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_fsync_called() { + // This test verifies that append_ndjson completes without error, + // which implies fsync was called successfully + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_fsync_{}.ndjson", Uuid::now_v7())); + + let record = TestRecord { + id: "test".to_string(), + value: 42, + }; + + // Should complete without error (including fsync) + append_ndjson(&file_path, &record).await.unwrap(); + + // Verify file was written + assert!(file_path.exists()); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } +} diff --git a/crates/dirigent_archivist/src/storage/paths.rs b/crates/dirigent_archivist/src/storage/paths.rs new file mode 100644 index 0000000..b8adce9 --- /dev/null +++ b/crates/dirigent_archivist/src/storage/paths.rs @@ -0,0 +1,436 @@ +//! Path management for archive directory structure. +//! +//! Defines the archive directory layout and provides utilities for +//! constructing paths to various archive components. + +use std::path::PathBuf; +use std::sync::Arc; + +use tokio::sync::Mutex; +use uuid::Uuid; + +/// Archive path utilities +/// +/// Provides methods to generate paths for all archive components: +/// - Sessions: `.contexts/{scroll_id}/` +/// - Connectors: `.db/connectors/{connector_uid}/` +/// - Files: `.files/{ab}/{cd}/{ef}/{...}` (sharded by SHA-256) +/// +/// Also carries the per-archive mutex that serialises `store_file`'s +/// read-modify-rewrite of `.files/file_index.ndjson` — the shared +/// `file_index.tmp` path made concurrent calls race on `rename`. +pub struct ArchivePaths { + root: PathBuf, + /// Guards the critical section in `storage::files::store_file` that + /// rewrites the per-archive file index. Cloneable `Arc` so callers + /// that share the same `ArchivePaths` instance serialise correctly. + file_index_lock: Arc>, +} + +impl ArchivePaths { + /// Create a new ArchivePaths instance + pub fn new(root: PathBuf) -> Self { + Self { + root, + file_index_lock: Arc::new(Mutex::new(())), + } + } + + /// Get the archive root directory + pub fn root(&self) -> &PathBuf { + &self.root + } + + /// Acquire the per-archive file-index lock. Held across the + /// `read → modify → append-temp → rename` sequence in + /// `storage::files::store_file`. + pub(crate) fn file_index_lock(&self) -> Arc> { + Arc::clone(&self.file_index_lock) + } + + // ======================================================================== + // Session Paths + // ======================================================================== + + /// Get the directory for a specific session + /// + /// Returns: `{root}/.contexts/{scroll_id}` + pub fn session_dir(&self, scroll_id: Uuid) -> PathBuf { + self.root.join(".contexts").join(scroll_id.to_string()) + } + + /// Get the session metadata JSON file path + /// + /// Returns: `{root}/.contexts/{scroll_id}/session.json` + pub fn session_json(&self, scroll_id: Uuid) -> PathBuf { + self.session_dir(scroll_id).join("session.json") + } + + /// Get the messages NDJSON file path for WRITE operations + /// + /// For read operations, use `messages_path_for_read()` which supports both .jsonl and .ndjson + /// + /// Returns: `{root}/.contexts/{scroll_id}/messages.ndjson` + #[deprecated(since = "0.2.0", note = "Use messages_path_for_write() instead")] + pub fn messages_ndjson(&self, scroll_id: Uuid) -> PathBuf { + self.session_dir(scroll_id).join("messages.ndjson") + } + + /// Resolve messages file path for reading. + /// Checks for .jsonl first, falls back to .ndjson + pub fn messages_path_for_read(&self, scroll_id: Uuid) -> PathBuf { + let session_dir = self.session_dir(scroll_id); + self.resolve_ndjson_or_jsonl(&session_dir, "messages") + } + + /// Get the messages file path for WRITE operations. + /// Always returns .jsonl path (new canonical format). + pub fn messages_path_for_write(&self, scroll_id: Uuid) -> PathBuf { + self.session_dir(scroll_id).join("messages.jsonl") + } + + /// Get the events file path for meta sessions (.jsonl format) + /// + /// Meta sessions (AcpConnection) store connection events in events.jsonl + /// instead of messages. These events track connection lifecycle and session navigation. + /// + /// Returns: `{root}/.contexts/{scroll_id}/events.jsonl` + pub fn events_path(&self, scroll_id: Uuid) -> PathBuf { + self.session_dir(scroll_id).join("events.jsonl") + } + + /// Get the DAG index file path. + /// + /// Returns: `{root}/.db/dag.jsonl` + pub fn dag_path(&self) -> PathBuf { + self.root.join(".db").join("dag.jsonl") + } + + /// Resolve sessions mapping file path for reading. + /// Checks for .jsonl first, falls back to .ndjson + pub fn sessions_path_for_read(&self, connector_uid: Uuid) -> PathBuf { + let connector_dir = self.connector_dir(connector_uid); + self.resolve_ndjson_or_jsonl(&connector_dir, "sessions") + } + + /// Get the sessions file path for WRITE operations. + /// Always returns .jsonl path (new canonical format). + pub fn sessions_path_for_write(&self, connector_uid: Uuid) -> PathBuf { + self.connector_dir(connector_uid).join("sessions.jsonl") + } + + // ======================================================================== + // Connector Paths + // ======================================================================== + + /// Get the directory for a specific connector + /// + /// Returns: `{root}/.db/connectors/{connector_uid}` + pub fn connector_dir(&self, connector_uid: Uuid) -> PathBuf { + self.root + .join(".db") + .join("connectors") + .join(connector_uid.to_string()) + } + + /// Get the connector index TSV file path + /// + /// Returns: `{root}/.db/connectors/index.tsv` + pub fn connector_index_tsv(&self) -> PathBuf { + self.root.join(".db").join("connectors").join("index.tsv") + } + + // ======================================================================== + // File Storage Paths + // ======================================================================== + + /// Get the blob path for a file using sharded storage + /// + /// Sharding strategy: + /// - Input: `sha256:abcdef0123456789...` + /// - Strip `sha256:` prefix + /// - Shard by first 6 characters (2-char segments) + /// - Returns: `{root}/.files/ab/cd/ef/0123456789...` + /// + /// # Arguments + /// * `file_id` - File identifier (e.g., "sha256:abcdef...") + pub fn file_blob_path(&self, file_id: &str) -> PathBuf { + // Strip "sha256:" prefix if present + let hash = file_id.strip_prefix("sha256:").unwrap_or(file_id); + + // Extract first 6 chars for sharding (3 levels of 2 chars each) + // If hash is shorter, we'll just use what we have + let (shard1, remainder) = if hash.len() >= 2 { + hash.split_at(2) + } else { + (hash, "") + }; + + let (shard2, remainder) = if remainder.len() >= 2 { + remainder.split_at(2) + } else { + (remainder, "") + }; + + let (shard3, _) = if remainder.len() >= 2 { + remainder.split_at(2) + } else { + (remainder, remainder) + }; + + // Build the sharded path + let mut path = self.root.join(".files"); + + if !shard1.is_empty() { + path = path.join(shard1); + } + if !shard2.is_empty() { + path = path.join(shard2); + } + if !shard3.is_empty() { + path = path.join(shard3); + } + + // Use the full hash (without prefix) as filename + path.join(hash) + } + + // ======================================================================== + // Directory Creation + // ======================================================================== + + /// Ensure all required directories exist for a session + /// + /// Creates the session directory if it doesn't exist. + pub async fn ensure_dirs(&self, scroll_id: Uuid) -> std::io::Result<()> { + let session_dir = self.session_dir(scroll_id); + tokio::fs::create_dir_all(session_dir).await + } + + /// Ensure the connector directory exists + /// + /// Creates `.db/connectors/{connector_uid}/` if it doesn't exist. + /// This should be called before any operations that write to connector-specific files. + pub async fn ensure_connector_dir(&self, connector_uid: Uuid) -> std::io::Result<()> { + let connector_dir = self.connector_dir(connector_uid); + tokio::fs::create_dir_all(&connector_dir).await + } + + /// Generic resolution: prefer .jsonl, fall back to .ndjson + /// + /// This enables backward compatibility with existing .ndjson archives + /// while supporting the more widely-recognized .jsonl extension. + fn resolve_ndjson_or_jsonl(&self, dir: &std::path::Path, base_name: &str) -> PathBuf { + // Check for .jsonl first (newer, more prominent extension) + let jsonl_path = dir.join(format!("{}.jsonl", base_name)); + if jsonl_path.exists() { + return jsonl_path; + } + + // Fall back to .ndjson (legacy format, still canonical for writes in Phase 1) + dir.join(format!("{}.ndjson", base_name)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::Path; + + #[test] + fn test_session_dir() { + let paths = ArchivePaths::new(PathBuf::from("/archive")); + let scroll_id = Uuid::parse_str("018c8f7e-7b6a-7e3c-9f2d-1a2b3c4d5e6f").unwrap(); + + let session_dir = paths.session_dir(scroll_id); + assert_eq!( + session_dir, + Path::new("/archive/.contexts/018c8f7e-7b6a-7e3c-9f2d-1a2b3c4d5e6f") + ); + } + + #[test] + fn test_session_json() { + let paths = ArchivePaths::new(PathBuf::from("/archive")); + let scroll_id = Uuid::parse_str("018c8f7e-7b6a-7e3c-9f2d-1a2b3c4d5e6f").unwrap(); + + let json_path = paths.session_json(scroll_id); + assert_eq!( + json_path, + Path::new("/archive/.contexts/018c8f7e-7b6a-7e3c-9f2d-1a2b3c4d5e6f/session.json") + ); + } + + #[test] + fn test_messages_ndjson() { + let paths = ArchivePaths::new(PathBuf::from("/archive")); + let scroll_id = Uuid::parse_str("018c8f7e-7b6a-7e3c-9f2d-1a2b3c4d5e6f").unwrap(); + + let messages_path = paths.messages_ndjson(scroll_id); + assert_eq!( + messages_path, + Path::new("/archive/.contexts/018c8f7e-7b6a-7e3c-9f2d-1a2b3c4d5e6f/messages.ndjson") + ); + } + + #[test] + fn test_connector_dir() { + let paths = ArchivePaths::new(PathBuf::from("/archive")); + let connector_uid = Uuid::parse_str("018c8f7e-7b6a-7e3c-9f2d-1a2b3c4d5e6f").unwrap(); + + let connector_dir = paths.connector_dir(connector_uid); + assert_eq!( + connector_dir, + Path::new("/archive/.db/connectors/018c8f7e-7b6a-7e3c-9f2d-1a2b3c4d5e6f") + ); + } + + #[test] + fn test_connector_index_tsv() { + let paths = ArchivePaths::new(PathBuf::from("/archive")); + + let index_path = paths.connector_index_tsv(); + assert_eq!(index_path, Path::new("/archive/.db/connectors/index.tsv")); + } + + #[test] + fn test_file_blob_path_with_prefix() { + let paths = ArchivePaths::new(PathBuf::from("/archive")); + let file_id = "sha256:abcdef0123456789"; + + let blob_path = paths.file_blob_path(file_id); + assert_eq!( + blob_path, + Path::new("/archive/.files/ab/cd/ef/abcdef0123456789") + ); + } + + #[test] + fn test_file_blob_path_without_prefix() { + let paths = ArchivePaths::new(PathBuf::from("/archive")); + let file_id = "abcdef0123456789"; + + let blob_path = paths.file_blob_path(file_id); + assert_eq!( + blob_path, + Path::new("/archive/.files/ab/cd/ef/abcdef0123456789") + ); + } + + #[test] + fn test_file_blob_path_short_hash() { + let paths = ArchivePaths::new(PathBuf::from("/archive")); + + // Very short hash (less than 6 chars) + let file_id = "sha256:abc"; + let blob_path = paths.file_blob_path(file_id); + assert_eq!(blob_path, Path::new("/archive/.files/ab/c/abc")); + + // 4 char hash + let file_id = "sha256:abcd"; + let blob_path = paths.file_blob_path(file_id); + assert_eq!(blob_path, Path::new("/archive/.files/ab/cd/abcd")); + + // 5 char hash + let file_id = "sha256:abcde"; + let blob_path = paths.file_blob_path(file_id); + assert_eq!(blob_path, Path::new("/archive/.files/ab/cd/e/abcde")); + } + + #[test] + fn test_file_blob_path_long_hash() { + let paths = ArchivePaths::new(PathBuf::from("/archive")); + let file_id = "sha256:abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789"; + + let blob_path = paths.file_blob_path(file_id); + assert_eq!( + blob_path, + Path::new("/archive/.files/ab/cd/ef/abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789") + ); + } + + #[test] + fn test_paths_use_correct_separators() { + let paths = ArchivePaths::new(PathBuf::from("/archive")); + let scroll_id = Uuid::parse_str("018c8f7e-7b6a-7e3c-9f2d-1a2b3c4d5e6f").unwrap(); + + // All paths should use PathBuf which handles platform separators + let session_dir = paths.session_dir(scroll_id); + let session_json = paths.session_json(scroll_id); + let messages_ndjson = paths.messages_ndjson(scroll_id); + + // On Windows, these should contain backslashes; on Unix, forward slashes + // PathBuf handles this automatically, so we just verify the components + assert!(session_dir.to_string_lossy().contains(".contexts")); + assert!(session_json.to_string_lossy().contains("session.json")); + assert!(messages_ndjson + .to_string_lossy() + .contains("messages.ndjson")); + } + + #[tokio::test] + async fn test_ensure_dirs() { + // Create a temporary directory for testing + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + let scroll_id = Uuid::now_v7(); + + // Directory should not exist yet + assert!(!paths.session_dir(scroll_id).exists()); + + // Create the directory + paths.ensure_dirs(scroll_id).await.unwrap(); + + // Directory should now exist + assert!(paths.session_dir(scroll_id).exists()); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + } + + #[tokio::test] + async fn test_messages_path_for_read_ndjson_only() { + // Create a temporary directory with only .ndjson file + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + let scroll_id = Uuid::now_v7(); + + // Create session directory and .ndjson file + paths.ensure_dirs(scroll_id).await.unwrap(); + let ndjson_path = paths.messages_ndjson(scroll_id); + tokio::fs::write(&ndjson_path, "test content").await.unwrap(); + + // messages_path_for_read should return the .ndjson path + let resolved_path = paths.messages_path_for_read(scroll_id); + assert_eq!(resolved_path, ndjson_path); + assert!(resolved_path.to_string_lossy().ends_with("messages.ndjson")); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + } + + #[tokio::test] + async fn test_messages_path_for_read_jsonl_preferred() { + // Create a temporary directory with both .ndjson and .jsonl files + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let paths = ArchivePaths::new(temp_dir.clone()); + let scroll_id = Uuid::now_v7(); + + // Create session directory and both files + paths.ensure_dirs(scroll_id).await.unwrap(); + let session_dir = paths.session_dir(scroll_id); + let ndjson_path = session_dir.join("messages.ndjson"); + let jsonl_path = session_dir.join("messages.jsonl"); + + tokio::fs::write(&ndjson_path, "old content").await.unwrap(); + tokio::fs::write(&jsonl_path, "new content").await.unwrap(); + + // messages_path_for_read should prefer .jsonl + let resolved_path = paths.messages_path_for_read(scroll_id); + assert_eq!(resolved_path, jsonl_path); + assert!(resolved_path.to_string_lossy().ends_with("messages.jsonl")); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + } +} diff --git a/crates/dirigent_archivist/src/storage/tsv.rs b/crates/dirigent_archivist/src/storage/tsv.rs new file mode 100644 index 0000000..dc273a2 --- /dev/null +++ b/crates/dirigent_archivist/src/storage/tsv.rs @@ -0,0 +1,552 @@ +//! TSV (Tab-Separated Values) storage utilities. +//! +//! Handles reading and writing TSV files for session listings and indices. +//! TSV format is human-readable and grep-able, making it ideal for manual +//! inspection and command-line processing. + +use crate::types::ConnectorIndexRow; +use std::path::Path; +use tokio::io::AsyncWriteExt; +use uuid::Uuid; + +/// Write connector index to a TSV file atomically +/// +/// This function: +/// 1. Generates the header line +/// 2. Formats each row as tab-separated values +/// 3. Writes to a temporary file +/// 4. Renames to the target path (atomic operation) +/// +/// TSV format: +/// ```text +/// connector_uid\ttype\ttitle\tclient_native_id\talias_of\tcreated_at +/// 018c8f7e-...\tOpenCode\tLocal Dev\topencode@...\t\t2025-01-15T12:34:56Z +/// ``` +/// +/// # Arguments +/// * `path` - Path to the TSV file +/// * `rows` - Rows to write +pub async fn write_connector_index(path: &Path, rows: &[ConnectorIndexRow]) -> std::io::Result<()> { + // Generate header + let header = "connector_uid\ttype\ttitle\tclient_native_id\talias_of\tcreated_at\tfingerprint\n"; + + // Format rows + let mut content = String::from(header); + for row in rows { + let alias_of_str = row + .alias_of + .as_ref() + .map(|u| u.to_string()) + .unwrap_or_default(); + + let fingerprint_str = row.fingerprint.as_deref().unwrap_or(""); + + let line = format!( + "{}\t{}\t{}\t{}\t{}\t{}\t{}\n", + row.connector_uid, + row.r#type, + row.title, + row.client_native_id, + alias_of_str, + row.created_at.to_rfc3339(), + fingerprint_str, + ); + content.push_str(&line); + } + + // Write to temp file + let temp_path = path.with_extension("tmp"); + let mut file = tokio::fs::File::create(&temp_path).await?; + file.write_all(content.as_bytes()).await?; + file.sync_all().await?; + drop(file); + + // Atomically rename + tokio::fs::rename(&temp_path, path).await?; + + Ok(()) +} + +/// Read connector index from a TSV file +/// +/// If the file doesn't exist, returns an empty vector. +/// +/// # Arguments +/// * `path` - Path to the TSV file +/// +/// # Returns +/// Vector of connector index rows +pub async fn read_connector_index(path: &Path) -> std::io::Result> { + // Check if file exists + if !path.exists() { + return Ok(Vec::new()); + } + + // Read file to string + let content = tokio::fs::read_to_string(path).await?; + + // Parse line by line + let mut rows = Vec::new(); + for (line_num, line) in content.lines().enumerate() { + // Skip header (line 0) + if line_num == 0 { + continue; + } + + // Skip empty lines + if line.trim().is_empty() { + continue; + } + + // Split by tab + let parts: Vec<&str> = line.split('\t').collect(); + // Accept 6 columns (legacy, no fingerprint) or 7 columns (with fingerprint) + if parts.len() != 6 && parts.len() != 7 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!( + "Invalid TSV format at line {}: expected 6 or 7 fields, got {}", + line_num + 1, + parts.len() + ), + )); + } + + // Parse fields + let connector_uid = Uuid::parse_str(parts[0]).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Invalid UUID at line {}: {}", line_num + 1, e), + ) + })?; + + let r#type = parts[1].to_string(); + let title = parts[2].to_string(); + let client_native_id = parts[3].to_string(); + + let alias_of = if parts[4].is_empty() { + None + } else { + Some(Uuid::parse_str(parts[4]).map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Invalid alias_of UUID at line {}: {}", line_num + 1, e), + ) + })?) + }; + + let created_at = chrono::DateTime::parse_from_rfc3339(parts[5]) + .map_err(|e| { + std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Invalid timestamp at line {}: {}", line_num + 1, e), + ) + })? + .with_timezone(&chrono::Utc); + + // Parse optional fingerprint (7th column, may be absent in legacy files) + let fingerprint = if parts.len() >= 7 && !parts[6].is_empty() { + Some(parts[6].to_string()) + } else { + None + }; + + rows.push(ConnectorIndexRow { + connector_uid, + r#type, + title, + client_native_id, + alias_of, + created_at, + fingerprint, + }); + } + + Ok(rows) +} + +#[cfg(test)] +mod tests { + use super::*; + use chrono::{DateTime, Utc}; + use std::time::SystemTime; + + #[tokio::test] + async fn test_write_and_read_roundtrip() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_tsv_{}.tsv", Uuid::now_v7())); + + let uid1 = Uuid::now_v7(); + let uid2 = Uuid::now_v7(); + let uid3 = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let rows = vec![ + ConnectorIndexRow { + connector_uid: uid1, + r#type: "OpenCode".to_string(), + title: "Local Dev".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + alias_of: None, + created_at: now, + fingerprint: None, + }, + ConnectorIndexRow { + connector_uid: uid2, + r#type: "ACP".to_string(), + title: "Remote Agent".to_string(), + client_native_id: "acp@http://localhost:3000".to_string(), + alias_of: Some(uid3), + created_at: now, + fingerprint: None, + }, + ]; + + // Write + write_connector_index(&file_path, &rows).await.unwrap(); + + // Read back + let read_rows = read_connector_index(&file_path).await.unwrap(); + + // Verify + assert_eq!(read_rows.len(), 2); + assert_eq!(read_rows[0].connector_uid, uid1); + assert_eq!(read_rows[0].r#type, "OpenCode"); + assert_eq!(read_rows[0].title, "Local Dev"); + assert_eq!(read_rows[0].alias_of, None); + + assert_eq!(read_rows[1].connector_uid, uid2); + assert_eq!(read_rows[1].r#type, "ACP"); + assert_eq!(read_rows[1].alias_of, Some(uid3)); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_optional_field_handling() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_optional_{}.tsv", Uuid::now_v7())); + + let uid1 = Uuid::now_v7(); + let uid2 = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let rows = vec![ + ConnectorIndexRow { + connector_uid: uid1, + r#type: "Type1".to_string(), + title: "Title1".to_string(), + client_native_id: "client1".to_string(), + alias_of: None, // Empty alias_of + created_at: now, + fingerprint: None, + }, + ConnectorIndexRow { + connector_uid: uid2, + r#type: "Type2".to_string(), + title: "Title2".to_string(), + client_native_id: "client2".to_string(), + alias_of: Some(uid1), // Non-empty alias_of + created_at: now, + fingerprint: None, + }, + ]; + + // Write + write_connector_index(&file_path, &rows).await.unwrap(); + + // Verify raw content has empty string for None + let content = tokio::fs::read_to_string(&file_path).await.unwrap(); + let lines: Vec<&str> = content.lines().collect(); + + // First data line should have empty alias_of (two consecutive tabs) + assert!(lines[1].contains("\t\t")); + + // Second data line should have a UUID for alias_of + assert!(lines[2].contains(&uid1.to_string())); + + // Read back + let read_rows = read_connector_index(&file_path).await.unwrap(); + + // Verify optional field handling + assert_eq!(read_rows[0].alias_of, None); + assert_eq!(read_rows[1].alias_of, Some(uid1)); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_header_generation() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_header_{}.tsv", Uuid::now_v7())); + + // Write empty index + write_connector_index(&file_path, &[]).await.unwrap(); + + // Read raw content + let content = tokio::fs::read_to_string(&file_path).await.unwrap(); + + // Verify header + assert_eq!( + content.trim(), + "connector_uid\ttype\ttitle\tclient_native_id\talias_of\tcreated_at\tfingerprint" + ); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_rfc3339_timestamp_formatting() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_timestamp_{}.tsv", Uuid::now_v7())); + + let uid = Uuid::now_v7(); + let timestamp = DateTime::::from(SystemTime::now()); + + let rows = vec![ConnectorIndexRow { + connector_uid: uid, + r#type: "Test".to_string(), + title: "Title".to_string(), + client_native_id: "client".to_string(), + alias_of: None, + created_at: timestamp, + fingerprint: None, + }]; + + // Write + write_connector_index(&file_path, &rows).await.unwrap(); + + // Read raw content + let content = tokio::fs::read_to_string(&file_path).await.unwrap(); + + // Verify RFC 3339 format in content + assert!(content.contains('T')); + assert!(content.contains('Z') || content.contains('+')); + + // Read back and verify timestamp is preserved + let read_rows = read_connector_index(&file_path).await.unwrap(); + let diff = + (timestamp.timestamp_millis() - read_rows[0].created_at.timestamp_millis()).abs(); + assert!(diff < 1000, "Timestamp difference too large"); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_missing_file_returns_empty_vec() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("nonexistent_{}.tsv", Uuid::now_v7())); + + // Should return empty vec, not error + let rows = read_connector_index(&file_path).await.unwrap(); + assert_eq!(rows.len(), 0); + } + + #[tokio::test] + async fn test_malformed_tsv_error() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_malformed_{}.tsv", Uuid::now_v7())); + + // Write malformed TSV (missing fields) + let content = + "connector_uid\ttype\ttitle\tclient_native_id\talias_of\tcreated_at\nuid1\ttype1\n"; + tokio::fs::write(&file_path, content).await.unwrap(); + + // Should fail with InvalidData + let result = read_connector_index(&file_path).await; + assert!(result.is_err()); + + match result { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert!(e.to_string().contains("expected 6 or 7 fields")); + } + Ok(_) => panic!("Expected error"), + } + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_invalid_uuid_error() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_invalid_uuid_{}.tsv", Uuid::now_v7())); + + // Write TSV with invalid UUID + let content = "connector_uid\ttype\ttitle\tclient_native_id\talias_of\tcreated_at\ninvalid-uuid\tType\tTitle\tClient\t\t2025-01-15T12:34:56Z\n"; + tokio::fs::write(&file_path, content).await.unwrap(); + + // Should fail with InvalidData + let result = read_connector_index(&file_path).await; + assert!(result.is_err()); + + match result { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert!(e.to_string().contains("Invalid UUID")); + } + Ok(_) => panic!("Expected error"), + } + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_invalid_timestamp_error() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_invalid_timestamp_{}.tsv", Uuid::now_v7())); + + let uid = Uuid::now_v7(); + + // Write TSV with invalid timestamp + let content = format!("connector_uid\ttype\ttitle\tclient_native_id\talias_of\tcreated_at\n{}\tType\tTitle\tClient\t\tinvalid-timestamp\n", uid); + tokio::fs::write(&file_path, content).await.unwrap(); + + // Should fail with InvalidData + let result = read_connector_index(&file_path).await; + assert!(result.is_err()); + + match result { + Err(e) => { + assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + assert!(e.to_string().contains("Invalid timestamp")); + } + Ok(_) => panic!("Expected error"), + } + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_atomic_write() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_atomic_{}.tsv", Uuid::now_v7())); + + let uid1 = Uuid::now_v7(); + let uid2 = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let rows1 = vec![ConnectorIndexRow { + connector_uid: uid1, + r#type: "First".to_string(), + title: "First Write".to_string(), + client_native_id: "client1".to_string(), + alias_of: None, + created_at: now, + fingerprint: None, + }]; + + let rows2 = vec![ConnectorIndexRow { + connector_uid: uid2, + r#type: "Second".to_string(), + title: "Second Write".to_string(), + client_native_id: "client2".to_string(), + alias_of: None, + created_at: now, + fingerprint: None, + }]; + + // Write first version + write_connector_index(&file_path, &rows1).await.unwrap(); + + // Verify first version + let read1 = read_connector_index(&file_path).await.unwrap(); + assert_eq!(read1.len(), 1); + assert_eq!(read1[0].title, "First Write"); + + // Overwrite with second version + write_connector_index(&file_path, &rows2).await.unwrap(); + + // Verify second version + let read2 = read_connector_index(&file_path).await.unwrap(); + assert_eq!(read2.len(), 1); + assert_eq!(read2[0].title, "Second Write"); + + // Temp file should not exist + let temp_path = file_path.with_extension("tmp"); + assert!(!temp_path.exists()); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_legacy_six_column_tsv_compatibility() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_legacy_tsv_{}.tsv", Uuid::now_v7())); + + let uid = Uuid::now_v7(); + + // Write a legacy 6-column TSV (no fingerprint column) + let content = format!( + "connector_uid\ttype\ttitle\tclient_native_id\talias_of\tcreated_at\n{}\tOpenCode\tLegacy\tclient-legacy\t\t2025-01-15T12:34:56Z\n", + uid + ); + tokio::fs::write(&file_path, content).await.unwrap(); + + // Should parse successfully with fingerprint = None + let rows = read_connector_index(&file_path).await.unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].connector_uid, uid); + assert_eq!(rows[0].r#type, "OpenCode"); + assert_eq!(rows[0].title, "Legacy"); + assert_eq!(rows[0].fingerprint, None); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } + + #[tokio::test] + async fn test_fingerprint_roundtrip() { + let temp_dir = std::env::temp_dir(); + let file_path = temp_dir.join(format!("test_fingerprint_{}.tsv", Uuid::now_v7())); + + let uid1 = Uuid::now_v7(); + let uid2 = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let rows = vec![ + ConnectorIndexRow { + connector_uid: uid1, + r#type: "ACP".to_string(), + title: "Claude CLI".to_string(), + client_native_id: "acp-claude-1".to_string(), + alias_of: None, + created_at: now, + fingerprint: Some("acp/stdio:/usr/bin/claude".to_string()), + }, + ConnectorIndexRow { + connector_uid: uid2, + r#type: "OpenCode".to_string(), + title: "No Fingerprint".to_string(), + client_native_id: "opencode@localhost".to_string(), + alias_of: None, + created_at: now, + fingerprint: None, + }, + ]; + + // Write + write_connector_index(&file_path, &rows).await.unwrap(); + + // Read back + let read_rows = read_connector_index(&file_path).await.unwrap(); + + assert_eq!(read_rows.len(), 2); + assert_eq!( + read_rows[0].fingerprint, + Some("acp/stdio:/usr/bin/claude".to_string()) + ); + assert_eq!(read_rows[1].fingerprint, None); + + // Clean up + tokio::fs::remove_file(&file_path).await.ok(); + } +} diff --git a/crates/dirigent_archivist/src/types.rs b/crates/dirigent_archivist/src/types.rs new file mode 100644 index 0000000..d3f79d4 --- /dev/null +++ b/crates/dirigent_archivist/src/types.rs @@ -0,0 +1,1298 @@ +//! Core types for the Archivist. +//! +//! This module defines the fundamental data structures used throughout the archivist, +//! including session metadata, message records, and connector information. +//! +//! All types follow these conventions: +//! - IDs: UUIDv7 for time-ordered identifiers +//! - Timestamps: RFC 3339 UTC +//! - Versioning: Every record carries a "version" field for schema evolution +//! - Metadata: Free-form JSON object reserved for caller-specific fields + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Continuation type indicating how a session relates to its parent. +/// +/// Defines the relationship when a session is derived from another session. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum Continuation { + /// Session splits from parent at a specific message + Split, + /// Session is a compacted version of parent + Compact, + /// Session references parent without duplication + Reference, + /// Session is an edited version of parent + Edit, + /// Session is a subagent spawned by parent via Agent tool + Subagent, + /// Unknown continuation type (for forward compatibility) + Unknown, +} + +/// How complete the session data in the archive is. +/// +/// Tracks whether the archivist has full message history or only +/// discovery-level metadata for this session. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum SessionCompleteness { + /// Full message history is available (or session is empty-but-known-complete). + /// + /// Set when: + /// - User creates a new session (session/new) + /// - Session history is fully loaded via session/load replay + /// + /// This is the default for backward compatibility: existing session.json + /// files that lack this field will deserialize as Complete. + #[default] + Complete, + /// Session was discovered via connector session/list. + /// Only ID, title, and metadata are known. No messages in the archive. + Discovered, + /// Messages were loaded but may be incomplete (e.g., agent compacted history). + /// Reserved for future refresh/sync functionality. + Partial, +} + +/// Session kind indicating the type of session storage. +/// +/// Distinguishes between regular chat sessions and special meta sessions +/// that track connection events rather than messages. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum SessionKind { + /// Regular chat session with messages stored in messages.jsonl + #[default] + Chat, + /// Meta session tracking ACP client connection history. + /// Stores connection events in events.jsonl rather than messages. + /// Linked session content is fetched on-demand, not duplicated. + AcpConnection, +} + +/// Event types for ACP meta sessions. +/// +/// These events track client connection lifecycle and session navigation. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum MetaEventType { + /// Client connected to the ACP server + ClientConnected, + /// Client disconnected from the ACP server + ClientDisconnected, + /// Client opened a new session + SessionOpened, + /// Client switched to a different session + SessionSwitched, + /// Client closed a session + SessionClosed, +} + +/// A single event record in an ACP meta session. +/// +/// Stored in `events.jsonl` (one per line) for meta sessions. +/// These events track connection lifecycle and session navigation, +/// NOT the actual message content of sessions. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MetaEventRecord { + /// Schema version for forward compatibility + pub version: u32, + /// Unique event ID (UUIDv7) + pub event_id: Uuid, + /// Meta session this event belongs to (scroll_id) + pub session: Uuid, + /// Timestamp when the event occurred + pub ts: DateTime, + /// Type of event + pub event_type: MetaEventType, + /// Human-readable description of the event + pub description: String, + /// Linked session ID (for SessionOpened, SessionSwitched, SessionClosed events) + #[serde(skip_serializing_if = "Option::is_none")] + pub linked_session_id: Option, + /// Linked connector ID (for session-related events) + #[serde(skip_serializing_if = "Option::is_none")] + pub linked_connector_id: Option, + /// Linked connector title (for display purposes) + #[serde(skip_serializing_if = "Option::is_none")] + pub linked_connector_title: Option, + /// Free-form metadata for additional event data + #[serde(default)] + pub metadata: serde_json::Value, +} + +/// Status returned when registering a connector or session. +/// +/// Indicates whether the registration was accepted, aliased to an existing +/// entity, or rejected. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum RegisterStatus { + /// Registration accepted with the provided or generated UID + Accepted, + /// Registration accepted but UID was aliased to an existing entity + Aliased, + /// Registration rejected due to collision or inconsistency + Rejected, +} + +/// Session metadata stored in `session.json`. +/// +/// Contains all metadata about a session including its lineage, connector +/// association, and custom metadata fields. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionMetadata { + /// Schema version for forward compatibility + pub version: u32, + /// Unique scroll ID for this session (UUIDv7) + pub scroll_id: Uuid, + /// When the session was created + pub created_at: DateTime, + /// When the session was last updated + pub updated_at: DateTime, + /// Optional human-readable title + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + /// Connector that owns this session + pub connector_uid: Uuid, + /// Native session ID from the connector (if applicable) + #[serde(skip_serializing_if = "Option::is_none")] + pub native_session_id: Option, + /// Agent ID associated with this session (if applicable) + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_id: Option, + /// Parent session this was derived from (if applicable) + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_scroll_id: Option, + /// How this session continues from its parent (if applicable) + #[serde(skip_serializing_if = "Option::is_none")] + pub continuation: Option, + /// Tags for categorization + #[serde(default)] + pub tags: Vec, + /// Free-form metadata for caller-specific fields + #[serde(default)] + pub metadata: serde_json::Value, + /// If true, this session should not appear in default session listings. + /// + /// Sessions with `no_update=true` exist but are hidden from the main archive list. + /// This is useful for sessions that are: + /// - Archived/inactive and shouldn't clutter the UI + /// - System sessions that users don't need to see + /// - Sessions marked for cleanup but not yet deleted + /// + /// Use `list_sessions_all()` or pass `include_no_update=true` to include these sessions. + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub no_update: bool, + /// Session kind (Chat or AcpConnection). Defaults to Chat. + #[serde(default)] + pub kind: SessionKind, + /// ACP client ID (only for AcpConnection sessions) + #[serde(skip_serializing_if = "Option::is_none")] + pub acp_client_id: Option, + /// Whether the ACP client is currently connected (only for AcpConnection sessions) + #[serde(skip_serializing_if = "Option::is_none")] + pub is_connected: Option, + /// Currently active linked session ID (only for AcpConnection sessions) + #[serde(skip_serializing_if = "Option::is_none")] + pub current_session_id: Option, + /// Agent models metadata (e.g., available models and current selection). + /// Stored as JSON for forward compatibility with protocol changes. + /// Contains `availableModels` and `currentModelId`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub models: Option, + /// Agent modes metadata (e.g., permission modes, plan mode). + /// Stored as JSON for forward compatibility with protocol changes. + /// Contains `availableModes` and `currentModeId`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub modes: Option, + /// ACP config options (JSON-serialized Vec). + /// Stored as JSON to avoid coupling the archivist to protocol types. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub config_options: Option, + /// How complete the archived session data is. + /// Defaults to Complete for backward compatibility with existing archives. + #[serde(default)] + pub completeness: SessionCompleteness, + /// Matrix room ID this session is shared to (e.g. "!abc:matrix.org"). + /// Retained even when sharing is disabled so re-enabling can reconnect. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub matrix_room_id: Option, + /// Whether Matrix sharing is currently active for this session. + #[serde(default)] + pub matrix_sharing_active: bool, + /// ISO 8601 timestamp of when sharing was first enabled. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub matrix_shared_at: Option>, + /// Whether this session is a subagent (non-loadable, hidden from default lists). + #[serde(default, skip_serializing_if = "std::ops::Not::not")] + pub is_subagent: bool, + /// Subagent type (e.g., "Explore", "rust-task-implementer"). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub subagent_type: Option, + /// The Agent tool_use block ID in the parent that spawned this subagent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub spawning_tool_use_id: Option, +} + +/// Reference to an attached file in a message. +/// +/// Links to a file stored in the archive's file storage. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AttachmentRef { + /// File identifier (e.g., "sha256:...") + pub file_id: String, + /// Original filename + pub name: String, + /// MIME type of the file (if known) + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, +} + +/// A single message record stored in `messages.ndjson`. +/// +/// Each line in the NDJSON file is one message record. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessageRecord { + /// Schema version for forward compatibility + pub version: u32, + /// Unique message ID (UUIDv7) + pub message_id: Uuid, + /// Session this message belongs to (scroll_id) + pub session: Uuid, + /// Parent message ID (if this is a response or continuation) + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_id: Option, + /// Timestamp when the message was created + pub ts: DateTime, + /// Message role (e.g., "system", "user", "assistant") + pub role: String, + /// Optional author identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub author: Option, + /// Message content in Markdown format (for search and fallback display) + pub content_md: String, + /// Original content parts for rich rendering (tool calls, code blocks, etc.) + /// This field preserves the structured MessagePart data for proper UI rendering. + /// If None, the UI should fall back to rendering content_md as plain text. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub content_parts: Option, + /// Attached files + #[serde(default)] + pub attachments: Vec, + /// Free-form metadata for connector-specific fields + #[serde(default)] + pub metadata: serde_json::Value, +} + +/// Connector metadata stored in `connector.json`. +/// +/// Contains information about a connector including its type, title, and +/// native client identifier. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ConnectorRecord { + /// Schema version for forward compatibility + pub version: u32, + /// Unique connector UID (UUIDv7) + pub connector_uid: Uuid, + /// Connector type (e.g., "OpenCode", "ACP", "Other") + #[serde(rename = "type")] + pub r#type: String, + /// Human-readable title + pub title: String, + /// Native client identifier (e.g., "opencode@http://localhost:12225") + pub client_native_id: String, + /// If this connector is an alias of another (for deduplication) + #[serde(skip_serializing_if = "Option::is_none")] + pub alias_of: Option, + /// When the connector was registered + pub created_at: DateTime, + /// Free-form metadata + #[serde(default)] + pub metadata: serde_json::Value, + /// Stable fingerprint for identity matching across connector re-registrations. + /// Format: "{transport}:{command_or_url}" e.g. "acp/stdio:/usr/bin/claude" + #[serde(default, skip_serializing_if = "Option::is_none")] + pub fingerprint: Option, +} + +/// Session mapping entry stored in `sessions.ndjson`. +/// +/// Maps native session IDs from connectors to scroll IDs in the archive. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionMapping { + /// Schema version for forward compatibility + pub version: u32, + /// Connector this mapping belongs to + pub connector_uid: Uuid, + /// Native session ID from the connector + pub native_session_id: String, + /// Scroll ID in the archive + pub scroll_id: Uuid, + /// When this mapping was created + pub created_at: DateTime, + /// If this session mapping is an alias of another + #[serde(skip_serializing_if = "Option::is_none")] + pub alias_of: Option, +} + +/// Row in the connector index TSV file. +/// +/// Note: TSV serialization is custom, not derived from serde. +#[derive(Debug, Clone)] +pub struct ConnectorIndexRow { + /// Connector UID + pub connector_uid: Uuid, + /// Connector type + pub r#type: String, + /// Connector title + pub title: String, + /// Native client identifier + pub client_native_id: String, + /// Alias of another connector (if applicable) + pub alias_of: Option, + /// Creation timestamp + pub created_at: DateTime, + /// Stable fingerprint for identity matching across connector re-registrations. + pub fingerprint: Option, +} + +/// File record stored in `file_index.ndjson`. +/// +/// Tracks files stored in the archive's content-addressable storage. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FileRecord { + /// Schema version for forward compatibility + pub version: u32, + /// File identifier (e.g., "sha256:...") + pub file_id: String, + /// Relative path in archive storage + pub path: String, + /// File size in bytes + pub size: u64, + /// MIME type (if known) + #[serde(skip_serializing_if = "Option::is_none")] + pub mime: Option, + /// Original filename + pub original_name: String, + /// Sessions that reference this file + #[serde(default)] + pub sessions: Vec, + /// Free-form metadata + #[serde(default)] + pub metadata: serde_json::Value, +} + +/// Report from a bulk session move operation. +/// +/// Tracks success/failure counts and collects error messages for any +/// sessions that could not be moved. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct MoveReport { + /// Number of sessions successfully moved + pub moved: usize, + /// Number of sessions that failed to move + pub failed: usize, + /// Error messages for failed moves (one per failure) + pub errors: Vec, +} + +// ============================================================================ +// API Request/Response Types +// ============================================================================ + +/// Request to register a new connector. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RegisterConnectorRequest { + /// Connector type (e.g., "OpenCode", "ACP") + #[serde(rename = "type")] + pub r#type: String, + /// Human-readable title + pub title: String, + /// Native client identifier + pub client_native_id: String, + /// Optional custom UID (if not provided, one will be generated) + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_uid: Option, + /// Free-form metadata + #[serde(default)] + pub metadata: serde_json::Value, + /// Stable fingerprint for identity matching across connector re-registrations. + /// Format: "{transport}:{command_or_url}" e.g. "acp/stdio:/usr/bin/claude" + #[serde(default, skip_serializing_if = "Option::is_none")] + pub fingerprint: Option, +} + +/// Response from registering a connector. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RegisterConnectorResponse { + /// Registration status + pub status: RegisterStatus, + /// Assigned or existing connector UID + pub connector_uid: Uuid, + /// If aliased, the UID this was aliased to + #[serde(skip_serializing_if = "Option::is_none")] + pub alias_of: Option, + /// Optional note explaining the result + #[serde(skip_serializing_if = "Option::is_none")] + pub note: Option, +} + +/// Request to register a new session. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RegisterSessionRequest { + /// Connector that owns this session + pub connector_uid: Uuid, + /// Native session ID from the connector + pub native_session_id: String, + /// Optional human-readable title + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + /// Optional custom scroll ID (if not provided, one will be generated) + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_scroll_id: Option, + /// Free-form metadata + #[serde(default)] + pub metadata: serde_json::Value, + /// Completeness level for this session. Defaults to Complete. + #[serde(default)] + pub completeness: SessionCompleteness, + /// Parent session scroll ID (for subagent linkage). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub parent_scroll_id: Option, + /// Whether this is a subagent session. + #[serde(default)] + pub is_subagent: bool, + /// How this session continues from parent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub continuation: Option, + /// Agent ID for this session (subagent identifier). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub agent_id: Option, + /// Subagent type (e.g., "Explore"). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub subagent_type: Option, + /// The Agent tool_use block ID in the parent that spawned this subagent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub spawning_tool_use_id: Option, +} + +/// Response from registering a session. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RegisterSessionResponse { + /// Registration status + pub status: RegisterStatus, + /// Assigned or existing scroll ID + pub scroll_id: Uuid, + /// If aliased, the scroll ID this was aliased to + #[serde(skip_serializing_if = "Option::is_none")] + pub alias_of: Option, +} + +/// An edge in the session DAG — links a parent session to a child subagent session. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DagEdge { + /// Parent session scroll_id + pub parent: Uuid, + /// Child (subagent) session scroll_id + pub child: Uuid, + /// Subagent's agent_id from Claude Code + pub agent_id: String, + /// Subagent type (e.g., "Explore", "rust-task-implementer") + #[serde(default, skip_serializing_if = "Option::is_none")] + pub subagent_type: Option, + /// The tool_use_id of the Agent call that spawned this subagent + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_use_id: Option, + /// Timestamp when the subagent was spawned + #[serde(default, skip_serializing_if = "Option::is_none")] + pub ts: Option>, +} + +// --------------------------------------------------------------------------- +// Session listing — cursor-paged query types +// --------------------------------------------------------------------------- + +/// Maximum number of items a single `list_sessions_paged` call may return. +/// The server clamps `SessionListQuery::limit` to this value. +pub const MAX_PAGE_LIMIT: usize = 200; + +/// Cursor into a sorted session listing. +/// +/// Sessions are ordered by `(updated_at DESC, scroll_id DESC)`. A cursor +/// means "items strictly after this `(updated_at, scroll_id)` point in the +/// ordering". +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionCursor { + pub updated_at: chrono::DateTime, + pub scroll_id: uuid::Uuid, +} + +/// Query shape for [`Archivist::list_sessions_paged`]. +/// +/// All filters are AND-combined. `connector_uids` and `project_ids` scope the +/// search; when both are empty, every session in the archive is considered. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionListQuery { + pub archive: Option, + + // Scoping + #[serde(default)] + pub connector_uids: Vec, + #[serde(default)] + pub project_ids: Vec, + /// Filter by `metadata.project_path` (exact match on filesystem path). + /// Useful for imported sessions that have a working directory but no + /// bound `Project` entity. + pub project_path: Option, + + // Visibility + /// `false` = hide `no_update=true` and `is_subagent=true` sessions (matches + /// legacy `list_sessions` default). `true` = include them (matches legacy + /// `list_sessions_all`). + pub include_hidden: bool, + + // Filters + /// Case-insensitive substring on `SessionMetadata.title`. Sessions with + /// `title=None` never match when this is `Some(_)`. + pub title_query: Option, + /// AND across tags; empty = no filter. + #[serde(default)] + pub tags: Vec, + /// Case-insensitive substring on `metadata.model`. Sessions without a + /// `metadata.model` string never match when this is `Some(_)`. + pub model_filter: Option, + + // Pagination + pub cursor: Option, + /// Requested page size. The implementation clamps to [`MAX_PAGE_LIMIT`]. + pub limit: usize, +} + +impl Default for SessionListQuery { + fn default() -> Self { + Self { + archive: None, + connector_uids: Vec::new(), + project_ids: Vec::new(), + project_path: None, + include_hidden: false, + title_query: None, + tags: Vec::new(), + model_filter: None, + cursor: None, + limit: 20, + } + } +} + +impl SessionListQuery { + /// Builder-style helper for tests and UI code. + /// + /// Adds a single connector UID to the filter. Can be called multiple + /// times to add more connectors. + pub fn with_connector(mut self, connector_uid: uuid::Uuid) -> Self { + self.connector_uids.push(connector_uid); + self + } + + /// Adds a single project ID to the filter. Can be called multiple + /// times to add more projects. + pub fn with_project(mut self, project_id: impl Into) -> Self { + self.project_ids.push(project_id.into()); + self + } + + pub fn with_project_path(mut self, path: impl Into) -> Self { + self.project_path = Some(path.into()); + self + } + + pub fn with_archive(mut self, archive: impl Into) -> Self { + self.archive = Some(archive.into()); + self + } + + pub fn with_limit(mut self, limit: usize) -> Self { + self.limit = limit; + self + } + + pub fn with_cursor(mut self, cursor: Option) -> Self { + self.cursor = cursor; + self + } + + pub fn with_title_query(mut self, q: impl Into) -> Self { + self.title_query = Some(q.into()); + self + } + + pub fn with_include_hidden(mut self, include: bool) -> Self { + self.include_hidden = include; + self + } +} + +/// One page of results from [`Archivist::list_sessions_paged`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionPage { + pub items: Vec, + /// `None` when the scan reached the end of the ordering. + pub next_cursor: Option, + /// Total number of sessions matching the query (before pagination). + /// `None` when the backend does not compute it. + #[serde(default)] + pub total_count: Option, +} + +// --------------------------------------------------------------------------- +// Message listing — cursor-paged query types (Phase 2) +// --------------------------------------------------------------------------- + +/// Cursor into a chronologically sorted message listing. +/// +/// Messages are ordered by `(ts ASC, message_id ASC)`. A cursor means +/// "items strictly after this `(ts, message_id)` point in the ordering". +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct MessageCursor { + pub ts: chrono::DateTime, + pub message_id: uuid::Uuid, +} + +/// A single page of messages returned from cursor-paged reads. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MessagePage { + pub items: Vec, + pub next_cursor: Option, +} + +#[cfg(any(test, feature = "test-utils"))] +impl SessionMetadata { + /// Test-only: minimal valid session metadata with defaulted fields. + pub fn stub(scroll_id: Uuid) -> Self { + let now = chrono::Utc::now(); + Self { + version: 1, + scroll_id, + created_at: now, + updated_at: now, + title: None, + connector_uid: Uuid::nil(), + native_session_id: None, + agent_id: None, + parent_scroll_id: None, + continuation: None, + tags: vec![], + metadata: serde_json::Value::Null, + no_update: false, + kind: SessionKind::Chat, + acp_client_id: None, + is_connected: None, + current_session_id: None, + models: None, + modes: None, + config_options: None, + completeness: SessionCompleteness::default(), + matrix_room_id: None, + matrix_sharing_active: false, + matrix_shared_at: None, + is_subagent: false, + subagent_type: None, + spawning_tool_use_id: None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::SystemTime; + + #[test] + fn test_continuation_serialization() { + let continuation = Continuation::Split; + let json = serde_json::to_string(&continuation).unwrap(); + assert_eq!(json, r#""SPLIT""#); + + let deserialized: Continuation = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized, continuation); + + // Test all variants + let variants = vec![ + (Continuation::Split, r#""SPLIT""#), + (Continuation::Compact, r#""COMPACT""#), + (Continuation::Reference, r#""REFERENCE""#), + (Continuation::Edit, r#""EDIT""#), + (Continuation::Unknown, r#""UNKNOWN""#), + ]; + + for (variant, expected) in variants { + let json = serde_json::to_string(&variant).unwrap(); + assert_eq!(json, expected); + let deserialized: Continuation = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized, variant); + } + } + + #[test] + fn test_register_status_serialization() { + let status = RegisterStatus::Accepted; + let json = serde_json::to_string(&status).unwrap(); + assert_eq!(json, r#""ACCEPTED""#); + + let deserialized: RegisterStatus = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized, status); + + // Test all variants + let variants = vec![ + (RegisterStatus::Accepted, r#""ACCEPTED""#), + (RegisterStatus::Aliased, r#""ALIASED""#), + (RegisterStatus::Rejected, r#""REJECTED""#), + ]; + + for (variant, expected) in variants { + let json = serde_json::to_string(&variant).unwrap(); + assert_eq!(json, expected); + let deserialized: RegisterStatus = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized, variant); + } + } + + #[test] + fn test_session_metadata_roundtrip() { + let connector_uid = Uuid::now_v7(); + let scroll_id = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let metadata = SessionMetadata { + version: 1, + scroll_id, + created_at: now, + updated_at: now, + title: Some("Test Session".to_string()), + connector_uid, + native_session_id: Some("native-123".to_string()), + agent_id: Some("claude-3-5".to_string()), + parent_scroll_id: None, + continuation: None, + tags: vec!["test".to_string(), "example".to_string()], + metadata: serde_json::json!({ + "source": "OpenCode", + "model": "claude-3-5-sonnet" + }), + no_update: false, + kind: SessionKind::Chat, + acp_client_id: None, + is_connected: None, + current_session_id: None, + models: None, + modes: None, + config_options: None, + completeness: SessionCompleteness::Complete, + matrix_room_id: None, + matrix_sharing_active: false, + matrix_shared_at: None, + is_subagent: false, + subagent_type: None, + spawning_tool_use_id: None, + }; + + // Serialize to JSON + let json = serde_json::to_string_pretty(&metadata).unwrap(); + + // Deserialize back + let deserialized: SessionMetadata = serde_json::from_str(&json).unwrap(); + + // Verify fields + assert_eq!(deserialized.version, 1); + assert_eq!(deserialized.scroll_id, scroll_id); + assert_eq!(deserialized.connector_uid, connector_uid); + assert_eq!(deserialized.title, Some("Test Session".to_string())); + assert_eq!( + deserialized.native_session_id, + Some("native-123".to_string()) + ); + assert_eq!(deserialized.tags, vec!["test", "example"]); + } + + #[test] + fn test_message_record_roundtrip() { + let message_id = Uuid::now_v7(); + let session_id = Uuid::now_v7(); + let parent_id = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let message = MessageRecord { + version: 1, + message_id, + session: session_id, + parent_id: Some(parent_id), + ts: now, + role: "user".to_string(), + author: Some("alice".to_string()), + content_md: "Hello, world!".to_string(), + content_parts: None, + attachments: vec![AttachmentRef { + file_id: "sha256:abc123".to_string(), + name: "spec.pdf".to_string(), + mime_type: Some("application/pdf".to_string()), + }], + metadata: serde_json::json!({ + "connector_msg_id": "native-456" + }), + }; + + // Serialize to JSON + let json = serde_json::to_string(&message).unwrap(); + + // Deserialize back + let deserialized: MessageRecord = serde_json::from_str(&json).unwrap(); + + // Verify fields + assert_eq!(deserialized.version, 1); + assert_eq!(deserialized.message_id, message_id); + assert_eq!(deserialized.session, session_id); + assert_eq!(deserialized.parent_id, Some(parent_id)); + assert_eq!(deserialized.role, "user"); + assert_eq!(deserialized.author, Some("alice".to_string())); + assert_eq!(deserialized.content_md, "Hello, world!"); + assert_eq!(deserialized.attachments.len(), 1); + assert_eq!(deserialized.attachments[0].file_id, "sha256:abc123"); + } + + #[test] + fn test_connector_record_roundtrip() { + let connector_uid = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let connector = ConnectorRecord { + version: 1, + connector_uid, + r#type: "OpenCode".to_string(), + title: "OpenCode Local".to_string(), + client_native_id: "opencode@http://localhost:12225".to_string(), + alias_of: None, + created_at: now, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + // Serialize to JSON + let json = serde_json::to_string_pretty(&connector).unwrap(); + + // Verify "type" field is used (not "r#type") + assert!(json.contains(r#""type""#)); + assert!(!json.contains(r#""r#type""#)); + + // Deserialize back + let deserialized: ConnectorRecord = serde_json::from_str(&json).unwrap(); + + // Verify fields + assert_eq!(deserialized.version, 1); + assert_eq!(deserialized.connector_uid, connector_uid); + assert_eq!(deserialized.r#type, "OpenCode"); + assert_eq!(deserialized.title, "OpenCode Local"); + assert_eq!( + deserialized.client_native_id, + "opencode@http://localhost:12225" + ); + assert_eq!(deserialized.alias_of, None); + } + + #[test] + fn test_session_mapping_roundtrip() { + let connector_uid = Uuid::now_v7(); + let scroll_id = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let mapping = SessionMapping { + version: 1, + connector_uid, + native_session_id: "abc123".to_string(), + scroll_id, + created_at: now, + alias_of: None, + }; + + // Serialize to JSON + let json = serde_json::to_string(&mapping).unwrap(); + + // Deserialize back + let deserialized: SessionMapping = serde_json::from_str(&json).unwrap(); + + // Verify fields + assert_eq!(deserialized.version, 1); + assert_eq!(deserialized.connector_uid, connector_uid); + assert_eq!(deserialized.native_session_id, "abc123"); + assert_eq!(deserialized.scroll_id, scroll_id); + } + + #[test] + fn test_file_record_roundtrip() { + let session1 = Uuid::now_v7(); + let session2 = Uuid::now_v7(); + + let file_record = FileRecord { + version: 1, + file_id: "sha256:abc123def456".to_string(), + path: ".files/ab/cd/abc123def456".to_string(), + size: 123456, + mime: Some("application/pdf".to_string()), + original_name: "spec.pdf".to_string(), + sessions: vec![session1, session2], + metadata: serde_json::json!({ + "source": "upload" + }), + }; + + // Serialize to JSON + let json = serde_json::to_string(&file_record).unwrap(); + + // Deserialize back + let deserialized: FileRecord = serde_json::from_str(&json).unwrap(); + + // Verify fields + assert_eq!(deserialized.version, 1); + assert_eq!(deserialized.file_id, "sha256:abc123def456"); + assert_eq!(deserialized.path, ".files/ab/cd/abc123def456"); + assert_eq!(deserialized.size, 123456); + assert_eq!(deserialized.mime, Some("application/pdf".to_string())); + assert_eq!(deserialized.original_name, "spec.pdf"); + assert_eq!(deserialized.sessions.len(), 2); + } + + #[test] + fn test_uuidv7_generation() { + // Generate several UUIDv7s + let uuid1 = Uuid::now_v7(); + std::thread::sleep(std::time::Duration::from_millis(2)); + let uuid2 = Uuid::now_v7(); + std::thread::sleep(std::time::Duration::from_millis(2)); + let uuid3 = Uuid::now_v7(); + + // Verify they're valid UUIDs + assert_eq!(uuid1.get_version_num(), 7); + assert_eq!(uuid2.get_version_num(), 7); + assert_eq!(uuid3.get_version_num(), 7); + + // Verify they're in time order (UUIDv7 is time-ordered) + // Convert to bytes for comparison + let bytes1 = uuid1.as_bytes(); + let bytes2 = uuid2.as_bytes(); + let bytes3 = uuid3.as_bytes(); + + // UUIDv7 should be sortable by bytes (time-ordered) + assert!(bytes1 < bytes2); + assert!(bytes2 < bytes3); + } + + #[test] + fn test_rfc3339_timestamps() { + let now = DateTime::::from(SystemTime::now()); + + // Serialize to RFC 3339 format + let json = serde_json::to_string(&now).unwrap(); + + // Should be in quotes and RFC 3339 format + assert!(json.starts_with('"')); + assert!(json.ends_with('"')); + assert!(json.contains('T')); + assert!(json.contains('Z') || json.contains('+')); + + // Deserialize back + let deserialized: DateTime = serde_json::from_str(&json).unwrap(); + + // Should be within a second of the original (allowing for microsecond precision loss) + let diff = (now.timestamp_millis() - deserialized.timestamp_millis()).abs(); + assert!(diff < 1000, "Timestamp difference too large: {} ms", diff); + } + + #[test] + fn test_api_request_types() { + // Test RegisterConnectorRequest + let request = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "test@localhost".to_string(), + custom_uid: Some(Uuid::now_v7()), + metadata: serde_json::json!({ "key": "value" }), + fingerprint: None, + }; + + let json = serde_json::to_string(&request).unwrap(); + let deserialized: RegisterConnectorRequest = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.r#type, "OpenCode"); + assert_eq!(deserialized.title, "Test Connector"); + + // Test RegisterSessionRequest + let request = RegisterSessionRequest { + connector_uid: Uuid::now_v7(), + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: Some(Uuid::now_v7()), + metadata: serde_json::json!({}), + completeness: SessionCompleteness::Complete, + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let json = serde_json::to_string(&request).unwrap(); + let deserialized: RegisterSessionRequest = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.native_session_id, "native-123"); + assert_eq!(deserialized.title, Some("Test Session".to_string())); + } + + #[test] + fn test_api_response_types() { + // Test RegisterConnectorResponse + let response = RegisterConnectorResponse { + status: RegisterStatus::Accepted, + connector_uid: Uuid::now_v7(), + alias_of: None, + note: Some("Successfully registered".to_string()), + }; + + let json = serde_json::to_string(&response).unwrap(); + let deserialized: RegisterConnectorResponse = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.status, RegisterStatus::Accepted); + assert_eq!( + deserialized.note, + Some("Successfully registered".to_string()) + ); + + // Test RegisterSessionResponse + let response = RegisterSessionResponse { + status: RegisterStatus::Aliased, + scroll_id: Uuid::now_v7(), + alias_of: Some(Uuid::now_v7()), + }; + + let json = serde_json::to_string(&response).unwrap(); + let deserialized: RegisterSessionResponse = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.status, RegisterStatus::Aliased); + assert!(deserialized.alias_of.is_some()); + } + + #[test] + fn test_session_kind_serialization() { + let kind_chat = SessionKind::Chat; + let json = serde_json::to_string(&kind_chat).unwrap(); + assert_eq!(json, r#""CHAT""#); + + let deserialized: SessionKind = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized, kind_chat); + + // Test all variants + let variants = vec![ + (SessionKind::Chat, r#""CHAT""#), + (SessionKind::AcpConnection, r#""ACP_CONNECTION""#), + ]; + + for (variant, expected) in variants { + let json = serde_json::to_string(&variant).unwrap(); + assert_eq!(json, expected); + let deserialized: SessionKind = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized, variant); + } + } + + #[test] + fn test_session_kind_default() { + let kind: SessionKind = Default::default(); + assert_eq!(kind, SessionKind::Chat); + } + + #[test] + fn test_meta_event_type_serialization() { + let event_type = MetaEventType::ClientConnected; + let json = serde_json::to_string(&event_type).unwrap(); + assert_eq!(json, r#""CLIENT_CONNECTED""#); + + let deserialized: MetaEventType = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized, event_type); + + // Test all variants + let variants = vec![ + (MetaEventType::ClientConnected, r#""CLIENT_CONNECTED""#), + (MetaEventType::ClientDisconnected, r#""CLIENT_DISCONNECTED""#), + (MetaEventType::SessionOpened, r#""SESSION_OPENED""#), + (MetaEventType::SessionSwitched, r#""SESSION_SWITCHED""#), + (MetaEventType::SessionClosed, r#""SESSION_CLOSED""#), + ]; + + for (variant, expected) in variants { + let json = serde_json::to_string(&variant).unwrap(); + assert_eq!(json, expected); + let deserialized: MetaEventType = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized, variant); + } + } + + #[test] + fn test_meta_event_record_roundtrip() { + let event_id = Uuid::now_v7(); + let session = Uuid::now_v7(); + let linked_session = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let event = MetaEventRecord { + version: 1, + event_id, + session, + ts: now, + event_type: MetaEventType::SessionOpened, + description: "Client opened session".to_string(), + linked_session_id: Some(linked_session), + linked_connector_id: Some("connector-123".to_string()), + linked_connector_title: Some("Claude Session".to_string()), + metadata: serde_json::json!({ + "client_version": "1.0.0" + }), + }; + + // Serialize to JSON + let json = serde_json::to_string(&event).unwrap(); + + // Deserialize back + let deserialized: MetaEventRecord = serde_json::from_str(&json).unwrap(); + + // Verify fields + assert_eq!(deserialized.version, 1); + assert_eq!(deserialized.event_id, event_id); + assert_eq!(deserialized.session, session); + assert_eq!(deserialized.event_type, MetaEventType::SessionOpened); + assert_eq!(deserialized.description, "Client opened session"); + assert_eq!(deserialized.linked_session_id, Some(linked_session)); + assert_eq!(deserialized.linked_connector_id, Some("connector-123".to_string())); + assert_eq!(deserialized.linked_connector_title, Some("Claude Session".to_string())); + } + + #[test] + fn test_meta_event_record_minimal() { + let event_id = Uuid::now_v7(); + let session = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let event = MetaEventRecord { + version: 1, + event_id, + session, + ts: now, + event_type: MetaEventType::ClientConnected, + description: "Client connected".to_string(), + linked_session_id: None, + linked_connector_id: None, + linked_connector_title: None, + metadata: serde_json::json!({}), + }; + + // Serialize to JSON + let json = serde_json::to_string(&event).unwrap(); + + // Verify optional fields are not serialized + assert!(!json.contains("linked_session_id")); + assert!(!json.contains("linked_connector_id")); + assert!(!json.contains("linked_connector_title")); + + // Deserialize back + let deserialized: MetaEventRecord = serde_json::from_str(&json).unwrap(); + + // Verify fields + assert_eq!(deserialized.version, 1); + assert_eq!(deserialized.event_id, event_id); + assert_eq!(deserialized.session, session); + assert_eq!(deserialized.event_type, MetaEventType::ClientConnected); + assert_eq!(deserialized.description, "Client connected"); + assert_eq!(deserialized.linked_session_id, None); + assert_eq!(deserialized.linked_connector_id, None); + assert_eq!(deserialized.linked_connector_title, None); + } + + #[test] + fn test_session_metadata_with_meta_fields() { + let connector_uid = Uuid::now_v7(); + let scroll_id = Uuid::now_v7(); + let current_session_id = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + let metadata = SessionMetadata { + version: 1, + scroll_id, + created_at: now, + updated_at: now, + title: Some("ACP Connection".to_string()), + connector_uid, + native_session_id: Some("acp-meta-client-123".to_string()), + agent_id: None, + parent_scroll_id: None, + continuation: None, + tags: vec![], + metadata: serde_json::json!({}), + no_update: false, + kind: SessionKind::AcpConnection, + acp_client_id: Some("client-123".to_string()), + is_connected: Some(true), + current_session_id: Some(current_session_id), + models: None, + modes: None, + config_options: None, + completeness: SessionCompleteness::Complete, + matrix_room_id: None, + matrix_sharing_active: false, + matrix_shared_at: None, + is_subagent: false, + subagent_type: None, + spawning_tool_use_id: None, + }; + + // Serialize to JSON + let json = serde_json::to_string_pretty(&metadata).unwrap(); + + // Deserialize back + let deserialized: SessionMetadata = serde_json::from_str(&json).unwrap(); + + // Verify fields + assert_eq!(deserialized.version, 1); + assert_eq!(deserialized.scroll_id, scroll_id); + assert_eq!(deserialized.kind, SessionKind::AcpConnection); + assert_eq!(deserialized.acp_client_id, Some("client-123".to_string())); + assert_eq!(deserialized.is_connected, Some(true)); + assert_eq!(deserialized.current_session_id, Some(current_session_id)); + } + + #[test] + fn test_session_metadata_chat_defaults() { + let connector_uid = Uuid::now_v7(); + let scroll_id = Uuid::now_v7(); + let now = DateTime::::from(SystemTime::now()); + + // Simulate deserializing old session metadata without new fields + let json = serde_json::json!({ + "version": 1, + "scroll_id": scroll_id, + "created_at": now, + "updated_at": now, + "title": "Old Session", + "connector_uid": connector_uid, + "tags": [], + "metadata": {} + }); + + let deserialized: SessionMetadata = serde_json::from_value(json).unwrap(); + + // Verify new fields use defaults + assert_eq!(deserialized.kind, SessionKind::Chat); + assert_eq!(deserialized.acp_client_id, None); + assert_eq!(deserialized.is_connected, None); + assert_eq!(deserialized.current_session_id, None); + } +} diff --git a/crates/dirigent_archivist/tests/archive_filter_test.rs b/crates/dirigent_archivist/tests/archive_filter_test.rs new file mode 100644 index 0000000..055db13 --- /dev/null +++ b/crates/dirigent_archivist/tests/archive_filter_test.rs @@ -0,0 +1,334 @@ +//! Two-archive fanout tests exercising `ArchiveFilter` semantics. +//! +//! The primary backend is unfiltered; the secondary backend carries a +//! restricted filter. Writes should always reach the primary but only +//! fan out to the secondary when the session passes the filter. + +#![cfg(feature = "test-utils")] + +use std::collections::HashSet; +use std::sync::Arc; + +use chrono::Utc; +use uuid::Uuid; + +use dirigent_archivist::backend::mock::MockBackend; +use dirigent_archivist::backend::{ArchiveBackend, HealthStatus}; +use dirigent_archivist::coordinator::Archivist; +use dirigent_archivist::registry::{ + ArchiveFilter, ArchiveRegistration, FailureMode, WritePolicy, +}; +use dirigent_archivist::types::{ + ConnectorRecord, MessageRecord, RegisterSessionRequest, +}; + +fn reg( + name: &str, + backend: Arc, + priority: u32, + filter: ArchiveFilter, +) -> Arc { + Arc::new( + ArchiveRegistration::new( + name.into(), + "mock", + backend as Arc, + /* write_active */ true, + FailureMode::Required, + priority, + /* enabled */ true, + WritePolicy::Inline, + /* writer */ None, + HealthStatus::Healthy, + ) + .with_filter(filter), + ) +} + +/// Seed a connector into a MockBackend directly, bypassing the coordinator. +async fn seed_connector(backend: &MockBackend, connector_uid: Uuid, client_native_id: &str) { + use dirigent_archivist::backend::ConnectorRegistryBackend; + let rec = ConnectorRecord { + version: 1, + connector_uid, + r#type: "Mock".into(), + title: "Mock connector".into(), + client_native_id: client_native_id.into(), + alias_of: None, + created_at: Utc::now(), + metadata: serde_json::Value::Null, + fingerprint: None, + }; + backend + .put_connector(rec) + .await + .expect("put_connector succeeds"); +} + +fn make_msg(session: Uuid, n: u32) -> MessageRecord { + MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session, + parent_id: None, + ts: Utc::now(), + role: "user".into(), + author: None, + content_md: format!("msg {}", n), + content_parts: None, + attachments: vec![], + metadata: serde_json::Value::Null, + } +} + +#[tokio::test] +async fn secondary_archive_filters_by_exclude_connector() { + let primary_backend = Arc::new(MockBackend::new()); + let secondary_backend = Arc::new(MockBackend::new()); + + let connector_a = Uuid::now_v7(); + let connector_b = Uuid::now_v7(); + + // Connector A is excluded from the secondary. + let mut excluded = HashSet::new(); + excluded.insert(connector_a); + let secondary_filter = ArchiveFilter { + exclude_connectors: excluded, + ..Default::default() + }; + + // Seed connectors on primary (and on secondary so mapping writes that DO + // pass the filter don't fail for unrelated reasons). + seed_connector(&primary_backend, connector_a, "native/a").await; + seed_connector(&primary_backend, connector_b, "native/b").await; + + let archivist = Archivist::from_registrations(vec![ + reg("primary", primary_backend.clone(), 0, ArchiveFilter::default()), + reg("secondary", secondary_backend.clone(), 10, secondary_filter), + ]); + + // Register a session for each connector. + let resp_a = archivist + .register_session( + RegisterSessionRequest { + connector_uid: connector_a, + native_session_id: "sess-a".into(), + title: None, + custom_scroll_id: None, + metadata: serde_json::Value::Null, + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await + .expect("register session a"); + let scroll_a = resp_a.scroll_id; + + let resp_b = archivist + .register_session( + RegisterSessionRequest { + connector_uid: connector_b, + native_session_id: "sess-b".into(), + title: None, + custom_scroll_id: None, + metadata: serde_json::Value::Null, + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await + .expect("register session b"); + let scroll_b = resp_b.scroll_id; + + // Append 3 messages to each session. + archivist + .append_messages( + scroll_a, + vec![make_msg(scroll_a, 1), make_msg(scroll_a, 2), make_msg(scroll_a, 3)], + None, + ) + .await + .expect("append to a"); + archivist + .append_messages( + scroll_b, + vec![make_msg(scroll_b, 1), make_msg(scroll_b, 2), make_msg(scroll_b, 3)], + None, + ) + .await + .expect("append to b"); + + // Primary sees every message. + assert_eq!(primary_backend.appended_count(scroll_a), 3); + assert_eq!(primary_backend.appended_count(scroll_b), 3); + + // Secondary excludes connector_a: scroll_a is filtered out, + // scroll_b is replicated. + assert_eq!( + secondary_backend.appended_count(scroll_a), + 0, + "secondary should NOT receive messages for the excluded connector" + ); + assert_eq!( + secondary_backend.appended_count(scroll_b), + 3, + "secondary should receive messages for the allowed connector" + ); + + // Session metadata fanout follows the same rule. + assert!( + primary_backend + .get_session(scroll_a) + .await + .unwrap() + .is_some(), + "primary has scroll_a" + ); + assert!( + secondary_backend + .get_session(scroll_a) + .await + .unwrap() + .is_none(), + "secondary should NOT have scroll_a (excluded connector)" + ); + assert!( + secondary_backend + .get_session(scroll_b) + .await + .unwrap() + .is_some(), + "secondary should have scroll_b (allowed connector)" + ); +} + +#[tokio::test] +async fn secondary_archive_filters_by_include_tag() { + let primary_backend = Arc::new(MockBackend::new()); + let secondary_backend = Arc::new(MockBackend::new()); + + let connector = Uuid::now_v7(); + seed_connector(&primary_backend, connector, "native/tagged").await; + + let mut include = HashSet::new(); + include.insert("prod".to_string()); + let secondary_filter = ArchiveFilter { + include_tags: include, + ..Default::default() + }; + + let archivist = Archivist::from_registrations(vec![ + reg("primary", primary_backend.clone(), 0, ArchiveFilter::default()), + reg("secondary", secondary_backend.clone(), 10, secondary_filter), + ]); + + // Register two sessions on the same connector. + let prod_resp = archivist + .register_session( + RegisterSessionRequest { + connector_uid: connector, + native_session_id: "sess-prod".into(), + title: None, + custom_scroll_id: None, + metadata: serde_json::Value::Null, + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await + .expect("register prod session"); + let scroll_prod = prod_resp.scroll_id; + + let dev_resp = archivist + .register_session( + RegisterSessionRequest { + connector_uid: connector, + native_session_id: "sess-dev".into(), + title: None, + custom_scroll_id: None, + metadata: serde_json::Value::Null, + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await + .expect("register dev session"); + let scroll_dev = dev_resp.scroll_id; + + // Tag the prod session directly on the primary so the coordinator can + // see it on the next fanout metadata lookup. We mutate via the primary + // backend to avoid going through update_session_metadata (which doesn't + // expose a tag API). + { + use dirigent_archivist::backend::ArchiveBackend as _; + let mut md = primary_backend + .get_session(scroll_prod) + .await + .unwrap() + .expect("prod session on primary"); + md.tags.push("prod".into()); + primary_backend.put_session(md).await.unwrap(); + } + + // Append messages AFTER tagging — now the filter consults the tagged metadata. + archivist + .append_messages( + scroll_prod, + vec![ + make_msg(scroll_prod, 1), + make_msg(scroll_prod, 2), + make_msg(scroll_prod, 3), + ], + None, + ) + .await + .expect("append prod"); + archivist + .append_messages( + scroll_dev, + vec![make_msg(scroll_dev, 1), make_msg(scroll_dev, 2), make_msg(scroll_dev, 3)], + None, + ) + .await + .expect("append dev"); + + // Primary keeps both. + assert_eq!(primary_backend.appended_count(scroll_prod), 3); + assert_eq!(primary_backend.appended_count(scroll_dev), 3); + + // Secondary only keeps the tagged session. + assert_eq!( + secondary_backend.appended_count(scroll_prod), + 3, + "secondary receives messages for the `prod`-tagged session" + ); + assert_eq!( + secondary_backend.appended_count(scroll_dev), + 0, + "secondary rejects the untagged session" + ); +} diff --git a/crates/dirigent_archivist/tests/fixtures/claude_minimal/projects/-home-user-myproj/abc12345-1234-1234-1234-abcdef123456.jsonl b/crates/dirigent_archivist/tests/fixtures/claude_minimal/projects/-home-user-myproj/abc12345-1234-1234-1234-abcdef123456.jsonl new file mode 100644 index 0000000..86e5b37 --- /dev/null +++ b/crates/dirigent_archivist/tests/fixtures/claude_minimal/projects/-home-user-myproj/abc12345-1234-1234-1234-abcdef123456.jsonl @@ -0,0 +1,2 @@ +{"type":"user","uuid":"11111111-1111-7111-8111-111111111111","parentUuid":null,"timestamp":"2024-01-01T00:00:00Z","sessionId":"abc12345-1234-1234-1234-abcdef123456","cwd":"/home/user/myproj","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"userType":"external","message":{"role":"user","content":"hello"}} +{"type":"assistant","uuid":"22222222-2222-7222-8222-222222222222","parentUuid":"11111111-1111-7111-8111-111111111111","timestamp":"2024-01-01T00:00:01Z","sessionId":"abc12345-1234-1234-1234-abcdef123456","cwd":"/home/user/myproj","version":"2.1.71","gitBranch":"main","isSidechain":false,"requestId":"req-001","message":{"model":"claude-3-5-sonnet","id":"msg-abc","type":"message","role":"assistant","content":[{"type":"text","text":"hi back"}],"stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}} diff --git a/crates/dirigent_archivist/tests/import_claude_idempotency_test.rs b/crates/dirigent_archivist/tests/import_claude_idempotency_test.rs new file mode 100644 index 0000000..da3445c --- /dev/null +++ b/crates/dirigent_archivist/tests/import_claude_idempotency_test.rs @@ -0,0 +1,153 @@ +//! End-to-end test: import a Claude fixture twice, expect no duplication; +//! then append a new message and re-import, expect exactly 1 new message. + +use camino::Utf8PathBuf; +use dirigent_archivist::{ + backends::JsonlBackend, + import::{claude::import_claude_sessions, ImportProgressSink}, + Archivist, SessionListQuery, +}; +use std::sync::Arc; +use uuid::Uuid; + +fn fixture_root() -> Utf8PathBuf { + Utf8PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap()) + .join("tests/fixtures/claude_minimal") +} + +/// Build a self-contained coordinator for a given archive root. +/// +/// Uses `from_single_backend` so that parallel-test runs do not race on a +/// shared `.archives.json` in the tempdir's parent (which is what +/// `new_with_single_archive` would create). +async fn mk_archivist(root: std::path::PathBuf) -> dirigent_archivist::Result { + let backend = Arc::new(JsonlBackend::new(root).await?); + Archivist::from_single_backend("main".into(), backend).await +} + +#[tokio::test] +async fn claude_import_twice_is_idempotent() -> dirigent_archivist::Result<()> { + let tmp = std::env::temp_dir().join(format!("claude_idem_{}", Uuid::now_v7())); + let archivist = mk_archivist(tmp.clone()).await?; + + let fixture = fixture_root(); + + // First run — should import everything. + let stats1 = import_claude_sessions(&archivist, &fixture, None, &ImportProgressSink::noop(), &std::collections::HashMap::new()).await?; + assert!( + stats1.sessions_imported >= 1, + "expected at least one imported session, got stats {:?}", + stats1 + ); + assert!( + stats1.messages_written >= 2, + "expected >=2 messages written, got {:?}", + stats1 + ); + + // Second run — should write nothing (fingerprint gate skips unchanged sessions). + let stats2 = import_claude_sessions(&archivist, &fixture, None, &ImportProgressSink::noop(), &std::collections::HashMap::new()).await?; + assert_eq!( + stats2.messages_written, 0, + "expected no re-write on second import, got {:?}", + stats2 + ); + assert_eq!(stats2.sessions_imported, 0); + assert!( + stats2.sessions_skipped >= 1, + "expected at least one skipped session, got {:?}", + stats2 + ); + + // Verify on disk: no duplicate message_ids within any session. + let page = archivist + .list_sessions_paged(SessionListQuery::default().with_limit(200)) + .await?; + for session in &page.items { + let messages = archivist.get_messages(session.scroll_id, None).await?; + let mut seen = std::collections::HashSet::new(); + for m in &messages { + assert!( + seen.insert(m.message_id), + "duplicate message_id {} in session {}", + m.message_id, + session.scroll_id + ); + } + } + + let _ = tokio::fs::remove_dir_all(tmp).await; + Ok(()) +} + +#[tokio::test] +async fn claude_import_picks_up_additive_growth() -> dirigent_archivist::Result<()> { + // Copy the fixture to a mutable temp dir so we can append a message. + let tmp_src = std::env::temp_dir().join(format!("claude_grow_src_{}", Uuid::now_v7())); + let fixture = fixture_root(); + copy_dir_recursive(&fixture.as_std_path().to_path_buf(), &tmp_src).await; + + let tmp_arch = std::env::temp_dir().join(format!("claude_grow_arch_{}", Uuid::now_v7())); + let archivist = mk_archivist(tmp_arch.clone()).await?; + + let src = Utf8PathBuf::from_path_buf(tmp_src.clone()).unwrap(); + let _ = import_claude_sessions(&archivist, &src, None, &ImportProgressSink::noop(), &std::collections::HashMap::new()).await?; + + // Append a new message to the existing JSONL. + let jsonl = find_jsonl(&tmp_src).expect("fixture jsonl not found"); + let extra = r#"{"type":"user","uuid":"33333333-3333-7333-8333-333333333333","parentUuid":"22222222-2222-7222-8222-222222222222","timestamp":"2024-01-01T00:00:02Z","sessionId":"abc12345-1234-1234-1234-abcdef123456","cwd":"/home/user/myproj","version":"2.1.71","gitBranch":"main","isSidechain":false,"isMeta":false,"userType":"external","message":{"role":"user","content":"follow up"}}"#; + use tokio::io::AsyncWriteExt; + let mut f = tokio::fs::OpenOptions::new() + .append(true) + .open(&jsonl) + .await + .unwrap(); + f.write_all(extra.as_bytes()).await.unwrap(); + f.write_all(b"\n").await.unwrap(); + drop(f); + + let stats = import_claude_sessions(&archivist, &src, None, &ImportProgressSink::noop(), &std::collections::HashMap::new()).await?; + assert_eq!( + stats.messages_written, 1, + "expected 1 new message to be imported, got {:?}", + stats + ); + assert_eq!( + stats.sessions_updated, 1, + "expected 1 session updated, got {:?}", + stats + ); + + let _ = tokio::fs::remove_dir_all(tmp_src).await; + let _ = tokio::fs::remove_dir_all(tmp_arch).await; + Ok(()) +} + +async fn copy_dir_recursive(src: &std::path::Path, dst: &std::path::Path) { + tokio::fs::create_dir_all(dst).await.unwrap(); + let mut stack = vec![(src.to_path_buf(), dst.to_path_buf())]; + while let Some((s, d)) = stack.pop() { + let mut entries = tokio::fs::read_dir(&s).await.unwrap(); + while let Some(entry) = entries.next_entry().await.unwrap() { + let from = entry.path(); + let to = d.join(entry.file_name()); + if entry.file_type().await.unwrap().is_dir() { + tokio::fs::create_dir_all(&to).await.unwrap(); + stack.push((from, to)); + } else { + tokio::fs::copy(&from, &to).await.unwrap(); + } + } + } +} + +fn find_jsonl(dir: &std::path::Path) -> Option { + for entry in walkdir::WalkDir::new(dir).into_iter().flatten() { + if entry.file_type().is_file() + && entry.path().extension().and_then(|s| s.to_str()) == Some("jsonl") + { + return Some(entry.path().to_path_buf()); + } + } + None +} diff --git a/crates/dirigent_archivist/tests/import_progress_test.rs b/crates/dirigent_archivist/tests/import_progress_test.rs new file mode 100644 index 0000000..15176bd --- /dev/null +++ b/crates/dirigent_archivist/tests/import_progress_test.rs @@ -0,0 +1,89 @@ +//! Integration test: importer trait progress events fire in expected order. +//! +//! Drives a full `ChatGptImporter::import` against a fixture and asserts on +//! the `ImportProgressEvent` sequence observed on the paired receiver. + +use std::sync::Arc; +use tempfile::TempDir; + +use dirigent_archivist::{ + backends::JsonlBackend, + coordinator::Archivist, + import::{ + ImportConfig, ImportProgressEvent, ImportProgressSink, ImportTarget, ImporterRegistry, + }, +}; + +#[tokio::test] +async fn progress_event_sequence_is_well_formed() { + // 1. Setup an in-memory archivist (JsonlBackend in tempdir). + let dir = TempDir::new().unwrap(); + let backend = Arc::new(JsonlBackend::new(dir.path().to_path_buf()).await.unwrap()); + let archivist = Archivist::from_single_backend("main".into(), backend) + .await + .unwrap(); + let archivist = Arc::new(archivist); + + // 2. Use the chatgpt fixture — a minimal conversations.json with a + // user + assistant message pair. + let fixture = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("../dirigent_chatgpt/tests/fixtures/minimal.json"); + assert!( + fixture.exists(), + "chatgpt fixture missing at {}", + fixture.display() + ); + + let cfg = ImportConfig { + source: "chatgpt".into(), + params: { + let mut m = std::collections::BTreeMap::new(); + m.insert("path".into(), serde_json::json!(fixture.display().to_string())); + m + }, + }; + + // 3. Run the import with a channel sink. + let registry = ImporterRegistry::default(); + let importer = registry.get("chatgpt").expect("chatgpt registered"); + let (sink, mut rx) = ImportProgressSink::channel(); + + let archivist_for_job = archivist.clone(); + let job = tokio::spawn(async move { + importer + .import(&cfg, &*archivist_for_job, ImportTarget::default(), sink) + .await + }); + + // 4. Collect all events until the sender side is dropped. + let mut events = Vec::new(); + while let Some(evt) = rx.recv().await { + events.push(evt); + } + let stats = job.await.unwrap().expect("import"); + + // 5. Assertions on the event sequence. + // Must contain at least one SessionStarted before any SessionFinished. + let started_idx = events + .iter() + .position(|e| matches!(e, ImportProgressEvent::SessionStarted { .. })); + let finished_idx = events + .iter() + .position(|e| matches!(e, ImportProgressEvent::SessionFinished { .. })); + + assert!(started_idx.is_some(), "expected a SessionStarted event"); + assert!(finished_idx.is_some(), "expected a SessionFinished event"); + assert!( + started_idx.unwrap() < finished_idx.unwrap(), + "SessionStarted must precede SessionFinished" + ); + + // Stats shows at least 2 messages written (chatgpt fixture has a user + // + assistant pair). + assert!( + stats.messages_written >= 2, + "expected messages to be written, got stats {:?}", + stats + ); + assert_eq!(stats.sessions_imported, 1); +} diff --git a/crates/dirigent_archivist/tests/integration_tests.rs b/crates/dirigent_archivist/tests/integration_tests.rs new file mode 100644 index 0000000..266cf56 --- /dev/null +++ b/crates/dirigent_archivist/tests/integration_tests.rs @@ -0,0 +1,2414 @@ +//! Integration tests for dirigent_archivist +//! +//! These tests verify the end-to-end functionality of the archivist, +//! including storage, retrieval, and event streaming. + +#[cfg(test)] +mod tests { + use chrono::Utc; + use dirigent_archivist::{ + Archivist, MessageRecord, RegisterConnectorRequest, + RegisterSessionRequest, RegisterStatus, Result, SessionKind, SessionListQuery, + SessionMetadata, + }; + use dirigent_archivist::storage::ndjson::append_ndjson; + use dirigent_archivist::storage::json::write_json; + use uuid::Uuid; + + #[tokio::test] + async fn test_archivist_creation() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Verify coordinator construction succeeded (smoke test). + let _ = &archivist; + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_register_connector_acceptance() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + let req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let response = archivist.register_connector(req, None).await?; + assert_eq!(response.status, RegisterStatus::Accepted); + assert!(response.alias_of.is_none()); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_register_connector_aliasing() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + let req1 = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let response1 = archivist.register_connector(req1, None).await?; + assert_eq!(response1.status, RegisterStatus::Accepted); + + // Register again with same client_native_id + let req2 = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector 2".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let response2 = archivist.register_connector(req2, None).await?; + assert_eq!(response2.status, RegisterStatus::Aliased); + assert_eq!(response2.connector_uid, response1.connector_uid); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_register_session_acceptance() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector first + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + // Register session + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + assert_eq!(session_response.status, RegisterStatus::Accepted); + assert!(session_response.alias_of.is_none()); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_register_session_aliasing() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + // Register session + let session_req1 = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response1 = archivist.register_session(session_req1, None).await?; + + // Register again with same native_session_id + let session_req2 = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session 2".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response2 = archivist.register_session(session_req2, None).await?; + assert_eq!(session_response2.status, RegisterStatus::Aliased); + assert_eq!(session_response2.scroll_id, session_response1.scroll_id); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_append_and_get_messages() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Create and append messages + let message1 = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: None, + ts: Utc::now(), + role: "user".to_string(), + author: Some("test".to_string()), + content_md: "Hello, world!".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + let message2 = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: Some(message1.message_id), + ts: Utc::now(), + role: "assistant".to_string(), + author: Some("assistant".to_string()), + content_md: "Hi there!".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + archivist + .append_messages( + session_response.scroll_id, + vec![message1.clone(), message2.clone()], + None, + ) + .await?; + + // Retrieve messages + let messages = archivist.get_messages(session_response.scroll_id, None).await?; + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].message_id, message1.message_id); + assert_eq!(messages[1].message_id, message2.message_id); + assert_eq!(messages[0].content_md, "Hello, world!"); + assert_eq!(messages[1].content_md, "Hi there!"); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_list_sessions() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + // Register multiple sessions + let session_req1 = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-1".to_string(), + title: Some("Session 1".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response1 = archivist.register_session(session_req1, None).await?; + + // Wait a moment to ensure different timestamps + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + let session_req2 = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-2".to_string(), + title: Some("Session 2".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response2 = archivist.register_session(session_req2, None).await?; + + // List sessions + let sessions = archivist + .list_sessions_paged( + SessionListQuery::default() + .with_connector(connector_response.connector_uid) + .with_limit(100), + ) + .await? + .items; + assert_eq!(sessions.len(), 2); + + // Verify sessions are sorted by updated_at descending (newest first) + // Session 2 should be first because it was created later + assert_eq!(sessions[0].scroll_id, session_response2.scroll_id); + assert_eq!(sessions[1].scroll_id, session_response1.scroll_id); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_get_session_metadata() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Get session metadata + let metadata = archivist + .get_session_metadata(session_response.scroll_id, None) + .await?; + + assert_eq!(metadata.scroll_id, session_response.scroll_id); + assert_eq!(metadata.title, Some("Test Session".to_string())); + assert_eq!(metadata.connector_uid, connector_response.connector_uid); + assert_eq!(metadata.native_session_id, Some("native-123".to_string())); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_resolve_session() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Resolve native session ID to scroll ID + let scroll_id = archivist + .resolve_session(connector_response.connector_uid, "native-123", None) + .await?; + + assert_eq!(scroll_id, session_response.scroll_id); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_register_connector_custom_uid_collision() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + let custom_uid = Uuid::now_v7(); + + // Register connector with custom UID + let req1 = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector 1".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: Some(custom_uid), + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let response1 = archivist.register_connector(req1, None).await?; + assert_eq!(response1.status, RegisterStatus::Accepted); + assert_eq!(response1.connector_uid, custom_uid); + + // Try to register another connector with same custom UID + let req2 = RegisterConnectorRequest { + r#type: "ACP".to_string(), + title: "Test Connector 2".to_string(), + client_native_id: "acp@localhost:3000".to_string(), + custom_uid: Some(custom_uid), + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let result2 = archivist.register_connector(req2, None).await; + assert!(result2.is_err(), "Expected error for custom_uid collision"); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_register_session_with_unknown_connector() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Try to register session with unknown connector + let unknown_connector = Uuid::now_v7(); + let session_req = RegisterSessionRequest { + connector_uid: unknown_connector, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let result = archivist.register_session(session_req, None).await; + assert!(result.is_err()); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_get_messages_empty_session() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Get messages from empty session + let messages = archivist.get_messages(session_response.scroll_id, None).await?; + assert_eq!(messages.len(), 0); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_get_messages_unknown_scroll_id() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Try to get messages from unknown session + let unknown_scroll_id = Uuid::now_v7(); + let messages = archivist.get_messages(unknown_scroll_id, None).await?; + + // Should return empty vector for unknown session + assert_eq!(messages.len(), 0); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_multiple_message_appends() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Append messages in multiple batches + let message1 = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: None, + ts: Utc::now(), + role: "user".to_string(), + author: Some("test".to_string()), + content_md: "First message".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + archivist + .append_messages(session_response.scroll_id, vec![message1.clone()], None) + .await?; + + let message2 = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: Some(message1.message_id), + ts: Utc::now(), + role: "assistant".to_string(), + author: Some("assistant".to_string()), + content_md: "Second message".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + archivist + .append_messages(session_response.scroll_id, vec![message2.clone()], None) + .await?; + + let message3 = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: Some(message2.message_id), + ts: Utc::now(), + role: "user".to_string(), + author: Some("test".to_string()), + content_md: "Third message".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + archivist + .append_messages(session_response.scroll_id, vec![message3.clone()], None) + .await?; + + // Retrieve all messages + let messages = archivist.get_messages(session_response.scroll_id, None).await?; + assert_eq!(messages.len(), 3); + assert_eq!(messages[0].content_md, "First message"); + assert_eq!(messages[1].content_md, "Second message"); + assert_eq!(messages[2].content_md, "Third message"); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_messages_sorted_chronologically() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Create messages with specific timestamps in chronological order + use chrono::TimeZone; + let base_time = Utc.with_ymd_and_hms(2025, 11, 18, 18, 23, 36).unwrap(); + + let msg_snake_user = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: None, + ts: base_time + chrono::Duration::milliseconds(947), + role: "user".to_string(), + author: Some("user".to_string()), + content_md: "hello please tell me a joke about snakes".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + let msg_snake_assistant = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: Some(msg_snake_user.message_id), + ts: base_time + chrono::Duration::milliseconds(969), + role: "assistant".to_string(), + author: Some("claude".to_string()), + content_md: "Why don't snakes need cutlery? They have forked tongues!".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + let msg_tiger_user = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: Some(msg_snake_assistant.message_id), + ts: base_time + chrono::Duration::milliseconds(13429), + role: "user".to_string(), + author: Some("user".to_string()), + content_md: "now one about tigers".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + let msg_tiger_assistant = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: Some(msg_tiger_user.message_id), + ts: base_time + chrono::Duration::milliseconds(13448), + role: "assistant".to_string(), + author: Some("claude".to_string()), + content_md: "What do tigers wear to bed? Striped pajamas!".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + let msg_hyena_user = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: Some(msg_tiger_assistant.message_id), + ts: base_time + chrono::Duration::milliseconds(32623), + role: "user".to_string(), + author: Some("user".to_string()), + content_md: "and a third one about hyenas".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + // Append messages OUT OF ORDER to simulate real-world event arrival + // (assistant replies often arrive after subsequent user messages) + archivist + .append_messages( + session_response.scroll_id, + vec![msg_snake_user.clone()], + None, + ) + .await?; + + archivist + .append_messages( + session_response.scroll_id, + vec![msg_tiger_user.clone()], + None, + ) + .await?; + + archivist + .append_messages( + session_response.scroll_id, + vec![msg_snake_assistant.clone()], + None, + ) + .await?; + + archivist + .append_messages( + session_response.scroll_id, + vec![msg_hyena_user.clone()], + None, + ) + .await?; + + archivist + .append_messages( + session_response.scroll_id, + vec![msg_tiger_assistant.clone()], + None, + ) + .await?; + + // Retrieve messages - should be sorted chronologically despite out-of-order appends + let messages = archivist.get_messages(session_response.scroll_id, None).await?; + + assert_eq!(messages.len(), 5); + + // Verify chronological order by timestamp + assert_eq!(messages[0].message_id, msg_snake_user.message_id); + assert_eq!(messages[0].content_md, "hello please tell me a joke about snakes"); + + assert_eq!(messages[1].message_id, msg_snake_assistant.message_id); + assert_eq!(messages[1].content_md, "Why don't snakes need cutlery? They have forked tongues!"); + + assert_eq!(messages[2].message_id, msg_tiger_user.message_id); + assert_eq!(messages[2].content_md, "now one about tigers"); + + assert_eq!(messages[3].message_id, msg_tiger_assistant.message_id); + assert_eq!(messages[3].content_md, "What do tigers wear to bed? Striped pajamas!"); + + assert_eq!(messages[4].message_id, msg_hyena_user.message_id); + assert_eq!(messages[4].content_md, "and a third one about hyenas"); + + // Verify timestamps are strictly increasing + for i in 1..messages.len() { + assert!( + messages[i].ts >= messages[i - 1].ts, + "Messages not in chronological order: message {} has ts {} which is before message {} with ts {}", + i, + messages[i].ts, + i - 1, + messages[i - 1].ts + ); + } + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_messages_with_identical_timestamps() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Create multiple messages with the exact same timestamp + // This tests the secondary sorting by message_id + let same_timestamp = Utc::now(); + + // Create messages with explicitly ordered UUIDs (v7 includes timestamp) + // Sleep briefly between creations to ensure UUIDv7 ordering + let msg1 = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: None, + ts: same_timestamp, + role: "user".to_string(), + author: Some("user".to_string()), + content_md: "First message".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + tokio::time::sleep(tokio::time::Duration::from_micros(1)).await; + + let msg2 = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: Some(msg1.message_id), + ts: same_timestamp, + role: "assistant".to_string(), + author: Some("assistant".to_string()), + content_md: "Second message".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + tokio::time::sleep(tokio::time::Duration::from_micros(1)).await; + + let msg3 = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: Some(msg2.message_id), + ts: same_timestamp, + role: "user".to_string(), + author: Some("user".to_string()), + content_md: "Third message".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + // Append in reverse order to ensure sorting is working + archivist + .append_messages( + session_response.scroll_id, + vec![msg3.clone()], + None, + ) + .await?; + + archivist + .append_messages( + session_response.scroll_id, + vec![msg1.clone()], + None, + ) + .await?; + + archivist + .append_messages( + session_response.scroll_id, + vec![msg2.clone()], + None, + ) + .await?; + + // Retrieve messages - should be sorted by message_id since timestamps are identical + let messages = archivist.get_messages(session_response.scroll_id, None).await?; + + assert_eq!(messages.len(), 3); + + // All timestamps should be the same + assert_eq!(messages[0].ts, same_timestamp); + assert_eq!(messages[1].ts, same_timestamp); + assert_eq!(messages[2].ts, same_timestamp); + + // Messages should be ordered by message_id (UUIDv7 preserves creation order) + assert_eq!(messages[0].message_id, msg1.message_id); + assert_eq!(messages[1].message_id, msg2.message_id); + assert_eq!(messages[2].message_id, msg3.message_id); + + assert_eq!(messages[0].content_md, "First message"); + assert_eq!(messages[1].content_md, "Second message"); + assert_eq!(messages[2].content_md, "Third message"); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_append_messages_updates_timestamp() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Get initial metadata + let metadata_before = archivist + .get_session_metadata(session_response.scroll_id, None) + .await?; + + // Wait a moment + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Append a message + let message = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: None, + ts: Utc::now(), + role: "user".to_string(), + author: Some("test".to_string()), + content_md: "Hello!".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + + archivist + .append_messages(session_response.scroll_id, vec![message], None) + .await?; + + // Get updated metadata + let metadata_after = archivist + .get_session_metadata(session_response.scroll_id, None) + .await?; + + // Verify updated_at changed + assert!(metadata_after.updated_at > metadata_before.updated_at); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + // ======================================================================== + // Performance Benchmarks + // These tests are marked with #[ignore] to avoid slowing down regular test runs + // Run with: cargo test --package dirigent_archivist -- --ignored + // ======================================================================== + + #[tokio::test] + #[ignore] + async fn bench_append_1000_messages() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_bench_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Bench Connector".to_string(), + client_native_id: "bench@localhost".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "bench-session".to_string(), + title: Some("Bench Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Create 1000 messages + let messages: Vec = (0..1000) + .map(|i| MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: None, + ts: Utc::now(), + role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(), + author: Some("bench".to_string()), + content_md: format!("Message number {} with some realistic content that might appear in a conversation. This helps simulate real-world usage patterns.", i), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({"index": i}), + }) + .collect(); + + // Benchmark appending messages + let start = std::time::Instant::now(); + archivist + .append_messages(session_response.scroll_id, messages, None) + .await?; + let elapsed = start.elapsed(); + + let messages_per_sec = 1000.0 / elapsed.as_secs_f64(); + println!("\nBenchmark: Append 1000 messages"); + println!(" Total time: {:?}", elapsed); + println!(" Messages/sec: {:.2}", messages_per_sec); + println!(" Avg time per message: {:?}", elapsed / 1000); + + // Target: >100 msg/s + assert!( + messages_per_sec > 100.0, + "Performance degraded: {:.2} msg/s < 100 msg/s", + messages_per_sec + ); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + #[ignore] + async fn bench_read_100_messages() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_bench_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Bench Connector".to_string(), + client_native_id: "bench@localhost".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "bench-session-100".to_string(), + title: Some("Bench Session 100".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Create and append 100 messages + let messages: Vec = (0..100) + .map(|i| MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: None, + ts: Utc::now(), + role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(), + author: Some("bench".to_string()), + content_md: format!("Message {}", i), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }) + .collect(); + + archivist + .append_messages(session_response.scroll_id, messages, None) + .await?; + + // Benchmark reading messages + let start = std::time::Instant::now(); + let retrieved = archivist.get_messages(session_response.scroll_id, None).await?; + let elapsed = start.elapsed(); + + println!("\nBenchmark: Read 100 messages"); + println!(" Total time: {:?}", elapsed); + println!(" Messages retrieved: {}", retrieved.len()); + + // Target: sub-100ms for typical sessions + assert!( + elapsed.as_millis() < 100, + "Read performance degraded: {:?} > 100ms", + elapsed + ); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + #[ignore] + async fn bench_read_1000_messages() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_bench_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector and session + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Bench Connector".to_string(), + client_native_id: "bench@localhost".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "bench-session-1000".to_string(), + title: Some("Bench Session 1000".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + let session_response = archivist.register_session(session_req, None).await?; + + // Create and append 1000 messages + let messages: Vec = (0..1000) + .map(|i| MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: session_response.scroll_id, + parent_id: None, + ts: Utc::now(), + role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(), + author: Some("bench".to_string()), + content_md: format!("Message {}", i), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }) + .collect(); + + archivist + .append_messages(session_response.scroll_id, messages, None) + .await?; + + // Benchmark reading messages + let start = std::time::Instant::now(); + let retrieved = archivist.get_messages(session_response.scroll_id, None).await?; + let elapsed = start.elapsed(); + + println!("\nBenchmark: Read 1000 messages"); + println!(" Total time: {:?}", elapsed); + println!(" Messages retrieved: {}", retrieved.len()); + + // Log for tracking (no strict requirement for large sessions) + println!(" Note: Performance acceptable for large session"); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + #[ignore] + async fn bench_list_100_sessions() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_bench_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Bench Connector".to_string(), + client_native_id: "bench@localhost".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let connector_response = archivist.register_connector(connector_req, None).await?; + + // Register 100 sessions + for i in 0..100 { + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: format!("bench-session-{}", i), + title: Some(format!("Session {}", i)), + custom_scroll_id: None, + metadata: serde_json::json!({"index": i}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + + archivist.register_session(session_req, None).await?; + } + + // Benchmark listing sessions + let start = std::time::Instant::now(); + let sessions = archivist + .list_sessions_paged( + SessionListQuery::default() + .with_connector(connector_response.connector_uid) + .with_limit(dirigent_archivist::MAX_PAGE_LIMIT), + ) + .await? + .items; + let elapsed = start.elapsed(); + + println!("\nBenchmark: List 100 sessions"); + println!(" Total time: {:?}", elapsed); + println!(" Sessions listed: {}", sessions.len()); + + // Target: sub-100ms for typical connector + assert!( + elapsed.as_millis() < 100, + "List performance degraded: {:?} > 100ms", + elapsed + ); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_mixed_format_compatibility() { + // Create archivist with temp directory + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await.unwrap() + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await.unwrap(); + + // Register connector + let connector_req = RegisterConnectorRequest { + r#type: "Test".to_string(), + title: "Test Connector".to_string(), + client_native_id: "test-connector".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + let connector_resp = archivist.register_connector(connector_req, None).await.unwrap(); + let connector_uid = connector_resp.connector_uid; + + // Manually create a session with .jsonl format for messages + let scroll_id = Uuid::now_v7(); + let session_metadata = SessionMetadata { + version: 1, + scroll_id, + created_at: Utc::now(), + updated_at: Utc::now(), + title: Some("Test Session".to_string()), + connector_uid, + native_session_id: Some("test-123".to_string()), + agent_id: None, + parent_scroll_id: None, + continuation: None, + tags: vec![], + metadata: serde_json::json!({}), + no_update: false, + kind: SessionKind::Chat, + acp_client_id: None, + is_connected: None, + current_session_id: None, + models: None, + modes: None, + config_options: None, + completeness: Default::default(), + matrix_room_id: None, + matrix_sharing_active: false, + matrix_shared_at: None, + is_subagent: false, + subagent_type: None, + spawning_tool_use_id: None, + }; + + backend.paths().ensure_dirs(scroll_id).await.unwrap(); + write_json(&backend.paths().session_json(scroll_id), &session_metadata).await.unwrap(); + + // Create messages.jsonl (not .ndjson) + let jsonl_path = backend.paths().session_dir(scroll_id).join("messages.jsonl"); + let message = MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: scroll_id, + parent_id: None, + ts: Utc::now(), + role: "user".to_string(), + author: None, + content_md: "Hello from .jsonl file".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + append_ndjson(&jsonl_path, &message).await.unwrap(); + + // Read messages using archivist API + let messages = archivist.get_messages(scroll_id, None).await.unwrap(); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].content_md, "Hello from .jsonl file"); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + } + + #[tokio::test] + async fn test_fingerprint_registration_and_matching() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register first connector with a fingerprint + let fingerprint = "acp/stdio:/usr/bin/claude".to_string(); + let req1 = RegisterConnectorRequest { + r#type: "ACP".to_string(), + title: "Claude CLI".to_string(), + client_native_id: "acp-session-abc123".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some(fingerprint.clone()), + }; + + let response1 = archivist.register_connector(req1, None).await?; + assert_eq!(response1.status, RegisterStatus::Accepted); + let original_uid = response1.connector_uid; + + // Register a second connector with a DIFFERENT client_native_id + // but the SAME fingerprint. Should be ALIASED to the original. + let req2 = RegisterConnectorRequest { + r#type: "ACP".to_string(), + title: "Claude CLI (re-added)".to_string(), + client_native_id: "acp-session-xyz789".to_string(), + custom_uid: None, + metadata: serde_json::json!({"version": 2}), + fingerprint: Some(fingerprint.clone()), + }; + + let response2 = archivist.register_connector(req2, None).await?; + assert_eq!( + response2.status, + RegisterStatus::Aliased, + "Same fingerprint should cause ALIASED status" + ); + assert_eq!( + response2.connector_uid, original_uid, + "Aliased connector should return the original UID" + ); + assert_eq!(response2.alias_of, Some(original_uid)); + assert!( + response2.note.as_deref().unwrap_or("").contains("fingerprint"), + "Note should mention fingerprint matching" + ); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_fingerprint_no_match_different_fingerprints() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register first connector with fingerprint A + let req1 = RegisterConnectorRequest { + r#type: "ACP".to_string(), + title: "Claude CLI".to_string(), + client_native_id: "acp-claude-1".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some("acp/stdio:/usr/bin/claude".to_string()), + }; + + let response1 = archivist.register_connector(req1, None).await?; + assert_eq!(response1.status, RegisterStatus::Accepted); + + // Register second connector with fingerprint B (different) + let req2 = RegisterConnectorRequest { + r#type: "ACP".to_string(), + title: "Codex Agent".to_string(), + client_native_id: "acp-codex-1".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some("acp/stdio:/usr/bin/codex".to_string()), + }; + + let response2 = archivist.register_connector(req2, None).await?; + assert_eq!( + response2.status, + RegisterStatus::Accepted, + "Different fingerprints should both be ACCEPTED" + ); + assert_ne!( + response2.connector_uid, response1.connector_uid, + "Different fingerprints should get different UIDs" + ); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_fingerprint_none_skips_matching() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register first connector WITH a fingerprint + let req1 = RegisterConnectorRequest { + r#type: "ACP".to_string(), + title: "Claude CLI".to_string(), + client_native_id: "acp-claude-1".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some("acp/stdio:/usr/bin/claude".to_string()), + }; + + let response1 = archivist.register_connector(req1, None).await?; + assert_eq!(response1.status, RegisterStatus::Accepted); + + // Register second connector WITHOUT a fingerprint (different native ID) + // Should NOT match the first connector even though one exists with a fingerprint + let req2 = RegisterConnectorRequest { + r#type: "ACP".to_string(), + title: "Unknown ACP Agent".to_string(), + client_native_id: "acp-unknown-1".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let response2 = archivist.register_connector(req2, None).await?; + assert_eq!( + response2.status, + RegisterStatus::Accepted, + "Connector with fingerprint=None should not match existing fingerprints" + ); + assert_ne!( + response2.connector_uid, response1.connector_uid, + "Should get a new UID when no fingerprint is provided" + ); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_connector_fingerprint_persistence() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // Register connector with a fingerprint + let fingerprint_value = "acp/stdio:/usr/bin/claude".to_string(); + let connector_req = RegisterConnectorRequest { + r#type: "ACP".to_string(), + title: "Claude CLI".to_string(), + client_native_id: "acp-claude-1".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some(fingerprint_value.clone()), + }; + + let response = archivist.register_connector(connector_req, None).await?; + assert_eq!(response.status, RegisterStatus::Accepted); + let connector_uid = response.connector_uid; + + // Verify fingerprint is persisted in connector.json + let connector_json_path = backend.paths() + .connector_dir(connector_uid) + .join("connector.json"); + let raw_json = tokio::fs::read_to_string(&connector_json_path).await.unwrap(); + let connector_record: serde_json::Value = serde_json::from_str(&raw_json).unwrap(); + assert_eq!( + connector_record["fingerprint"].as_str(), + Some(fingerprint_value.as_str()), + "Fingerprint should be persisted in connector.json" + ); + + // Verify fingerprint is persisted in TSV index + let index_path = backend.paths().connector_index_tsv(); + let tsv_content = tokio::fs::read_to_string(&index_path).await.unwrap(); + assert!( + tsv_content.contains(&fingerprint_value), + "Fingerprint should appear in TSV index" + ); + + // Verify fingerprint can be read back via TSV reader + let rows = dirigent_archivist::storage::tsv::read_connector_index(&index_path).await.unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].fingerprint, Some(fingerprint_value.clone())); + + // Register a second connector WITHOUT a fingerprint (ensure None is handled) + let connector_req2 = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "OpenCode Local".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + + let response2 = archivist.register_connector(connector_req2, None).await?; + assert_eq!(response2.status, RegisterStatus::Accepted); + + // Re-read TSV and verify both connectors + let rows = dirigent_archivist::storage::tsv::read_connector_index(&index_path).await.unwrap(); + assert_eq!(rows.len(), 2); + + // First connector should have fingerprint + let row_with_fp = rows.iter().find(|r| r.connector_uid == connector_uid).unwrap(); + assert_eq!(row_with_fp.fingerprint, Some(fingerprint_value)); + + // Second connector should have no fingerprint + let row_without_fp = rows.iter().find(|r| r.connector_uid == response2.connector_uid).unwrap(); + assert_eq!(row_without_fp.fingerprint, None); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_list_connectors() -> std::result::Result<(), Box> { + let temp_dir = std::env::temp_dir().join(format!("archivist_lc_{}", uuid::Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + let req1 = RegisterConnectorRequest { + r#type: "Acp".to_string(), + title: "Claude".to_string(), + client_native_id: "c1".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some("acp/stdio:/usr/bin/claude".to_string()), + }; + archivist.register_connector(req1, None).await?; + + let req2 = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "OC".to_string(), + client_native_id: "c2".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + archivist.register_connector(req2, None).await?; + + let connectors = archivist.list_connectors(None).await?; + assert_eq!(connectors.len(), 2); + + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_move_session() -> std::result::Result<(), Box> { + let temp_dir = std::env::temp_dir().join(format!("archivist_mv_{}", uuid::Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + let c1 = archivist + .register_connector( + RegisterConnectorRequest { + r#type: "Acp".to_string(), + title: "Source".to_string(), + client_native_id: "src".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }, + None, + ) + .await? + .connector_uid; + + let c2 = archivist + .register_connector( + RegisterConnectorRequest { + r#type: "Acp".to_string(), + title: "Target".to_string(), + client_native_id: "tgt".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }, + None, + ) + .await? + .connector_uid; + + let session = archivist + .register_session( + RegisterSessionRequest { + connector_uid: c1, + native_session_id: "s1".to_string(), + title: Some("Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await?; + + // Verify under c1 + assert_eq!( + archivist + .list_sessions_paged( + SessionListQuery::default().with_connector(c1).with_limit(100), + ) + .await? + .items + .len(), + 1 + ); + + // Move to c2 + archivist + .move_session_to_connector(session.scroll_id, c2, None) + .await?; + + // c1 should be empty, c2 should have the session + assert_eq!( + archivist + .list_sessions_paged( + SessionListQuery::default().with_connector(c1).with_limit(100), + ) + .await? + .items + .len(), + 0 + ); + let c2_sessions = archivist + .list_sessions_paged( + SessionListQuery::default().with_connector(c2).with_limit(100), + ) + .await? + .items; + assert_eq!(c2_sessions.len(), 1); + assert_eq!(c2_sessions[0].connector_uid, c2); + + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_copy_session() -> std::result::Result<(), Box> { + let temp_dir = std::env::temp_dir().join(format!("archivist_cp_{}", uuid::Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + let c1 = archivist + .register_connector( + RegisterConnectorRequest { + r#type: "Acp".to_string(), + title: "Source".to_string(), + client_native_id: "src".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }, + None, + ) + .await? + .connector_uid; + + let c2 = archivist + .register_connector( + RegisterConnectorRequest { + r#type: "Acp".to_string(), + title: "Target".to_string(), + client_native_id: "tgt".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }, + None, + ) + .await? + .connector_uid; + + let session = archivist + .register_session( + RegisterSessionRequest { + connector_uid: c1, + native_session_id: "s1".to_string(), + title: Some("Original".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await?; + + // Add a message + let msg = MessageRecord { + version: 1, + message_id: uuid::Uuid::now_v7(), + session: session.scroll_id, + parent_id: None, + ts: chrono::Utc::now(), + role: "user".to_string(), + author: Some("test".to_string()), + content_md: "Hello".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + archivist + .append_messages(session.scroll_id, vec![msg], None) + .await?; + + // Copy + let new_scroll_id = archivist + .copy_session_to_connector(session.scroll_id, c2, None) + .await?; + assert_ne!(new_scroll_id, session.scroll_id); + + // Original still under c1 + assert_eq!( + archivist + .list_sessions_paged( + SessionListQuery::default().with_connector(c1).with_limit(100), + ) + .await? + .items + .len(), + 1 + ); + // Copy under c2 + let c2_sessions = archivist + .list_sessions_paged( + SessionListQuery::default().with_connector(c2).with_limit(100), + ) + .await? + .items; + assert_eq!(c2_sessions.len(), 1); + assert_eq!(c2_sessions[0].connector_uid, c2); + // Messages copied + let msgs = archivist.get_messages(new_scroll_id, None).await?; + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content_md, "Hello"); + + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_move_sessions_bulk() -> std::result::Result<(), Box> { + let temp_dir = + std::env::temp_dir().join(format!("archivist_mvb_{}", uuid::Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + let c1 = archivist + .register_connector( + RegisterConnectorRequest { + r#type: "Acp".to_string(), + title: "Source".to_string(), + client_native_id: "src".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }, + None, + ) + .await? + .connector_uid; + + let c2 = archivist + .register_connector( + RegisterConnectorRequest { + r#type: "Acp".to_string(), + title: "Target".to_string(), + client_native_id: "tgt".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }, + None, + ) + .await? + .connector_uid; + + let mut scroll_ids = Vec::new(); + for i in 0..3 { + let s = archivist + .register_session( + RegisterSessionRequest { + connector_uid: c1, + native_session_id: format!("s{}", i), + title: Some(format!("Session {}", i)), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await?; + scroll_ids.push(s.scroll_id); + } + + let report = archivist + .move_sessions_to_connector(scroll_ids, c2, None) + .await?; + assert_eq!(report.moved, 3); + assert_eq!(report.failed, 0); + assert!(report.errors.is_empty()); + assert_eq!( + archivist + .list_sessions_paged( + SessionListQuery::default().with_connector(c1).with_limit(100), + ) + .await? + .items + .len(), + 0 + ); + assert_eq!( + archivist + .list_sessions_paged( + SessionListQuery::default().with_connector(c2).with_limit(100), + ) + .await? + .items + .len(), + 3 + ); + + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_connector_identity_persistence_e2e() -> std::result::Result<(), Box> + { + let temp_dir = + std::env::temp_dir().join(format!("archivist_e2e_{}", uuid::Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + // 1. Register connector with fingerprint + let req1 = RegisterConnectorRequest { + r#type: "Acp".to_string(), + title: "Claude v1".to_string(), + client_native_id: "first-run-id".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some("acp/stdio:/usr/bin/claude".to_string()), + }; + let resp1 = archivist.register_connector(req1, None).await?; + let original_uid = resp1.connector_uid; + assert_eq!(resp1.status, RegisterStatus::Accepted); + + // 2. Create sessions under this connector + let s1 = archivist + .register_session( + RegisterSessionRequest { + connector_uid: original_uid, + native_session_id: "session-1".to_string(), + title: Some("Important Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await?; + + // 3. Add messages + let msg = MessageRecord { + version: 1, + message_id: uuid::Uuid::now_v7(), + session: s1.scroll_id, + parent_id: None, + ts: chrono::Utc::now(), + role: "user".to_string(), + author: Some("test".to_string()), + content_md: "Don't lose me!".to_string(), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + }; + archivist + .append_messages(s1.scroll_id, vec![msg], None) + .await?; + + // 4. Simulate "remove and re-add" -- new connector_id, same fingerprint + let req2 = RegisterConnectorRequest { + r#type: "Acp".to_string(), + title: "Claude v2 (reinstalled)".to_string(), + client_native_id: "second-run-id".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: Some("acp/stdio:/usr/bin/claude".to_string()), + }; + let resp2 = archivist.register_connector(req2, None).await?; + + // 5. Verify: same UID, ALIASED status + assert_eq!(resp2.status, RegisterStatus::Aliased); + assert_eq!(resp2.connector_uid, original_uid); + + // 6. Verify: sessions still accessible under the same UID + let sessions = archivist + .list_sessions_paged( + SessionListQuery::default() + .with_connector(original_uid) + .with_limit(100), + ) + .await? + .items; + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].title, Some("Important Session".to_string())); + + // 7. Verify: messages intact + let messages = archivist.get_messages(s1.scroll_id, None).await?; + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].content_md, "Don't lose me!"); + + // 8. Verify: connector record is preserved under the original UID. + // + // NOTE: Pre-Phase-2 `FileBasedArchivist` ALSO refreshed the matched + // connector's `title`/`metadata` on fingerprint-based ALIASED + // registration. The `Archivist` deliberately drops that + // refresh (see `coordinator/connectors.rs` for the rationale — the + // `ConnectorRegistryBackend` trait has no "update metadata" method + // yet, and `put_connector` would append rather than mutate). The + // identity (UID) is stable; the title stays the original. + let connectors = archivist.list_connectors(None).await?; + let connector = connectors + .iter() + .find(|c| c.connector_uid == original_uid) + .unwrap(); + assert_eq!(connector.title, "Claude v1"); + + // 9. Test move_session works after fingerprint re-association + let c2 = archivist + .register_connector( + RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Secondary".to_string(), + client_native_id: "secondary".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }, + None, + ) + .await? + .connector_uid; + + archivist + .move_session_to_connector(s1.scroll_id, c2, None) + .await?; + assert_eq!( + archivist + .list_sessions_paged( + SessionListQuery::default() + .with_connector(original_uid) + .with_limit(100), + ) + .await? + .items + .len(), + 0 + ); + assert_eq!( + archivist + .list_sessions_paged( + SessionListQuery::default().with_connector(c2).with_limit(100), + ) + .await? + .items + .len(), + 1 + ); + + // Move it back + archivist + .move_session_to_connector(s1.scroll_id, original_uid, None) + .await?; + assert_eq!( + archivist + .list_sessions_paged( + SessionListQuery::default() + .with_connector(original_uid) + .with_limit(100), + ) + .await? + .items + .len(), + 1 + ); + + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_paged_walk_fifty_sessions() -> Result<()> { + use chrono::{Duration, Utc}; + use dirigent_archivist::SessionListQuery; + + let temp_dir = std::env::temp_dir().join(format!("paged_walk_{}", Uuid::now_v7())); + let backend = std::sync::Arc::new( + dirigent_archivist::backends::JsonlBackend::new(temp_dir.clone()).await? + ); + let archivist = Archivist::from_single_backend( + "main".into(), backend.clone() + ).await?; + + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "paged-walk".to_string(), + client_native_id: format!("paged-walk@{}", Uuid::now_v7()), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + let cresp = archivist.register_connector(connector_req, None).await?; + let uid = cresp.connector_uid; + + let base = Utc::now(); + for i in 0..50 { + let tag = if i % 2 == 0 { "even" } else { "odd" }; + + let req = RegisterSessionRequest { + connector_uid: uid, + native_session_id: format!("walk-{i}"), + title: Some(format!("title-{i}")), + custom_scroll_id: None, + metadata: serde_json::json!({"model": "claude-3-5-sonnet"}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + let r = archivist.register_session(req, None).await?; + + let mut meta = archivist.get_session_metadata(r.scroll_id, None).await?; + meta.updated_at = base - Duration::seconds(i); + meta.tags = vec![tag.to_string()]; + let path = backend.paths().session_json(r.scroll_id); + dirigent_archivist::storage::json::write_json(&path, &meta) + .await + .map_err(dirigent_archivist::ArchivistError::Io)?; + } + + // Walk in chunks of 10 — 5 pages, 50 items, no dupes. + let mut seen: std::collections::HashSet = std::collections::HashSet::new(); + let mut cursor = None; + let mut page_count = 0; + loop { + let page = archivist + .list_sessions_paged( + SessionListQuery::default() + .with_connector(uid) + .with_limit(10) + .with_cursor(cursor.clone()), + ) + .await?; + page_count += 1; + for s in &page.items { + assert!(seen.insert(s.scroll_id), "duplicate scroll_id across pages"); + } + if page.next_cursor.is_none() { + break; + } + cursor = page.next_cursor; + assert!(page_count <= 10, "runaway pagination"); + } + assert_eq!(seen.len(), 50); + assert_eq!(page_count, 5); + + // Compose filter: tag=even AND title contains "1" → titles 10, 12, 14, 16, 18. + let mut q = SessionListQuery::default().with_connector(uid).with_limit(50); + q.tags = vec!["even".into()]; + q.title_query = Some("1".into()); + let page = archivist.list_sessions_paged(q).await?; + assert_eq!( + page.items.len(), + 5, + "got titles {:?}", + page.items.iter().map(|s| s.title.clone()).collect::>() + ); + + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } +} diff --git a/crates/dirigent_archivist/tests/list_sessions_paged_test.rs b/crates/dirigent_archivist/tests/list_sessions_paged_test.rs new file mode 100644 index 0000000..a89a5ca --- /dev/null +++ b/crates/dirigent_archivist/tests/list_sessions_paged_test.rs @@ -0,0 +1,364 @@ +//! Tests for `Archivist::list_sessions_paged` — pagination, filters, cursor stability. + +use chrono::{Duration, Utc}; +use dirigent_archivist::{ + backends::JsonlBackend, Archivist, RegisterConnectorRequest, RegisterSessionRequest, + Result, SessionListQuery, +}; +use std::sync::Arc; +use uuid::Uuid; + +/// Scaffold: create a coordinator backed by a single `JsonlBackend` in a +/// unique temp dir, returning the backend alongside it so tests can probe +/// disk paths via `backend.paths()`. +async fn mk_archivist() -> Result<(Archivist, Arc, std::path::PathBuf)> { + let temp_dir = std::env::temp_dir().join(format!("paged_test_{}", Uuid::now_v7())); + let backend = Arc::new(JsonlBackend::new(temp_dir.clone()).await?); + let archivist = + Archivist::from_single_backend("main".into(), backend.clone()).await?; + Ok((archivist, backend, temp_dir)) +} + +/// Register a connector, return its UID. +async fn mk_connector(archivist: &Archivist, title: &str) -> Result { + let resp = archivist + .register_connector( + RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: title.to_string(), + client_native_id: format!("{title}@local:{}", Uuid::now_v7()), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }, + None, + ) + .await?; + Ok(resp.connector_uid) +} + +/// Register a session and patch fields that `register_session` does not expose. +#[allow(clippy::too_many_arguments)] +async fn mk_session( + archivist: &Archivist, + backend: &JsonlBackend, + connector_uid: Uuid, + native_id: &str, + title: Option<&str>, + tags: Vec, + model: Option<&str>, + project_id: Option<&str>, + no_update: bool, +) -> Result { + let mut metadata = serde_json::Map::new(); + if let Some(m) = model { + metadata.insert("model".to_string(), serde_json::Value::String(m.to_string())); + } + if let Some(p) = project_id { + metadata.insert( + "project_id".to_string(), + serde_json::Value::String(p.to_string()), + ); + } + + let resp = archivist + .register_session( + RegisterSessionRequest { + connector_uid, + native_session_id: native_id.to_string(), + title: title.map(String::from), + custom_scroll_id: None, + metadata: serde_json::Value::Object(metadata), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await?; + let scroll_id = resp.scroll_id; + + // Patch tags / no_update into session.json on disk. + if !tags.is_empty() || no_update { + let mut meta = archivist.get_session_metadata(scroll_id, None).await?; + meta.tags = tags; + meta.no_update = no_update; + let path = backend.paths().session_json(scroll_id); + dirigent_archivist::storage::json::write_json(&path, &meta) + .await + .map_err(dirigent_archivist::ArchivistError::Io)?; + } + + Ok(scroll_id) +} + +/// Overwrite a session's updated_at on disk for deterministic ordering. +async fn set_updated_at( + archivist: &Archivist, + backend: &JsonlBackend, + scroll_id: Uuid, + when: chrono::DateTime, +) -> Result<()> { + let mut meta = archivist.get_session_metadata(scroll_id, None).await?; + meta.updated_at = when; + let path = backend.paths().session_json(scroll_id); + dirigent_archivist::storage::json::write_json(&path, &meta) + .await + .map_err(dirigent_archivist::ArchivistError::Io)?; + Ok(()) +} + +fn cleanup(path: std::path::PathBuf) { + let _ = std::fs::remove_dir_all(path); +} + +#[tokio::test] +async fn list_sessions_paged_respects_limit() -> Result<()> { + let (archivist, backend, temp) = mk_archivist().await?; + let uid = mk_connector(&archivist, "connector-a").await?; + + let base = Utc::now(); + for i in 0..30 { + let scroll = mk_session( + &archivist, + &backend, + uid, + &format!("native-{i}"), + Some(&format!("title-{i}")), + Vec::new(), + None, + None, + false, + ) + .await?; + set_updated_at(&archivist, &backend, scroll, base - Duration::seconds(i)).await?; + } + + let page = archivist + .list_sessions_paged(SessionListQuery::default().with_connector(uid).with_limit(10)) + .await?; + + assert_eq!(page.items.len(), 10); + assert!(page.next_cursor.is_some()); + + cleanup(temp); + Ok(()) +} + +#[tokio::test] +async fn list_sessions_paged_end_of_list() -> Result<()> { + let (archivist, backend, temp) = mk_archivist().await?; + let uid = mk_connector(&archivist, "connector-b").await?; + + let base = Utc::now(); + for i in 0..5 { + let scroll = mk_session( + &archivist, + &backend, + uid, + &format!("native-{i}"), + Some(&format!("title-{i}")), + Vec::new(), + None, + None, + false, + ) + .await?; + set_updated_at(&archivist, &backend, scroll, base - Duration::seconds(i)).await?; + } + + let page = archivist + .list_sessions_paged(SessionListQuery::default().with_connector(uid).with_limit(100)) + .await?; + + assert_eq!(page.items.len(), 5); + assert!(page.next_cursor.is_none()); + + cleanup(temp); + Ok(()) +} + +#[tokio::test] +async fn list_sessions_paged_cursor_stability() -> Result<()> { + let (archivist, backend, temp) = mk_archivist().await?; + let uid = mk_connector(&archivist, "connector-c").await?; + + let fixed = Utc::now(); + for i in 0..6 { + let scroll = mk_session( + &archivist, + &backend, + uid, + &format!("native-{i}"), + Some(&format!("title-{i}")), + Vec::new(), + None, + None, + false, + ) + .await?; + set_updated_at(&archivist, &backend, scroll, fixed).await?; + } + + let p1 = archivist + .list_sessions_paged(SessionListQuery::default().with_connector(uid).with_limit(3)) + .await?; + assert_eq!(p1.items.len(), 3); + assert!(p1.next_cursor.is_some()); + + let p2 = archivist + .list_sessions_paged( + SessionListQuery::default() + .with_connector(uid) + .with_limit(3) + .with_cursor(p1.next_cursor.clone()), + ) + .await?; + assert_eq!(p2.items.len(), 3); + + let ids1: std::collections::HashSet<_> = p1.items.iter().map(|s| s.scroll_id).collect(); + let ids2: std::collections::HashSet<_> = p2.items.iter().map(|s| s.scroll_id).collect(); + assert!(ids1.is_disjoint(&ids2), "page 1 and page 2 must not overlap"); + assert!(p2.next_cursor.is_none()); + + cleanup(temp); + Ok(()) +} + +#[tokio::test] +async fn list_sessions_paged_title_filter() -> Result<()> { + let (archivist, backend, temp) = mk_archivist().await?; + let uid = mk_connector(&archivist, "connector-d").await?; + + mk_session(&archivist, &backend, uid, "n1", Some("Alpha beta"), vec![], None, None, false).await?; + mk_session(&archivist, &backend, uid, "n2", Some("BETA only"), vec![], None, None, false).await?; + mk_session(&archivist, &backend, uid, "n3", Some("gamma"), vec![], None, None, false).await?; + mk_session(&archivist, &backend, uid, "n4", None, vec![], None, None, false).await?; + + let page = archivist + .list_sessions_paged( + SessionListQuery::default() + .with_connector(uid) + .with_limit(50) + .with_title_query("beta"), + ) + .await?; + + let titles: Vec<_> = page.items.iter().filter_map(|s| s.title.clone()).collect(); + assert_eq!(titles.len(), 2, "got {titles:?}"); + assert!(titles.iter().any(|t| t == "Alpha beta")); + assert!(titles.iter().any(|t| t == "BETA only")); + + cleanup(temp); + Ok(()) +} + +#[tokio::test] +async fn list_sessions_paged_tags_and() -> Result<()> { + let (archivist, backend, temp) = mk_archivist().await?; + let uid = mk_connector(&archivist, "connector-e").await?; + + mk_session( + &archivist, &backend, uid, "n1", Some("s1"), + vec!["red".into(), "blue".into()], None, None, false, + ).await?; + mk_session( + &archivist, &backend, uid, "n2", Some("s2"), + vec!["red".into()], None, None, false, + ).await?; + mk_session( + &archivist, &backend, uid, "n3", Some("s3"), + vec!["blue".into()], None, None, false, + ).await?; + + let mut q = SessionListQuery::default().with_connector(uid).with_limit(50); + q.tags = vec!["red".into(), "blue".into()]; + + let page = archivist.list_sessions_paged(q).await?; + + assert_eq!(page.items.len(), 1); + assert_eq!(page.items[0].title.as_deref(), Some("s1")); + + cleanup(temp); + Ok(()) +} + +#[tokio::test] +async fn list_sessions_paged_model_filter() -> Result<()> { + let (archivist, backend, temp) = mk_archivist().await?; + let uid = mk_connector(&archivist, "connector-f").await?; + + mk_session(&archivist, &backend, uid, "n1", Some("s1"), vec![], Some("claude-3-5-sonnet"), None, false).await?; + mk_session(&archivist, &backend, uid, "n2", Some("s2"), vec![], Some("gpt-4o"), None, false).await?; + mk_session(&archivist, &backend, uid, "n3", Some("s3"), vec![], None, None, false).await?; + + let mut q = SessionListQuery::default().with_connector(uid).with_limit(50); + q.model_filter = Some("sonnet".into()); + + let page = archivist.list_sessions_paged(q).await?; + + assert_eq!(page.items.len(), 1); + assert_eq!(page.items[0].title.as_deref(), Some("s1")); + + cleanup(temp); + Ok(()) +} + +#[tokio::test] +async fn list_sessions_paged_include_hidden() -> Result<()> { + let (archivist, backend, temp) = mk_archivist().await?; + let uid = mk_connector(&archivist, "connector-g").await?; + + mk_session(&archivist, &backend, uid, "n1", Some("visible"), vec![], None, None, false).await?; + mk_session(&archivist, &backend, uid, "n2", Some("hidden"), vec![], None, None, true).await?; + + let visible_only = archivist + .list_sessions_paged(SessionListQuery::default().with_connector(uid).with_limit(50)) + .await?; + assert_eq!(visible_only.items.len(), 1); + assert_eq!(visible_only.items[0].title.as_deref(), Some("visible")); + + let all = archivist + .list_sessions_paged( + SessionListQuery::default() + .with_connector(uid) + .with_limit(50) + .with_include_hidden(true), + ) + .await?; + assert_eq!(all.items.len(), 2); + + cleanup(temp); + Ok(()) +} + +#[tokio::test] +async fn list_sessions_paged_project_scope() -> Result<()> { + let (archivist, backend, temp) = mk_archivist().await?; + let c1 = mk_connector(&archivist, "connector-h1").await?; + let c2 = mk_connector(&archivist, "connector-h2").await?; + + mk_session(&archivist, &backend, c1, "n1", Some("proj-a-1"), vec![], None, Some("proj-a"), false).await?; + mk_session(&archivist, &backend, c1, "n2", Some("proj-b-1"), vec![], None, Some("proj-b"), false).await?; + mk_session(&archivist, &backend, c2, "n3", Some("proj-a-2"), vec![], None, Some("proj-a"), false).await?; + + let page = archivist + .list_sessions_paged( + SessionListQuery::default() + .with_project("proj-a") + .with_limit(50), + ) + .await?; + + assert_eq!(page.items.len(), 2); + for s in &page.items { + assert_eq!(s.metadata.get("project_id").and_then(|v| v.as_str()), Some("proj-a")); + } + + cleanup(temp); + Ok(()) +} diff --git a/crates/dirigent_archivist/tests/multi_backend_boot_test.rs b/crates/dirigent_archivist/tests/multi_backend_boot_test.rs new file mode 100644 index 0000000..8ee7a78 --- /dev/null +++ b/crates/dirigent_archivist/tests/multi_backend_boot_test.rs @@ -0,0 +1,130 @@ +use dirigent_archivist::{ + coordinator::Archivist, + error::ArchivistBootError, + registry::{ArchivesConfig, BackendRegistry}, +}; + +fn parse(toml_src: &str) -> ArchivesConfig { + toml::from_str(toml_src).unwrap() +} + +#[tokio::test] +async fn boot_with_one_jsonl_archive() { + let dir = tempfile::tempdir().unwrap(); + let cfg = parse(&format!( + r#" + [[archives]] + name = "main" + type = "jsonl" + [archives.params] + path = "{}" + "#, + dir.path().to_string_lossy().replace('\\', "/") + )); + let registry = BackendRegistry::with_jsonl(); + let _archivist = Archivist::from_config(cfg, ®istry, None).await.unwrap(); +} + +#[tokio::test] +async fn boot_empty_config_is_ephemeral() { + let cfg: ArchivesConfig = toml::from_str("").unwrap(); + let registry = BackendRegistry::with_jsonl(); + let archivist = Archivist::from_config(cfg, ®istry, None).await.unwrap(); + let archives = archivist.list_archives().await.unwrap(); + assert!(archives.is_empty()); +} + +#[tokio::test] +async fn boot_unknown_type_errors() { + let cfg = parse( + r#" + [[archives]] + name = "x" + type = "nope" + [archives.params] + "#, + ); + let registry = BackendRegistry::with_jsonl(); + let result = Archivist::from_config(cfg, ®istry, None).await; + match result { + Ok(_) => panic!("expected UnknownType error"), + Err(err) => assert!( + matches!(err, ArchivistBootError::UnknownType { .. }), + "expected UnknownType, got {err:?}" + ), + } +} + +#[tokio::test] +async fn boot_no_primary_errors() { + let cfg = parse( + r#" + [[archives]] + name = "mirror" + type = "jsonl" + failure_mode = "best_effort" + [archives.params] + path = "/tmp/whatever" + "#, + ); + let registry = BackendRegistry::with_jsonl(); + let result = Archivist::from_config(cfg, ®istry, None).await; + match result { + Ok(_) => panic!("expected Validation error"), + Err(err) => assert!( + matches!(err, ArchivistBootError::Validation(_)), + "expected Validation, got {err:?}" + ), + } +} + +#[tokio::test] +async fn boot_duplicate_name_errors() { + let dir = tempfile::tempdir().unwrap(); + let cfg = parse(&format!( + r#" + [[archives]] + name = "main" + type = "jsonl" + [archives.params] + path = "{p}" + + [[archives]] + name = "main" + type = "jsonl" + [archives.params] + path = "{p}" + "#, + p = dir.path().to_string_lossy().replace('\\', "/"), + )); + let registry = BackendRegistry::with_jsonl(); + let result = Archivist::from_config(cfg, ®istry, None).await; + match result { + Ok(_) => panic!("expected Validation error"), + Err(err) => assert!( + matches!(err, ArchivistBootError::Validation(_)), + "expected Validation, got {err:?}" + ), + } +} + +#[test] +fn example_toml_parses() { + // Load the full dirigent.toml.example and parse just the [[archives]] + // section as `ArchivesConfig`. Confirms the example's archive syntax is + // valid Phase 3 schema. + let src = std::fs::read_to_string( + std::path::Path::new(env!("CARGO_MANIFEST_DIR")) + .join("../../dirigent.toml.example"), + ) + .expect("dirigent.toml.example present at workspace root"); + // Parse the whole file as a TOML value, then try to deserialize the + // full document into `ArchivesConfig`. Any `archives` subtable gets picked up; + // other top-level fields (connectors, matrix, ...) are ignored because + // `ArchivesConfig` only has `entries: Vec` via + // `#[serde(rename = "archives")]`. + let cfg: ArchivesConfig = + toml::from_str(&src).expect("ArchivesConfig from full example"); + cfg.validate().expect("example config validates"); + assert!(!cfg.entries.is_empty(), "example must declare at least one archive"); +} diff --git a/crates/dirigent_archivist/tests/multi_backend_capability_test.rs b/crates/dirigent_archivist/tests/multi_backend_capability_test.rs new file mode 100644 index 0000000..0da0620 --- /dev/null +++ b/crates/dirigent_archivist/tests/multi_backend_capability_test.rs @@ -0,0 +1,76 @@ +#![cfg(feature = "test-utils")] + +use std::sync::Arc; + +use dirigent_archivist::backend::mock::MockBackend; +use dirigent_archivist::backend::{ArchiveBackend, ArchiveCapability, CapabilitySet, HealthStatus}; +use dirigent_archivist::coordinator::Archivist; +use dirigent_archivist::registry::{ArchiveRegistration, FailureMode, WritePolicy}; +use dirigent_archivist::types::{MetaEventRecord, MetaEventType}; +use uuid::Uuid; + +fn reg(name: &str, backend: Arc, priority: u32) -> Arc { + Arc::new(ArchiveRegistration::new( + name.into(), + "mock", + backend as Arc, + true, + FailureMode::Required, + priority, + true, + WritePolicy::Inline, + None, + HealthStatus::Healthy, + )) +} + +fn stub_meta_event(scroll_id: Uuid) -> MetaEventRecord { + MetaEventRecord { + version: 1, + event_id: Uuid::now_v7(), + session: scroll_id, + ts: chrono::Utc::now(), + event_type: MetaEventType::ClientConnected, + description: "test event".into(), + linked_session_id: None, + linked_connector_id: None, + linked_connector_title: None, + metadata: serde_json::Value::Null, + } +} + +#[tokio::test] +async fn capability_filter_skips_backend_without_meta_events() { + let mut caps_with_meta = CapabilitySet::new(); + caps_with_meta.insert(ArchiveCapability::MetaEvents); + caps_with_meta.insert(ArchiveCapability::SessionMapping); + caps_with_meta.insert(ArchiveCapability::ConnectorRegistry); + let with_meta = Arc::new(MockBackend::with_capabilities(caps_with_meta)); + + let mut caps_without_meta = CapabilitySet::new(); + caps_without_meta.insert(ArchiveCapability::SessionMapping); + caps_without_meta.insert(ArchiveCapability::ConnectorRegistry); + let without_meta = Arc::new(MockBackend::with_capabilities(caps_without_meta)); + + let archivist = Archivist::from_registrations(vec![ + reg("primary", with_meta.clone(), 0), + reg("secondary", without_meta.clone(), 10), + ]); + + let scroll = Uuid::new_v4(); + archivist + .append_meta_events(scroll, vec![stub_meta_event(scroll)], None) + .await + .unwrap(); + + // Primary received the meta event. + assert!( + with_meta.has_meta_events(scroll), + "primary should receive meta event" + ); + // Secondary was capability-skipped. + assert!( + !without_meta.has_meta_events(scroll), + "secondary should be skipped" + ); +} diff --git a/crates/dirigent_archivist/tests/multi_backend_cross_test.rs b/crates/dirigent_archivist/tests/multi_backend_cross_test.rs new file mode 100644 index 0000000..42c464c --- /dev/null +++ b/crates/dirigent_archivist/tests/multi_backend_cross_test.rs @@ -0,0 +1,121 @@ +#![cfg(feature = "test-utils")] + +use std::sync::Arc; + +use dirigent_archivist::backend::mock::MockBackend; +use dirigent_archivist::backend::ArchiveBackend; +use dirigent_archivist::backend::HealthStatus; +use dirigent_archivist::coordinator::Archivist; +use dirigent_archivist::error::ArchivistError; +use dirigent_archivist::registry::{ArchiveRegistration, FailureMode, WritePolicy}; +use dirigent_archivist::types::SessionMetadata; +use uuid::Uuid; + +async fn dual_backend_coordinator() -> (Archivist, Arc, Arc) { + let a = Arc::new(MockBackend::new()); + let b = Arc::new(MockBackend::new()); + let regs = vec![ + Arc::new(ArchiveRegistration::new( + "a".into(), + "mock", + a.clone() as Arc, + true, + FailureMode::Required, + 0, + true, + WritePolicy::Inline, + None, + HealthStatus::Healthy, + )), + Arc::new(ArchiveRegistration::new( + "b".into(), + "mock", + b.clone() as Arc, + true, + FailureMode::Required, + 10, + true, + WritePolicy::Inline, + None, + HealthStatus::Healthy, + )), + ]; + (Archivist::from_registrations(regs), a, b) +} + +#[tokio::test] +async fn copy_session_carries_metadata_and_messages() { + let (archivist, a, b) = dual_backend_coordinator().await; + let scroll = Uuid::new_v4(); + + // Seed `a` only. + a.put_session(SessionMetadata::stub(scroll)).await.unwrap(); + a.append_messages(scroll, vec![]).await.unwrap(); + + archivist.copy_session(scroll, "a", "b").await.unwrap(); + + assert!(b.get_session(scroll).await.unwrap().is_some()); + assert!(a.get_session(scroll).await.unwrap().is_some()); +} + +#[tokio::test] +async fn move_session_removes_from_source() { + let (archivist, a, b) = dual_backend_coordinator().await; + let scroll = Uuid::new_v4(); + a.put_session(SessionMetadata::stub(scroll)).await.unwrap(); + + archivist.move_session(scroll, "a", "b").await.unwrap(); + + assert!(a.get_session(scroll).await.unwrap().is_none()); + assert!(b.get_session(scroll).await.unwrap().is_some()); + assert_eq!( + archivist.read_cache_size().await, + 1, + "cache should now reflect the move" + ); +} + +#[tokio::test] +async fn move_session_partial_failure_returns_partial_move_error() { + let (archivist, a, b) = dual_backend_coordinator().await; + let scroll = Uuid::new_v4(); + a.put_session(SessionMetadata::stub(scroll)).await.unwrap(); + + // The source-side delete happens AFTER the copy. Inject ONE write failure + // AFTER the copy has already consumed the write capacity. `MockBackend`'s + // inject_write_failures decrements on every mutating call — so we: + // 1. perform the copy through the archivist (uses put_session+append on `b`, + // but NO writes on `a`, since reads happen on the source side). + // 2. THEN inject a write failure on `a` to make the delete fail. + // + // Actually `copy_session` reads from `a` then writes to `b`, no writes on `a`. + // So we can safely inject BEFORE calling move_session: the only write on `a` + // during move_session is the delete, which will hit the injected failure. + + a.inject_write_failures(1); + + let err = archivist.move_session(scroll, "a", "b").await.unwrap_err(); + assert!(matches!(err, ArchivistError::PartialMove { .. })); + + // Both backends now have the session. + assert!(a.get_session(scroll).await.unwrap().is_some()); + assert!(b.get_session(scroll).await.unwrap().is_some()); +} + +#[tokio::test] +async fn delete_session_fans_out_and_invalidates_cache() { + let (archivist, a, b) = dual_backend_coordinator().await; + let scroll = Uuid::new_v4(); + a.put_session(SessionMetadata::stub(scroll)).await.unwrap(); + b.put_session(SessionMetadata::stub(scroll)).await.unwrap(); + + // Prime the cache with a read. + let _ = archivist.get_session_metadata(scroll, None).await.unwrap(); + assert_eq!(archivist.read_cache_size().await, 1); + + archivist.delete_session(scroll, None).await.unwrap(); + + assert!(a.get_session(scroll).await.unwrap().is_none()); + assert!(b.get_session(scroll).await.unwrap().is_none()); + assert_eq!(archivist.read_cache_size().await, 0); +} diff --git a/crates/dirigent_archivist/tests/multi_backend_fanout_test.rs b/crates/dirigent_archivist/tests/multi_backend_fanout_test.rs new file mode 100644 index 0000000..e4e411d --- /dev/null +++ b/crates/dirigent_archivist/tests/multi_backend_fanout_test.rs @@ -0,0 +1,124 @@ +#![cfg(feature = "test-utils")] + +use std::sync::Arc; + +use dirigent_archivist::backend::mock::MockBackend; +use dirigent_archivist::backend::{ArchiveBackend, HealthStatus}; +use dirigent_archivist::coordinator::Archivist; +use dirigent_archivist::registry::{ArchiveRegistration, FailureMode, WritePolicy}; +use uuid::Uuid; + +fn reg( + name: &str, + backend: Arc, + priority: u32, + failure: FailureMode, +) -> Arc { + Arc::new(ArchiveRegistration::new( + name.into(), + "mock", + backend as Arc, + true, + failure, + priority, + true, + WritePolicy::Inline, + None, + HealthStatus::Healthy, + )) +} + +fn sample_message(session: Uuid) -> dirigent_archivist::types::MessageRecord { + dirigent_archivist::types::MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session, + parent_id: None, + ts: chrono::Utc::now(), + role: "user".into(), + author: None, + content_md: "hi".into(), + content_parts: None, + attachments: vec![], + metadata: serde_json::Value::Null, + } +} + +#[tokio::test] +async fn write_fans_out_to_both_backends() { + let a = Arc::new(MockBackend::new()); + let b = Arc::new(MockBackend::new()); + let archivist = Archivist::from_registrations(vec![ + reg("a", a.clone(), 0, FailureMode::Required), + reg("b", b.clone(), 10, FailureMode::BestEffort), + ]); + + // Using a non-empty message vec for a robust positive-count check: + let scroll = Uuid::new_v4(); + let m = sample_message(scroll); + archivist + .append_messages(scroll, vec![m], None) + .await + .unwrap(); + assert_eq!(a.appended_count(scroll), 1); + assert_eq!(b.appended_count(scroll), 1); +} + +#[tokio::test] +async fn best_effort_failure_does_not_propagate() { + let a = Arc::new(MockBackend::new()); + let b = Arc::new(MockBackend::new()); + b.inject_write_failures(1); + + let archivist = Archivist::from_registrations(vec![ + reg("a", a.clone(), 0, FailureMode::Required), + reg("b", b.clone(), 10, FailureMode::BestEffort), + ]); + + archivist + .append_messages(Uuid::new_v4(), vec![], None) + .await + .unwrap(); // Ok despite secondary failure + + let snapshot = archivist.list_archives_with_health().await; + let b_status = snapshot.iter().find(|s| s.name == "b").unwrap(); + assert!(matches!(b_status.health, HealthStatus::Degraded { .. })); +} + +#[tokio::test] +async fn required_secondary_failure_propagates() { + let a = Arc::new(MockBackend::new()); + let b = Arc::new(MockBackend::new()); + b.inject_write_failures(1); + + let archivist = Archivist::from_registrations(vec![ + reg("a", a.clone(), 0, FailureMode::Required), + reg("b", b.clone(), 10, FailureMode::Required), + ]); + + let err = archivist + .append_messages(Uuid::new_v4(), vec![], None) + .await; + assert!(err.is_err(), "expected error when required secondary fails"); +} + +#[tokio::test] +async fn explicit_archive_overrides_default_primary() { + let a = Arc::new(MockBackend::new()); + let b = Arc::new(MockBackend::new()); + let archivist = Archivist::from_registrations(vec![ + reg("a", a.clone(), 0, FailureMode::Required), + reg("b", b.clone(), 10, FailureMode::Required), + ]); + + let scroll = Uuid::new_v4(); + let m = sample_message(scroll); + archivist + .append_messages(scroll, vec![m], Some("b".into())) + .await + .unwrap(); + + // Both receive the write: b is explicit primary, a is secondary via fanout. + assert_eq!(a.appended_count(scroll), 1); + assert_eq!(b.appended_count(scroll), 1); +} diff --git a/crates/dirigent_archivist/tests/multi_backend_health_test.rs b/crates/dirigent_archivist/tests/multi_backend_health_test.rs new file mode 100644 index 0000000..65363c4 --- /dev/null +++ b/crates/dirigent_archivist/tests/multi_backend_health_test.rs @@ -0,0 +1,129 @@ +#![cfg(feature = "test-utils")] + +use std::sync::Arc; + +use dirigent_archivist::backend::mock::MockBackend; +use dirigent_archivist::backend::{ArchiveBackend, HealthStatus}; +use dirigent_archivist::coordinator::Archivist; +use dirigent_archivist::registry::{ArchiveRegistration, FailureMode, WritePolicy}; +use dirigent_archivist::types::SessionMetadata; +use uuid::Uuid; + +fn reg( + name: &str, + backend: Arc, + priority: u32, + failure: FailureMode, +) -> Arc { + Arc::new(ArchiveRegistration::new( + name.into(), + "mock", + backend as Arc, + true, + failure, + priority, + true, + WritePolicy::Inline, + None, + HealthStatus::Healthy, + )) +} + +#[tokio::test] +async fn five_consecutive_failures_drifts_to_unavailable() { + let primary = Arc::new(MockBackend::new()); + let secondary = Arc::new(MockBackend::new()); + secondary.inject_write_failures(10); + + let archivist = Archivist::from_registrations(vec![ + reg("primary", primary.clone(), 0, FailureMode::Required), + reg("secondary", secondary.clone(), 10, FailureMode::BestEffort), + ]); + + for _ in 0..5 { + archivist + .append_messages(Uuid::new_v4(), vec![], None) + .await + .ok(); + } + + let snapshot = archivist.list_archives_with_health().await; + let secondary_status = snapshot.iter().find(|s| s.name == "secondary").unwrap(); + assert!( + matches!(secondary_status.health, HealthStatus::Unavailable { .. }), + "secondary should be Unavailable after 5 failures; got {:?}", + secondary_status.health + ); +} + +#[tokio::test] +async fn success_after_failure_recovers_to_healthy() { + let backend = Arc::new(MockBackend::new()); + backend.inject_write_failures(1); + + let archivist = Archivist::from_registrations(vec![reg( + "only", + backend.clone(), + 0, + FailureMode::Required, + )]); + + // First call fails. + let _ = archivist + .append_messages(Uuid::new_v4(), vec![], None) + .await; + let snapshot = archivist.list_archives_with_health().await; + assert!( + matches!(snapshot[0].health, HealthStatus::Degraded { .. }), + "expected Degraded after first failure; got {:?}", + snapshot[0].health + ); + + // Second call succeeds — health returns to Healthy. + archivist + .append_messages(Uuid::new_v4(), vec![], None) + .await + .unwrap(); + let snapshot = archivist.list_archives_with_health().await; + assert!( + matches!(snapshot[0].health, HealthStatus::Healthy), + "expected Healthy after recovery; got {:?}", + snapshot[0].health + ); +} + +#[tokio::test] +async fn unavailable_backend_skipped_during_read_walk() { + let primary = Arc::new(MockBackend::new()); + let secondary = Arc::new(MockBackend::new()); + + let scroll = Uuid::new_v4(); + secondary + .put_session(SessionMetadata::stub(scroll)) + .await + .unwrap(); + secondary.break_permanently("kaput"); + + let primary_reg = reg("primary", primary.clone(), 0, FailureMode::Required); + let secondary_reg = reg("secondary", secondary.clone(), 10, FailureMode::Required); + // Force secondary's cached health to Unavailable BEFORE the walk, + // so the routing layer skips it entirely rather than attempting + failing. + *secondary_reg.last_health.write().await = HealthStatus::Unavailable { + reason: "test".into(), + }; + + let archivist = Archivist::from_registrations(vec![primary_reg, secondary_reg]); + + // Primary doesn't have the session; secondary has it but is marked Unavailable. + // Read walk should skip secondary → Ok(None)-style ergonomics, bubbling up + // as `SessionUnknown` per `get_session_metadata`'s contract. + let result = archivist.get_session_metadata(scroll, None).await; + assert!( + result.is_err(), + "expected SessionUnknown error when Unavailable backend is skipped" + ); + assert!(matches!( + result.unwrap_err(), + dirigent_archivist::error::ArchivistError::SessionUnknown(_) + )); +} diff --git a/crates/dirigent_archivist/tests/multi_backend_routing_test.rs b/crates/dirigent_archivist/tests/multi_backend_routing_test.rs new file mode 100644 index 0000000..7255957 --- /dev/null +++ b/crates/dirigent_archivist/tests/multi_backend_routing_test.rs @@ -0,0 +1,102 @@ +#![cfg(feature = "test-utils")] + +use std::sync::Arc; + +use dirigent_archivist::backend::mock::MockBackend; +use dirigent_archivist::backend::{ArchiveBackend, HealthStatus}; +use dirigent_archivist::coordinator::Archivist; +use dirigent_archivist::registry::{ArchiveRegistration, FailureMode, WritePolicy}; +use dirigent_archivist::types::SessionMetadata; +use uuid::Uuid; + +fn reg(name: &str, backend: Arc, priority: u32) -> Arc { + Arc::new(ArchiveRegistration::new( + name.into(), + "mock", + backend as Arc, + true, + FailureMode::Required, + priority, + true, + WritePolicy::Inline, + None, + HealthStatus::Healthy, + )) +} + +#[tokio::test] +async fn high_priority_backend_serves_first() { + let high = Arc::new(MockBackend::new()); + let low = Arc::new(MockBackend::new()); + let scroll = Uuid::new_v4(); + high.put_session(SessionMetadata::stub(scroll)).await.unwrap(); + + let archivist = Archivist::from_registrations(vec![ + reg("high", high.clone(), 0), + reg("low", low.clone(), 10), + ]); + + let meta = archivist.get_session_metadata(scroll, None).await; + assert!(meta.is_ok(), "expected Ok; got {:?}", meta); + assert_eq!(archivist.read_cache_size().await, 1); +} + +#[tokio::test] +async fn falls_through_to_lower_priority_when_high_misses() { + let high = Arc::new(MockBackend::new()); + let low = Arc::new(MockBackend::new()); + let scroll = Uuid::new_v4(); + low.put_session(SessionMetadata::stub(scroll)).await.unwrap(); + + let archivist = Archivist::from_registrations(vec![ + reg("high", high.clone(), 0), + reg("low", low.clone(), 10), + ]); + + let meta = archivist.get_session_metadata(scroll, None).await; + assert!(meta.is_ok(), "expected Ok; got {:?}", meta); + assert_eq!(archivist.read_cache_size().await, 1); +} + +#[tokio::test] +async fn cache_makes_second_read_skip_priority_walk() { + let high = Arc::new(MockBackend::new()); + let low = Arc::new(MockBackend::new()); + let scroll = Uuid::new_v4(); + low.put_session(SessionMetadata::stub(scroll)).await.unwrap(); + + let archivist = Archivist::from_registrations(vec![ + reg("high", high.clone(), 0), + reg("low", low.clone(), 10), + ]); + + // Prime the cache. + let _ = archivist.get_session_metadata(scroll, None).await.unwrap(); + + // Inject a read failure on `high` to detect whether the second read walks it. + // If the cache works, `high` must NOT be touched. + high.inject_read_failures(1); + let _ = archivist.get_session_metadata(scroll, None).await.unwrap(); + + let snapshot = archivist.list_archives_with_health().await; + let high_status = snapshot.iter().find(|s| s.name == "high").unwrap(); + assert!( + matches!(high_status.health, HealthStatus::Healthy), + "cache should have skipped `high`; got {:?}", + high_status.health + ); +} + +#[tokio::test] +async fn delete_invalidates_cache() { + let high = Arc::new(MockBackend::new()); + let scroll = Uuid::new_v4(); + high.put_session(SessionMetadata::stub(scroll)).await.unwrap(); + + let archivist = Archivist::from_registrations(vec![reg("high", high.clone(), 0)]); + let _ = archivist.get_session_metadata(scroll, None).await.unwrap(); + assert_eq!(archivist.read_cache_size().await, 1); + + archivist.delete_session(scroll, None).await.unwrap(); + assert_eq!(archivist.read_cache_size().await, 0); +} diff --git a/crates/dirigent_archivist/tests/multi_backend_writer_test.rs b/crates/dirigent_archivist/tests/multi_backend_writer_test.rs new file mode 100644 index 0000000..9ba8565 --- /dev/null +++ b/crates/dirigent_archivist/tests/multi_backend_writer_test.rs @@ -0,0 +1,252 @@ +#![cfg(feature = "test-utils")] + +//! Integration tests for Task 17's per-backend queued writer task. +//! +//! These exercise the full enqueue → batch → coalesce → dispatch pipeline +//! end-to-end by constructing real writer tasks against `MockBackend` +//! instances and driving them through the `Archivist` coordinator. +//! +//! The tests are timing-sensitive: the batch window is 25ms and the +//! backpressure test artificially slows the backend. Assertions use +//! tolerant margins so they survive CI jitter. + +use std::sync::Arc; +use std::time::Duration; + +use dirigent_archivist::backend::mock::MockBackend; +use dirigent_archivist::backend::{ArchiveBackend, HealthStatus}; +use dirigent_archivist::coordinator::Archivist; +use dirigent_archivist::registry::writer::spawn_writer; +use dirigent_archivist::registry::{ + ArchiveRegistration, FailureMode, OverflowPolicy, WritePolicy, +}; +use uuid::Uuid; + +fn sample_message(scroll: Uuid) -> dirigent_archivist::types::MessageRecord { + dirigent_archivist::types::MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: scroll, + parent_id: None, + ts: chrono::Utc::now(), + role: "user".into(), + author: None, + content_md: "hi".into(), + content_parts: None, + attachments: vec![], + metadata: serde_json::Value::Null, + } +} + +fn queued_reg( + name: &str, + backend: Arc, + priority: u32, + overflow: OverflowPolicy, +) -> Arc { + let initial_health = HealthStatus::Healthy; + let policy = WritePolicy::Queued { + batch_window_ms: 25, + capacity: 8, + overflow, + }; + + let health = Arc::new(tokio::sync::RwLock::new(initial_health)); + let last_error = Arc::new(tokio::sync::RwLock::new(None)); + let consecutive = Arc::new(tokio::sync::RwLock::new(0u32)); + + let writer = Some(spawn_writer( + backend.clone() as Arc, + name.into(), + 8, + Duration::from_millis(25), + overflow, + health.clone(), + last_error.clone(), + consecutive.clone(), + )); + + Arc::new(ArchiveRegistration::new_with_shared_state( + name.into(), + "mock", + backend as Arc, + true, + FailureMode::Required, + priority, + true, + policy, + writer, + health, + last_error, + consecutive, + )) +} + +#[tokio::test] +async fn queued_write_returns_immediately_then_eventually_lands() { + let mock = Arc::new(MockBackend::new()); + let archivist = Archivist::from_registrations(vec![queued_reg( + "queued", + mock.clone(), + 0, + OverflowPolicy::Block, + )]); + + let scroll = Uuid::new_v4(); + archivist + .append_messages(scroll, vec![sample_message(scroll)], None) + .await + .unwrap(); + + // Wait up to 500ms for the writer to drain. + let mut landed = false; + for _ in 0..50 { + if mock.appended_count(scroll) > 0 { + landed = true; + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + assert!(landed, "writer task did not drain within 500ms"); + assert_eq!(mock.appended_count(scroll), 1); + + archivist.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn coalescing_merges_consecutive_appends_for_same_scroll() { + let mock = Arc::new(MockBackend::new()); + let archivist = Archivist::from_registrations(vec![queued_reg( + "queued", + mock.clone(), + 0, + OverflowPolicy::Block, + )]); + + let scroll = Uuid::new_v4(); + for _ in 0..5 { + archivist + .append_messages(scroll, vec![sample_message(scroll)], None) + .await + .unwrap(); + } + + // Give the writer time to drain + coalesce, then shut down to guarantee + // any still-queued ops are flushed before we assert. + tokio::time::sleep(Duration::from_millis(200)).await; + archivist.shutdown().await.unwrap(); + + // Five enqueued ops may have been coalesced into fewer backend calls. + // The only strict invariant we can reliably assert is: the total number + // of backend `append_messages` INVOCATIONS is <= 5. + assert!( + mock.append_call_count(scroll) <= 5, + "expected <= 5 backend calls, got {}", + mock.append_call_count(scroll) + ); + assert_eq!( + mock.appended_count(scroll), + 5, + "all 5 messages should land" + ); +} + +#[tokio::test] +async fn overflow_block_applies_backpressure() { + // For backpressure to visibly stall the sender, we need four things: + // 1. A tight queue (capacity=2) so the channel actually fills up. + // 2. A slow backend (per-op 50ms) so the writer stalls in dispatch + // long enough for the channel to fill. + // 3. batch_window=0 so the writer spends (almost) all its time in + // the 50ms per-op sleep instead of draining fast inside the + // batch-collection phase. + // 4. Distinct scroll IDs so the writer's same-scroll coalescing + // doesn't merge everything into one dispatch call (which would + // collapse the entire batch into a single 50ms sleep). + // With those, the writer dispatches N serial 50ms calls; while it's + // sleeping the sender can't fit its next op into the full channel + // and must wait for a drain. + let mock = Arc::new(MockBackend::new()); + mock.set_per_op_delay(Duration::from_millis(50)); + + let capacity = 2usize; + let overflow = OverflowPolicy::Block; + // batch_window=0 means the writer dispatches each op immediately and + // spends (almost) all its time in the 50ms per-op sleep — so the + // channel stays full and the sender has to wait on every drain. + let policy = WritePolicy::Queued { + batch_window_ms: 0, + capacity, + overflow, + }; + + let health = Arc::new(tokio::sync::RwLock::new(HealthStatus::Healthy)); + let last_error = Arc::new(tokio::sync::RwLock::new(None)); + let consecutive = Arc::new(tokio::sync::RwLock::new(0u32)); + + let writer = Some(spawn_writer( + mock.clone() as Arc, + "queued".into(), + capacity, + Duration::from_millis(0), + overflow, + health.clone(), + last_error.clone(), + consecutive.clone(), + )); + + let reg = Arc::new(ArchiveRegistration::new_with_shared_state( + "queued".into(), + "mock", + mock.clone() as Arc, + true, + FailureMode::Required, + 0, + true, + policy, + writer, + health, + last_error, + consecutive, + )); + + let archivist = Archivist::from_registrations(vec![reg]); + + // Prime the writer with one op and wait just long enough for it to + // enter its first 50ms dispatch sleep. After that the writer is NOT + // recv'ing, so the tight capacity=2 channel fills and further sends + // must wait for a drain. + let scroll0 = Uuid::new_v4(); + archivist + .append_messages(scroll0, vec![sample_message(scroll0)], None) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(10)).await; + + // Now measure the cost of many more sends with distinct scroll IDs + // so the writer can't coalesce them. Each dispatch call is 50ms, the + // queue holds only 2, so the sender must wait repeatedly for the + // writer to drain cycles. + let start = std::time::Instant::now(); + for _ in 0..24 { + let scroll = Uuid::new_v4(); + archivist + .append_messages(scroll, vec![sample_message(scroll)], None) + .await + .unwrap(); + } + let elapsed = start.elapsed(); + + // With 24 distinct-scroll sends, a capacity=2 queue, batch_window=0, + // and a 50ms per-op delay, the sender cannot finish instantly — the + // writer needs many drain cycles and the sender waits on each. A + // 100ms floor keeps the test meaningful (a non-blocking run measures + // in microseconds) while being lenient on CI jitter. + assert!( + elapsed >= Duration::from_millis(100), + "block policy did not apply backpressure (elapsed: {:?})", + elapsed + ); + + archivist.shutdown().await.unwrap(); +} diff --git a/crates/dirigent_archivist/tests/pagination_test.rs b/crates/dirigent_archivist/tests/pagination_test.rs new file mode 100644 index 0000000..b26191b --- /dev/null +++ b/crates/dirigent_archivist/tests/pagination_test.rs @@ -0,0 +1,142 @@ +//! Pagination tests for dirigent_archivist +//! +//! These tests verify the count_messages and get_messages_range functionality. + +#[cfg(test)] +mod pagination_tests { + use chrono::Utc; + use dirigent_archivist::{ + backends::JsonlBackend, Archivist, MessageRecord, RegisterConnectorRequest, + RegisterSessionRequest, Result, + }; + use std::sync::Arc; + use uuid::Uuid; + + /// Build a self-contained coordinator rooted at `archive_root`, backed by + /// a single `JsonlBackend`. Avoids the shared `.archives.json` race that + /// `new_with_single_archive` creates in the tempdir's parent. + async fn mk_archivist(archive_root: std::path::PathBuf) -> Result { + let backend = Arc::new(JsonlBackend::new(archive_root).await?); + Archivist::from_single_backend("main".into(), backend).await + } + + #[tokio::test] + async fn test_pagination_count_and_range() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let archivist = mk_archivist(temp_dir.clone()).await?; + + // Register connector + let connector_req = RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Test Connector".to_string(), + client_native_id: "opencode@localhost:12225".to_string(), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }; + let connector_response = archivist.register_connector(connector_req, None).await?; + + // Register session + let session_req = RegisterSessionRequest { + connector_uid: connector_response.connector_uid, + native_session_id: "native-123".to_string(), + title: Some("Pagination Test".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }; + let session_response = archivist.register_session(session_req, None).await?; + let scroll_id = session_response.scroll_id; + + // Test empty session + let count = archivist.count_messages(scroll_id, None).await?; + assert_eq!(count, 0, "Empty session should have 0 messages"); + + let range = archivist.get_messages_range(scroll_id, 0, 10, None).await?; + assert_eq!(range.len(), 0, "Empty session should return empty range"); + + // Add 25 messages with varying timestamps + let mut messages = Vec::new(); + let base_time = Utc::now(); + for i in 0..25 { + messages.push(MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: scroll_id, + parent_id: None, + ts: base_time + chrono::Duration::seconds(i), + role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(), + author: Some("test".to_string()), + content_md: format!("Message {}", i), + content_parts: None, + attachments: Vec::new(), + metadata: serde_json::json!({}), + }); + } + archivist.append_messages(scroll_id, messages, None).await?; + + // Test count_messages + let count = archivist.count_messages(scroll_id, None).await?; + assert_eq!(count, 25, "Should count 25 messages"); + + // Test get_messages_range - first page + let page1 = archivist.get_messages_range(scroll_id, 0, 10, None).await?; + assert_eq!(page1.len(), 10, "First page should have 10 messages"); + assert_eq!(page1[0].content_md, "Message 0", "First message should be Message 0"); + assert_eq!(page1[9].content_md, "Message 9", "10th message should be Message 9"); + + // Test get_messages_range - second page + let page2 = archivist.get_messages_range(scroll_id, 10, 10, None).await?; + assert_eq!(page2.len(), 10, "Second page should have 10 messages"); + assert_eq!(page2[0].content_md, "Message 10", "11th message should be Message 10"); + assert_eq!(page2[9].content_md, "Message 19", "20th message should be Message 19"); + + // Test get_messages_range - partial last page + let page3 = archivist.get_messages_range(scroll_id, 20, 10, None).await?; + assert_eq!(page3.len(), 5, "Last page should have 5 messages"); + assert_eq!(page3[0].content_md, "Message 20", "21st message should be Message 20"); + assert_eq!(page3[4].content_md, "Message 24", "25th message should be Message 24"); + + // Test get_messages_range - offset beyond messages + let page4 = archivist.get_messages_range(scroll_id, 30, 10, None).await?; + assert_eq!(page4.len(), 0, "Offset beyond messages should return empty"); + + // Verify chronological ordering is maintained in pagination + let all_messages = archivist.get_messages(scroll_id, None).await?; + let first_10_from_all = &all_messages[0..10]; + let first_10_from_page = &page1[..]; + + for i in 0..10 { + assert_eq!( + first_10_from_all[i].message_id, + first_10_from_page[i].message_id, + "Pagination should maintain same order as get_messages()" + ); + } + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } + + #[tokio::test] + async fn test_count_messages_nonexistent_session() -> Result<()> { + let temp_dir = std::env::temp_dir().join(format!("archivist_test_{}", Uuid::now_v7())); + let archivist = mk_archivist(temp_dir.clone()).await?; + + // Count messages for non-existent session (should return 0, not error) + let nonexistent_scroll_id = Uuid::now_v7(); + let count = archivist.count_messages(nonexistent_scroll_id, None).await?; + assert_eq!(count, 0, "Non-existent session should have 0 messages"); + + // Clean up + tokio::fs::remove_dir_all(temp_dir).await.ok(); + Ok(()) + } +} diff --git a/crates/dirigent_auth/Cargo.toml b/crates/dirigent_auth/Cargo.toml new file mode 100644 index 0000000..864beaa --- /dev/null +++ b/crates/dirigent_auth/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "dirigent_auth" +version = "0.1.0" +edition = "2021" + +[dependencies] +chrono = { version = "0.4", features = ["serde"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "2.0" +uuid = { version = "1.0", features = ["js", "serde", "v7"] } + +[dev-dependencies] diff --git a/crates/dirigent_auth/src/account.rs b/crates/dirigent_auth/src/account.rs new file mode 100644 index 0000000..53c5e9b --- /dev/null +++ b/crates/dirigent_auth/src/account.rs @@ -0,0 +1,173 @@ +//! Account model for Dirigent identity management. + +use std::collections::HashMap; +use serde::{Deserialize, Serialize}; +use crate::{secret::SecretSource, UserId}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AccountKind { + Local, + Matrix, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct AccountProfile { + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub email: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub username: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub display_name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Account { + #[serde(rename = "type")] + pub kind: AccountKind, + + #[serde(skip)] + pub config_name: String, + + #[serde(skip)] + pub user_id: Option, + + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub credentials: HashMap, + + #[serde(flatten)] + pub profile: AccountProfile, + + #[serde(flatten)] + pub properties: HashMap, +} + +impl Account { + pub fn resolve_credential(&self, name: &str) -> Result { + self.credentials + .get(name) + .ok_or_else(|| crate::SecretError::EnvNotSet { + key: format!("", name), + }) + .and_then(|s| s.resolve()) + } + + pub fn property_str(&self, key: &str) -> Option<&str> { + self.properties.get(key).and_then(|v| v.as_str()) + } + + pub fn property_str_or<'a>(&'a self, key: &str, default: &'a str) -> &'a str { + self.property_str(key).unwrap_or(default) + } + + pub fn display_name(&self) -> &str { + self.profile.display_name.as_deref() + .or(self.profile.name.as_deref()) + .or(self.profile.username.as_deref()) + .unwrap_or(&self.config_name) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_local() -> Account { + Account { + kind: AccountKind::Local, + config_name: "local".to_string(), + user_id: None, + credentials: HashMap::new(), + profile: AccountProfile { name: Some("Gabriel".to_string()), ..Default::default() }, + properties: HashMap::new(), + } + } + + fn sample_matrix() -> Account { + let mut creds = HashMap::new(); + creds.insert("password".to_string(), SecretSource::Inline { value: "bot_pass".to_string() }); + let mut props = HashMap::new(); + props.insert("homeserver".to_string(), serde_json::json!("https://matrix.example.com")); + props.insert("device_id".to_string(), serde_json::json!("DIRIGENT_01")); + Account { + kind: AccountKind::Matrix, + config_name: "matrix-bot".to_string(), + user_id: None, + credentials: creds, + profile: AccountProfile { + username: Some("dirigent_bot".to_string()), + display_name: Some("Dirigent Bot".to_string()), + ..Default::default() + }, + properties: props, + } + } + + #[test] + fn test_local_display_name() { + assert_eq!(sample_local().display_name(), "Gabriel"); + } + + #[test] + fn test_matrix_display_name() { + assert_eq!(sample_matrix().display_name(), "Dirigent Bot"); + } + + #[test] + fn test_fallback_to_config_name() { + let acct = Account { + kind: AccountKind::Local, + config_name: "fallback".to_string(), + user_id: None, + credentials: HashMap::new(), + profile: AccountProfile::default(), + properties: HashMap::new(), + }; + assert_eq!(acct.display_name(), "fallback"); + } + + #[test] + fn test_resolve_credential() { + assert_eq!(sample_matrix().resolve_credential("password").unwrap(), "bot_pass"); + } + + #[test] + fn test_missing_credential() { + assert!(sample_local().resolve_credential("password").is_err()); + } + + #[test] + fn test_property_str() { + let acct = sample_matrix(); + assert_eq!(acct.property_str("homeserver"), Some("https://matrix.example.com")); + assert_eq!(acct.property_str("nonexistent"), None); + } + + #[test] + fn test_property_str_or() { + let acct = sample_matrix(); + assert_eq!(acct.property_str_or("device_id", "DEFAULT"), "DIRIGENT_01"); + assert_eq!(acct.property_str_or("missing", "DEFAULT"), "DEFAULT"); + } + + #[test] + fn test_account_kind_serde() { + let json = serde_json::to_string(&AccountKind::Matrix).unwrap(); + assert_eq!(json, r#""matrix""#); + let back: AccountKind = serde_json::from_str(&json).unwrap(); + assert_eq!(back, AccountKind::Matrix); + } + + #[test] + fn test_account_serde_roundtrip() { + let acct = sample_matrix(); + let json = serde_json::to_string(&acct).unwrap(); + let back: Account = serde_json::from_str(&json).unwrap(); + assert_eq!(back.kind, AccountKind::Matrix); + assert_eq!(back.profile.display_name, Some("Dirigent Bot".to_string())); + assert!(back.credentials.contains_key("password")); + assert_eq!(back.property_str("homeserver"), Some("https://matrix.example.com")); + } +} diff --git a/crates/dirigent_auth/src/lib.rs b/crates/dirigent_auth/src/lib.rs new file mode 100644 index 0000000..803eb4e --- /dev/null +++ b/crates/dirigent_auth/src/lib.rs @@ -0,0 +1,156 @@ +//! Dirigent Auth +//! +//! Tiny user identity crate for the Dirigent system. No async deps. +//! Usable from both server and WASM targets. +//! +//! Provides the core `UserId`, `User`, and `UserProfile` types used +//! throughout the Dirigent ecosystem for ownership tracking and +//! authorization. + +pub mod secret; +pub use secret::{SecretSource, SecretError}; + +pub mod account; +pub use account::{Account, AccountKind, AccountProfile}; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +/// Unique identifier for a user. +/// +/// Uses UUID v7 (time-ordered) for new users. Existing string-based +/// IDs should be migrated to UUIDs at creation time. +pub type UserId = Uuid; + +/// User information. +/// +/// Represents a user in the Dirigent system with profile data and +/// creation timestamp. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct User { + /// Unique user identifier (UUID v7) + pub id: UserId, + /// User profile with optional display fields + pub profile: UserProfile, + /// When this user was created + pub created_at: DateTime, +} + +impl User { + /// Create a new user with a generated UUID v7 and the given profile. + pub fn new(profile: UserProfile) -> Self { + Self { + id: Uuid::now_v7(), + profile, + created_at: Utc::now(), + } + } + + /// Create a new user with a specific ID. + pub fn with_id(id: UserId, profile: UserProfile) -> Self { + Self { + id, + profile, + created_at: Utc::now(), + } + } + + /// Get the display name, falling back to username or "Unknown". + pub fn display_name(&self) -> &str { + self.profile + .name + .as_deref() + .or(self.profile.username.as_deref()) + .unwrap_or("Unknown") + } +} + +/// User profile information. +/// +/// All fields are optional to allow progressive enrichment. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct UserProfile { + /// Display name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// Email address + #[serde(skip_serializing_if = "Option::is_none")] + pub email: Option, + /// Username / handle + #[serde(skip_serializing_if = "Option::is_none")] + pub username: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_user_creation() { + let user = User::new(UserProfile { + name: Some("Test User".to_string()), + ..Default::default() + }); + assert_eq!(user.display_name(), "Test User"); + assert_eq!(user.id.get_version_num(), 7); + } + + #[test] + fn test_user_with_id() { + let id = Uuid::nil(); + let user = User::with_id( + id, + UserProfile { + name: Some("Nil User".to_string()), + ..Default::default() + }, + ); + assert_eq!(user.id, Uuid::nil()); + assert_eq!(user.display_name(), "Nil User"); + } + + #[test] + fn test_display_name_fallbacks() { + // Falls back to username + let user = User::new(UserProfile { + username: Some("jdoe".to_string()), + ..Default::default() + }); + assert_eq!(user.display_name(), "jdoe"); + + // Falls back to "Unknown" + let user = User::new(UserProfile::default()); + assert_eq!(user.display_name(), "Unknown"); + } + + #[test] + fn test_serialization_roundtrip() { + let user = User::new(UserProfile { + name: Some("Test".to_string()), + email: Some("test@example.com".to_string()), + username: None, + }); + + let json = serde_json::to_string(&user).expect("serialize"); + let deserialized: User = serde_json::from_str(&json).expect("deserialize"); + + assert_eq!(deserialized.id, user.id); + assert_eq!(deserialized.profile.name, user.profile.name); + assert_eq!(deserialized.profile.email, user.profile.email); + assert!(deserialized.profile.username.is_none()); + } + + #[test] + fn test_profile_skip_none_fields() { + let profile = UserProfile { + name: Some("Test".to_string()), + email: None, + username: None, + }; + let json = serde_json::to_string(&profile).expect("serialize"); + assert!(!json.contains("email")); + assert!(!json.contains("username")); + assert!(json.contains("name")); + } +} diff --git a/crates/dirigent_auth/src/secret.rs b/crates/dirigent_auth/src/secret.rs new file mode 100644 index 0000000..38353b9 --- /dev/null +++ b/crates/dirigent_auth/src/secret.rs @@ -0,0 +1,98 @@ +//! Credential resolution for Dirigent accounts. + +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum SecretError { + #[error("Environment variable '{key}' not set")] + EnvNotSet { key: String }, + #[error("Failed to read secret file '{path}': {reason}")] + FileReadFailed { path: String, reason: String }, +} + +/// Describes how to retrieve a secret value (password, token, etc.). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "source")] +pub enum SecretSource { + #[serde(rename = "env")] + Env { key: String }, + #[serde(rename = "inline")] + Inline { value: String }, + #[serde(rename = "file")] + File { path: String }, +} + +impl SecretSource { + pub fn resolve(&self) -> Result { + match self { + SecretSource::Env { key } => { + std::env::var(key).map_err(|_| SecretError::EnvNotSet { key: key.clone() }) + } + SecretSource::Inline { value } => Ok(value.clone()), + SecretSource::File { path } => std::fs::read_to_string(path) + .map(|s| s.trim().to_string()) + .map_err(|e| SecretError::FileReadFailed { + path: path.clone(), + reason: e.to_string(), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_inline_resolve() { + let src = SecretSource::Inline { value: "hunter2".to_string() }; + assert_eq!(src.resolve().unwrap(), "hunter2"); + } + + #[test] + fn test_env_resolve() { + std::env::set_var("DIRIGENT_TEST_SECRET_7382", "env_value"); + let src = SecretSource::Env { key: "DIRIGENT_TEST_SECRET_7382".to_string() }; + assert_eq!(src.resolve().unwrap(), "env_value"); + std::env::remove_var("DIRIGENT_TEST_SECRET_7382"); + } + + #[test] + fn test_env_missing() { + let src = SecretSource::Env { key: "DIRIGENT_NONEXISTENT_VAR_99999".to_string() }; + assert!(src.resolve().is_err()); + } + + #[test] + fn test_file_missing() { + let src = SecretSource::File { path: "/tmp/dirigent_nonexistent_secret_file".to_string() }; + assert!(src.resolve().is_err()); + } + + #[test] + fn test_serde_roundtrip_env() { + let src = SecretSource::Env { key: "MY_VAR".to_string() }; + let json = serde_json::to_string(&src).unwrap(); + assert!(json.contains(r#""source":"env"#)); + let back: SecretSource = serde_json::from_str(&json).unwrap(); + assert!(matches!(back, SecretSource::Env { key } if key == "MY_VAR")); + } + + #[test] + fn test_serde_roundtrip_inline() { + let src = SecretSource::Inline { value: "secret".to_string() }; + let json = serde_json::to_string(&src).unwrap(); + let back: SecretSource = serde_json::from_str(&json).unwrap(); + assert!(matches!(back, SecretSource::Inline { value } if value == "secret")); + } + + #[test] + fn test_serde_roundtrip_file() { + let src = SecretSource::File { path: "/run/secrets/pw".to_string() }; + let json = serde_json::to_string(&src).unwrap(); + assert!(json.contains(r#""source":"file"#)); + let back: SecretSource = serde_json::from_str(&json).unwrap(); + assert!(matches!(back, SecretSource::File { path } if path == "/run/secrets/pw")); + } +} diff --git a/crates/dirigent_chatgpt/CLAUDE.md b/crates/dirigent_chatgpt/CLAUDE.md new file mode 100644 index 0000000..9af4a4f --- /dev/null +++ b/crates/dirigent_chatgpt/CLAUDE.md @@ -0,0 +1,32 @@ +# Package: dirigent_chatgpt + +Pure-Rust parser for OpenAI's ChatGPT `conversations.json` data export. + +## Scope + +- `parse_export(path)` — reads a `conversations.json` file on disk and + returns `Vec`. +- `parse_str(json)` — parses an in-memory JSON string (useful for tests + and piped inputs). +- Types: `ParsedConversation`, `ParsedMessage`, `ContentPart` (`Text`, + `Code`, `Tool`). + +No dirigent-specific types. `dirigent_archivist::import::sources::chatgpt` +consumes this crate and maps into the archivist's internal types. + +## Example + +```rust +let convs = dirigent_chatgpt::parse_export(path)?; +for c in convs { + println!("{}: {} messages", c.title.as_deref().unwrap_or("(untitled)"), c.messages.len()); +} +``` + +## Failure modes + +- Truly broken JSON → `ParseError::Json`. +- Malformed individual messages are skipped where possible. +- Unknown content shapes are preserved as best-effort text in + `ContentPart::Text { text: raw_json }` so no user data is silently + lost. diff --git a/crates/dirigent_chatgpt/Cargo.toml b/crates/dirigent_chatgpt/Cargo.toml new file mode 100644 index 0000000..f52e6a0 --- /dev/null +++ b/crates/dirigent_chatgpt/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "dirigent_chatgpt" +version = "0.1.0" +edition = "2021" + +[dependencies] +chrono = { version = "0.4", features = ["serde"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "1" +uuid = { version = "1", features = ["v4", "v7", "serde"] } + +[dev-dependencies] diff --git a/crates/dirigent_chatgpt/src/lib.rs b/crates/dirigent_chatgpt/src/lib.rs new file mode 100644 index 0000000..164f075 --- /dev/null +++ b/crates/dirigent_chatgpt/src/lib.rs @@ -0,0 +1,7 @@ +//! ChatGPT export parser. Zero dirigent-specific types. + +pub mod parser; +pub mod types; + +pub use parser::{parse_export, parse_str, ParseError}; +pub use types::{ContentPart, ParsedConversation, ParsedMessage}; diff --git a/crates/dirigent_chatgpt/src/parser.rs b/crates/dirigent_chatgpt/src/parser.rs new file mode 100644 index 0000000..00370f7 --- /dev/null +++ b/crates/dirigent_chatgpt/src/parser.rs @@ -0,0 +1,349 @@ +use std::path::Path; +use chrono::{DateTime, TimeZone, Utc}; +use thiserror::Error; + +use crate::types::{ContentPart, ParsedConversation, ParsedMessage}; + +#[derive(Debug, Error)] +pub enum ParseError { + #[error("I/O: {0}")] Io(#[from] std::io::Error), + #[error("JSON: {0}")] Json(#[from] serde_json::Error), + #[error("unsupported shape: {0}")] UnsupportedShape(String), +} + +/// Parse a ChatGPT `conversations.json` file into a list of conversations. +pub fn parse_export(path: &Path) -> Result, ParseError> { + let text = std::fs::read_to_string(path)?; + parse_str(&text) +} + +/// Parse a JSON string of conversations. +pub fn parse_str(json: &str) -> Result, ParseError> { + // ChatGPT conversations.json is a JSON array of conversation objects. + let root: serde_json::Value = serde_json::from_str(json)?; + let arr = root.as_array() + .ok_or_else(|| ParseError::UnsupportedShape("expected JSON array at root".into()))?; + let mut out = Vec::with_capacity(arr.len()); + for conv in arr { + out.push(convert_conversation(conv)?); + } + Ok(out) +} + +fn convert_conversation(conv: &serde_json::Value) -> Result { + let id = conv.get("id") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let title = conv.get("title").and_then(|v| v.as_str()).map(String::from); + let created_at = conv.get("create_time").and_then(parse_unix_time); + let updated_at = conv.get("update_time").and_then(parse_unix_time); + + // Walk the mapping tree if present; otherwise return empty messages. + let messages = if let Some(mapping) = conv.get("mapping").and_then(|v| v.as_object()) { + walk_mapping(mapping) + } else { + Vec::new() + }; + + // Preserve whatever metadata we can't otherwise capture. + let mut metadata = serde_json::Map::new(); + for key in &["conversation_id", "gizmo_id", "model", "default_model_slug", "moderation_results"] { + if let Some(v) = conv.get(*key) { + metadata.insert((*key).to_string(), v.clone()); + } + } + + Ok(ParsedConversation { + id, + title, + created_at, + updated_at, + messages, + metadata: if metadata.is_empty() { + serde_json::Value::Null + } else { + serde_json::Value::Object(metadata) + }, + }) +} + +/// Walk the `mapping` tree starting at the root (parent=null), DFS in +/// `create_time` order, collecting non-null messages. +fn walk_mapping( + mapping: &serde_json::Map, +) -> Vec { + // Find roots: nodes with parent == null, or (fallback) nodes not + // referenced as a child by any other node. + let mut roots: Vec<&str> = mapping + .iter() + .filter_map(|(id, node)| { + let parent = node.get("parent"); + let is_root = match parent { + None => true, + Some(serde_json::Value::Null) => true, + _ => false, + }; + if is_root { + Some(id.as_str()) + } else { + None + } + }) + .collect(); + + // Fallback: if we didn't find any via the `parent` signal, derive from + // the child-set (a root is a node that nobody lists as a child). + if roots.is_empty() { + let mut referenced: std::collections::HashSet<&str> = std::collections::HashSet::new(); + for node in mapping.values() { + if let Some(children) = node.get("children").and_then(|c| c.as_array()) { + for child in children { + if let Some(s) = child.as_str() { + referenced.insert(s); + } + } + } + } + roots = mapping + .keys() + .filter(|k| !referenced.contains(k.as_str())) + .map(|s| s.as_str()) + .collect(); + } + + let mut out: Vec = Vec::new(); + let mut visited: std::collections::HashSet = std::collections::HashSet::new(); + for root_id in roots { + dfs_collect(mapping, root_id, &mut out, &mut visited); + } + out +} + +fn dfs_collect( + mapping: &serde_json::Map, + node_id: &str, + out: &mut Vec, + visited: &mut std::collections::HashSet, +) { + if !visited.insert(node_id.to_string()) { + return; + } + let node = match mapping.get(node_id) { + Some(n) => n, + None => return, + }; + if let Some(msg) = node.get("message") { + if !msg.is_null() { + if let Some(parsed) = parse_mapping_message(msg) { + out.push(parsed); + } + } + } + + // Children, sorted by create_time when available for deterministic order. + if let Some(children) = node.get("children").and_then(|c| c.as_array()) { + let mut child_refs: Vec<(&str, Option)> = children + .iter() + .filter_map(|v| v.as_str()) + .map(|id| { + let ct = mapping + .get(id) + .and_then(|n| n.get("message")) + .and_then(|m| m.get("create_time")) + .and_then(|v| v.as_f64()); + (id, ct) + }) + .collect(); + child_refs.sort_by(|a, b| match (a.1, b.1) { + (Some(x), Some(y)) => x.partial_cmp(&y).unwrap_or(std::cmp::Ordering::Equal), + (Some(_), None) => std::cmp::Ordering::Less, + (None, Some(_)) => std::cmp::Ordering::Greater, + (None, None) => std::cmp::Ordering::Equal, + }); + for (child_id, _) in child_refs { + dfs_collect(mapping, child_id, out, visited); + } + } +} + +fn parse_mapping_message(msg: &serde_json::Value) -> Option { + let id = msg.get("id").and_then(|v| v.as_str()).unwrap_or("").to_string(); + + let role = msg + .get("author") + .and_then(|a| a.get("role")) + .and_then(|r| r.as_str()) + .unwrap_or("unknown") + .to_string(); + + let ts = msg.get("create_time").and_then(parse_unix_time); + + let content = msg + .get("content") + .map(content_to_parts) + .unwrap_or_default(); + + // Skip purely empty system placeholders (common at the root of chats). + if content.iter().all(|p| is_part_empty(p)) && role == "system" { + return None; + } + + let mut metadata = serde_json::Map::new(); + if let Some(m) = msg.get("metadata").and_then(|v| v.as_object()) { + for (k, v) in m { + metadata.insert(k.clone(), v.clone()); + } + } + if let Some(author) = msg.get("author").and_then(|v| v.as_object()) { + if let Some(name) = author.get("name") { + metadata.insert("author_name".to_string(), name.clone()); + } + } + + Some(ParsedMessage { + id, + role, + ts, + content, + metadata: if metadata.is_empty() { + serde_json::Value::Null + } else { + serde_json::Value::Object(metadata) + }, + }) +} + +fn is_part_empty(p: &ContentPart) -> bool { + match p { + ContentPart::Text { text } => text.trim().is_empty(), + ContentPart::Code { text, .. } => text.trim().is_empty(), + ContentPart::Tool { .. } => false, + } +} + +/// Convert a `content` blob (various shapes) into a list of `ContentPart`s. +fn content_to_parts(content: &serde_json::Value) -> Vec { + // Typical shape: { "content_type": "text", "parts": [ ... ] } + // Other content types seen in the wild: "code", "tether_browsing_display", + // "multimodal_text", "execution_output", "system_error". We do a best-effort + // normalisation here; Task 8+ can specialise further. + let content_type = content.get("content_type").and_then(|v| v.as_str()).unwrap_or("text"); + + if let Some(parts) = content.get("parts").and_then(|v| v.as_array()) { + return parts.iter().map(|p| part_to_content_part(p, content_type)).collect(); + } + + // `content_type = "code"` carries { language, text } + if content_type == "code" { + let language = content.get("language").and_then(|v| v.as_str()).map(String::from); + let text = content + .get("text") + .and_then(|v| v.as_str()) + .map(String::from) + .unwrap_or_default(); + return vec![ContentPart::Code { language, text }]; + } + + // `content_type = "tether_browsing_display"` / execution_output carry + // `text` or `result` fields — treat as text. + if let Some(text) = content.get("text").and_then(|v| v.as_str()) { + return vec![ContentPart::Text { text: text.to_string() }]; + } + if let Some(text) = content.get("result").and_then(|v| v.as_str()) { + return vec![ContentPart::Text { text: text.to_string() }]; + } + + // Unknown shape — serialize the raw JSON. + vec![ContentPart::Text { text: content.to_string() }] +} + +fn part_to_content_part(part: &serde_json::Value, outer_type: &str) -> ContentPart { + // String — plain text (or code, depending on outer content_type). + if let Some(s) = part.as_str() { + if outer_type == "code" { + return ContentPart::Code { language: None, text: s.to_string() }; + } + return ContentPart::Text { text: s.to_string() }; + } + + // Object — inspect fields. + if let Some(obj) = part.as_object() { + // Multimodal text part: { "text": "...", ... } + if let Some(text) = obj.get("text").and_then(|v| v.as_str()) { + return ContentPart::Text { text: text.to_string() }; + } + + // Tool-ish shape: { "tool": "...", "input": {...}, "output": ... } + // (ChatGPT's actual tool shape varies; this is a best-effort catch.) + if let (Some(name), Some(input)) = ( + obj.get("name").or_else(|| obj.get("tool")).and_then(|v| v.as_str()), + obj.get("input"), + ) { + return ContentPart::Tool { + name: name.to_string(), + input: input.clone(), + output: obj.get("output").cloned(), + }; + } + + // Image / asset-pointer parts: describe them inline. + if let Some(asset) = obj.get("asset_pointer").and_then(|v| v.as_str()) { + return ContentPart::Text { text: format!("[asset: {}]", asset) }; + } + } + + // Unknown shape — serialise the raw JSON. + ContentPart::Text { text: part.to_string() } +} + +/// Parse a ChatGPT unix-seconds timestamp (may be float or int, may be null). +fn parse_unix_time(v: &serde_json::Value) -> Option> { + let seconds = v.as_f64()?; + if !seconds.is_finite() { + return None; + } + let secs = seconds.trunc() as i64; + let nanos = ((seconds - secs as f64) * 1_000_000_000.0).round() as u32; + Utc.timestamp_opt(secs, nanos).single() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_minimal_fixture() { + let path = std::path::Path::new("tests/fixtures/minimal.json"); + let convs = parse_export(path).expect("parse"); + assert_eq!(convs.len(), 1); + assert_eq!(convs[0].title.as_deref(), Some("Hello")); + assert_eq!(convs[0].messages.len(), 2); + assert_eq!(convs[0].messages[0].role, "user"); + assert_eq!(convs[0].messages[1].role, "assistant"); + } + + #[test] + fn parses_timestamps() { + let path = std::path::Path::new("tests/fixtures/minimal.json"); + let convs = parse_export(path).expect("parse"); + let c = &convs[0]; + assert!(c.created_at.is_some()); + assert!(c.updated_at.is_some()); + // Message timestamps should be derived from create_time. + assert!(c.messages[0].ts.is_some()); + assert!(c.messages[1].ts.is_some()); + } + + #[test] + fn text_content_extracted() { + let path = std::path::Path::new("tests/fixtures/minimal.json"); + let convs = parse_export(path).expect("parse"); + let msg0 = &convs[0].messages[0]; + assert_eq!(msg0.content.len(), 1); + match &msg0.content[0] { + ContentPart::Text { text } => assert_eq!(text, "Hello, world"), + other => panic!("expected Text, got {:?}", other), + } + } +} diff --git a/crates/dirigent_chatgpt/src/types.rs b/crates/dirigent_chatgpt/src/types.rs new file mode 100644 index 0000000..55b4bab --- /dev/null +++ b/crates/dirigent_chatgpt/src/types.rs @@ -0,0 +1,31 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParsedConversation { + pub id: String, // ChatGPT's conversation_id (hex-ish; may or may not be a UUID) + pub title: Option, + pub created_at: Option>, + pub updated_at: Option>, + pub messages: Vec, + #[serde(default)] + pub metadata: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParsedMessage { + pub id: String, + pub role: String, // "user" | "assistant" | "system" | "tool" + pub ts: Option>, + pub content: Vec, + #[serde(default)] + pub metadata: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentPart { + Text { text: String }, + Code { language: Option, text: String }, + Tool { name: String, input: serde_json::Value, output: Option }, +} diff --git a/crates/dirigent_chatgpt/tests/fixtures/minimal.json b/crates/dirigent_chatgpt/tests/fixtures/minimal.json new file mode 100644 index 0000000..1912486 --- /dev/null +++ b/crates/dirigent_chatgpt/tests/fixtures/minimal.json @@ -0,0 +1,31 @@ +[ + { + "id": "c1", + "title": "Hello", + "create_time": 1700000000.0, + "update_time": 1700000100.0, + "mapping": { + "root": { "id": "root", "message": null, "children": ["m1"] }, + "m1": { + "id": "m1", + "message": { + "id": "m1", + "author": { "role": "user" }, + "create_time": 1700000010.0, + "content": { "content_type": "text", "parts": ["Hello, world"] } + }, + "children": ["m2"] + }, + "m2": { + "id": "m2", + "message": { + "id": "m2", + "author": { "role": "assistant" }, + "create_time": 1700000020.0, + "content": { "content_type": "text", "parts": ["Hi!"] } + }, + "children": [] + } + } + } +] diff --git a/crates/dirigent_codex/CLAUDE.md b/crates/dirigent_codex/CLAUDE.md new file mode 100644 index 0000000..03b4702 --- /dev/null +++ b/crates/dirigent_codex/CLAUDE.md @@ -0,0 +1,30 @@ +# Package: dirigent_codex + +Pure-Rust parser for OpenAI Codex JSONL session files. + +## Scope + +- `parse_file(path)` — reads one `*.jsonl` session file on disk and + returns a `ParsedSession`. +- `discover_sessions(dir)` — scans a directory (e.g. + `~/.codex/sessions/`) for session files. +- Types: `ParsedSession`, `ParsedMessage`. + +No dirigent-specific types. `dirigent_archivist::import::sources::codex` +consumes this crate and maps into the archivist's internal types. + +## Example + +```rust +let sessions = dirigent_codex::discover_sessions(dir)?; +for s in sessions { + println!("{}: {} messages", s.id, s.messages.len()); +} +``` + +## Failure modes + +- Individual malformed JSONL lines are skipped where possible. +- Truly broken files return `ParseError::Json`. +- Unknown message shapes are preserved as best-effort text so no user + data is silently lost. diff --git a/crates/dirigent_codex/Cargo.toml b/crates/dirigent_codex/Cargo.toml new file mode 100644 index 0000000..903c73c --- /dev/null +++ b/crates/dirigent_codex/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "dirigent_codex" +version = "0.1.0" +edition = "2021" + +[dependencies] +chrono = { version = "0.4", features = ["serde"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "1" +uuid = { version = "1", features = ["v4", "v7", "serde"] } + +[dev-dependencies] +tempfile = "3" diff --git a/crates/dirigent_codex/src/lib.rs b/crates/dirigent_codex/src/lib.rs new file mode 100644 index 0000000..96d1272 --- /dev/null +++ b/crates/dirigent_codex/src/lib.rs @@ -0,0 +1,14 @@ +//! OpenAI Codex on-disk session parser. Zero dirigent-specific types. +//! +//! The Codex CLI persists its sessions as JSONL files under +//! `~/.codex/sessions/*.jsonl` (or a caller-supplied equivalent). Each line +//! is a best-effort event object with a `role`, some `content`, and an +//! optional timestamp. Exact schema varies across Codex versions, so this +//! parser is intentionally lenient: unknown/malformed lines are skipped, +//! not failed. + +pub mod parser; +pub mod types; + +pub use parser::{discover_sessions, parse_file, ParseError}; +pub use types::{ParsedMessage, ParsedSession}; diff --git a/crates/dirigent_codex/src/parser.rs b/crates/dirigent_codex/src/parser.rs new file mode 100644 index 0000000..7e5a9e3 --- /dev/null +++ b/crates/dirigent_codex/src/parser.rs @@ -0,0 +1,274 @@ +use std::path::{Path, PathBuf}; + +use chrono::{DateTime, TimeZone, Utc}; +use thiserror::Error; + +use crate::types::{ParsedMessage, ParsedSession}; + +#[derive(Debug, Error)] +pub enum ParseError { + #[error("I/O: {0}")] + Io(#[from] std::io::Error), + #[error("JSON: {0}")] + Json(#[from] serde_json::Error), + #[error("not found: {0}")] + NotFound(String), +} + +/// Walk `dir` (non-recursively) for Codex session JSONL files. +/// +/// Returns a deterministically ordered list (lexical by path) of every +/// `*.jsonl` file directly under `dir`. Returns `NotFound` if the directory +/// itself doesn't exist. +pub fn discover_sessions(dir: &Path) -> Result, ParseError> { + if !dir.exists() { + return Err(ParseError::NotFound(dir.display().to_string())); + } + let mut out = Vec::new(); + for entry in std::fs::read_dir(dir)? { + let entry = entry?; + let path = entry.path(); + if path.is_file() + && path.extension().and_then(|s| s.to_str()) == Some("jsonl") + { + out.push(path); + } + } + out.sort(); + Ok(out) +} + +/// Parse a single Codex session JSONL file. +/// +/// Malformed / unexpected lines are skipped (not fatal). Every line that +/// exposes a `role` and a `content` is turned into a [`ParsedMessage`]. +/// The session's `created_at` / `updated_at` bracket the first and last +/// message timestamps seen. +pub fn parse_file(path: &Path) -> Result { + let text = std::fs::read_to_string(path)?; + let native_id = path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("unknown") + .to_string(); + + let mut messages = Vec::new(); + let mut created_at: Option> = None; + let mut updated_at: Option> = None; + + for line in text.lines() { + if line.trim().is_empty() { + continue; + } + let val: serde_json::Value = match serde_json::from_str(line) { + Ok(v) => v, + Err(_) => continue, // skip malformed lines + }; + if let Some(msg) = extract_message(&val) { + if let Some(ts) = msg.ts { + if created_at.is_none() { + created_at = Some(ts); + } + updated_at = Some(ts); + } + messages.push(msg); + } + } + + Ok(ParsedSession { + native_id, + source_path: path.to_path_buf(), + created_at, + updated_at, + messages, + }) +} + +/// Best-effort extraction of a [`ParsedMessage`] from an arbitrary JSONL +/// event. Returns `None` if the shape doesn't carry a role + content. +fn extract_message(val: &serde_json::Value) -> Option { + let role = val.get("role").and_then(|v| v.as_str()).map(String::from)?; + let content = extract_content(val.get("content")?)?; + let ts = extract_ts(val); + Some(ParsedMessage { + ts, + role, + content, + metadata: val.clone(), + }) +} + +/// Flatten a `content` field into a single string. +/// +/// - string → as-is +/// - array → strings joined by `\n`; objects with a `text` field use that, +/// otherwise their raw JSON is stringified +/// - object → `text` field if present, otherwise the raw JSON +/// - null → `None` +fn extract_content(content: &serde_json::Value) -> Option { + match content { + serde_json::Value::String(s) => Some(s.clone()), + serde_json::Value::Array(arr) => { + let parts: Vec = arr + .iter() + .filter_map(|p| { + if let Some(s) = p.as_str() { + Some(s.to_string()) + } else if let Some(t) = p.get("text").and_then(|v| v.as_str()) { + Some(t.to_string()) + } else { + Some(p.to_string()) + } + }) + .collect(); + Some(parts.join("\n")) + } + serde_json::Value::Object(_) => { + if let Some(t) = content.get("text").and_then(|v| v.as_str()) { + Some(t.to_string()) + } else { + Some(content.to_string()) + } + } + serde_json::Value::Null => None, + other => Some(other.to_string()), + } +} + +/// Extract a timestamp from one of several possible fields. +/// +/// Accepts RFC 3339 strings or numeric unix-seconds (integer or float). +fn extract_ts(val: &serde_json::Value) -> Option> { + let candidate = val + .get("ts") + .or_else(|| val.get("timestamp")) + .or_else(|| val.get("created_at")) + .or_else(|| val.get("time"))?; + + if let Some(s) = candidate.as_str() { + if let Ok(dt) = DateTime::parse_from_rfc3339(s) { + return Some(dt.with_timezone(&Utc)); + } + } + if let Some(f) = candidate.as_f64() { + if f.is_finite() { + let secs = f.trunc() as i64; + let nanos = ((f - secs as f64) * 1_000_000_000.0).round() as u32; + return Utc.timestamp_opt(secs, nanos).single(); + } + } + if let Some(i) = candidate.as_i64() { + return Utc.timestamp_opt(i, 0).single(); + } + None +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + + fn write_jsonl(dir: &Path, name: &str, lines: &[&str]) -> PathBuf { + let path = dir.join(name); + let mut f = std::fs::File::create(&path).unwrap(); + for line in lines { + writeln!(f, "{}", line).unwrap(); + } + path + } + + #[test] + fn discover_sessions_missing_dir_returns_not_found() { + let err = discover_sessions(Path::new("/tmp/this/does/not/exist/ever")) + .expect_err("should fail"); + assert!(matches!(err, ParseError::NotFound(_))); + } + + #[test] + fn discover_sessions_lists_only_jsonl() { + let tmp = tempfile::tempdir().unwrap(); + let _ = write_jsonl(tmp.path(), "a.jsonl", &[]); + let _ = write_jsonl(tmp.path(), "b.jsonl", &[]); + let _ = write_jsonl(tmp.path(), "not-this.txt", &[]); + let found = discover_sessions(tmp.path()).unwrap(); + assert_eq!(found.len(), 2); + assert!(found[0].ends_with("a.jsonl")); + assert!(found[1].ends_with("b.jsonl")); + } + + #[test] + fn parse_file_extracts_basic_messages() { + let tmp = tempfile::tempdir().unwrap(); + let path = write_jsonl( + tmp.path(), + "session-abc.jsonl", + &[ + r#"{"role":"user","content":"hi","ts":"2025-01-01T12:00:00Z"}"#, + r#"{"role":"assistant","content":"hello","ts":"2025-01-01T12:00:01Z"}"#, + ], + ); + let session = parse_file(&path).unwrap(); + assert_eq!(session.native_id, "session-abc"); + assert_eq!(session.messages.len(), 2); + assert_eq!(session.messages[0].role, "user"); + assert_eq!(session.messages[0].content, "hi"); + assert_eq!(session.messages[1].role, "assistant"); + assert!(session.created_at.is_some()); + assert!(session.updated_at.is_some()); + assert_ne!(session.created_at, session.updated_at); + } + + #[test] + fn parse_file_skips_malformed_and_empty_lines() { + let tmp = tempfile::tempdir().unwrap(); + let path = write_jsonl( + tmp.path(), + "session.jsonl", + &[ + r#"{"role":"user","content":"hi"}"#, + "", + "not json at all", + r#"{"garbled":true}"#, // no role/content → skipped + r#"{"role":"assistant","content":"ok"}"#, + ], + ); + let session = parse_file(&path).unwrap(); + assert_eq!(session.messages.len(), 2); + } + + #[test] + fn parse_file_handles_content_array_and_object() { + let tmp = tempfile::tempdir().unwrap(); + let path = write_jsonl( + tmp.path(), + "session.jsonl", + &[ + r#"{"role":"user","content":["a","b","c"]}"#, + r#"{"role":"user","content":[{"text":"x"},{"text":"y"}]}"#, + r#"{"role":"assistant","content":{"text":"nested"}}"#, + ], + ); + let session = parse_file(&path).unwrap(); + assert_eq!(session.messages.len(), 3); + assert_eq!(session.messages[0].content, "a\nb\nc"); + assert_eq!(session.messages[1].content, "x\ny"); + assert_eq!(session.messages[2].content, "nested"); + } + + #[test] + fn parse_file_accepts_unix_ts() { + let tmp = tempfile::tempdir().unwrap(); + let path = write_jsonl( + tmp.path(), + "session.jsonl", + &[r#"{"role":"user","content":"hi","ts":1735732800}"#], + ); + let session = parse_file(&path).unwrap(); + assert_eq!(session.messages.len(), 1); + assert!(session.messages[0].ts.is_some()); + } +} diff --git a/crates/dirigent_codex/src/types.rs b/crates/dirigent_codex/src/types.rs new file mode 100644 index 0000000..1ae3d8c --- /dev/null +++ b/crates/dirigent_codex/src/types.rs @@ -0,0 +1,32 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +/// A Codex session parsed from a single JSONL file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParsedSession { + /// The native session id. For Codex this is the JSONL file stem. + pub native_id: String, + /// The source file the session was loaded from. + pub source_path: PathBuf, + /// First message timestamp seen (if any). + pub created_at: Option>, + /// Last message timestamp seen (if any). + pub updated_at: Option>, + /// Parsed messages in file order. + pub messages: Vec, +} + +/// A single message event from a Codex JSONL session file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ParsedMessage { + /// Timestamp if one could be extracted (RFC 3339 or unix epoch). + pub ts: Option>, + /// Free-form role, e.g. "user", "assistant", "system", "tool". + pub role: String, + /// Best-effort concatenated text content. + pub content: String, + /// Raw event for provenance. + #[serde(default)] + pub metadata: serde_json::Value, +} diff --git a/crates/dirigent_config/CLAUDE.md b/crates/dirigent_config/CLAUDE.md new file mode 100644 index 0000000..4a1958f --- /dev/null +++ b/crates/dirigent_config/CLAUDE.md @@ -0,0 +1,22 @@ +# dirigent_config + +Platform-native configuration and data path resolution. + +## Purpose +Provides `DirigentPaths` for resolving config/data directories across Linux, macOS, Windows. +Creates a symlink on Linux/macOS from config_dir/data -> data_dir for discoverability. + +## Key Types +- `DirigentPaths` -- resolved config_dir + data_dir with convenience methods +- `ConfigPathError` -- error enum for path resolution failures + +## Usage +```rust +let paths = DirigentPaths::resolve()?; +paths.ensure_dirs()?; // creates dirs + symlink +let config = paths.config_file(); // ~/.config/dirigent/dirigent.toml +``` + +## Dependencies +- `dirs` -- cross-platform directory resolution +- Zero UI dependency -- used by core, archivist, zed crates diff --git a/crates/dirigent_config/Cargo.toml b/crates/dirigent_config/Cargo.toml new file mode 100644 index 0000000..77ad5ab --- /dev/null +++ b/crates/dirigent_config/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "dirigent_config" +version = "0.1.0" +edition = "2021" +description = "Platform-native configuration and data path resolution for Dirigent" + +[dependencies] +dirs = "5" +thiserror = "2.0" + +[dev-dependencies] +serial_test = "3" diff --git a/crates/dirigent_config/src/lib.rs b/crates/dirigent_config/src/lib.rs new file mode 100644 index 0000000..4c1fcbb --- /dev/null +++ b/crates/dirigent_config/src/lib.rs @@ -0,0 +1,10 @@ +//! Platform-native configuration and data path resolution for Dirigent. +//! +//! Provides `DirigentPaths` for resolving config and data directories +//! across Linux, macOS, and Windows. On Linux/macOS, creates a symlink +//! from config_dir/data to data_dir for discoverability. + +mod paths; + +pub use paths::ConfigPathError; +pub use paths::DirigentPaths; diff --git a/crates/dirigent_config/src/paths.rs b/crates/dirigent_config/src/paths.rs new file mode 100644 index 0000000..762c5b1 --- /dev/null +++ b/crates/dirigent_config/src/paths.rs @@ -0,0 +1,282 @@ +use std::path::{Path, PathBuf}; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ConfigPathError { + #[error("Could not determine config directory for this platform")] + NoConfigDir, + #[error("Could not determine data directory for this platform")] + NoDataDir, + #[error("Failed to create directory {path}: {source}")] + CreateDir { + path: PathBuf, + source: std::io::Error, + }, + #[error("Failed to create symlink from {from} to {to}: {source}")] + Symlink { + from: PathBuf, + to: PathBuf, + source: std::io::Error, + }, +} + +/// Platform-native paths for dirigent configuration and data. +/// +/// Environment variable overrides (highest priority): +/// - `DIRIGENT_CONFIG_DIR` -- overrides the config directory +/// - `DIRIGENT_DATA_DIR` -- overrides the data directory +/// +/// Platform-native defaults (fallback): +/// +/// | Platform | Config Dir | Data Dir | +/// |----------|-----------|----------| +/// | Linux | `$XDG_CONFIG_HOME/dirigent/` | `$XDG_DATA_HOME/dirigent/` | +/// | macOS | `~/.config/dirigent/` | `~/Library/Application Support/dirigent/` | +/// | Windows | `%APPDATA%\dirigent\` | `%LOCALAPPDATA%\dirigent\` | +#[derive(Debug, Clone)] +pub struct DirigentPaths { + config_dir: PathBuf, + data_dir: PathBuf, +} + +impl DirigentPaths { + /// Resolve platform-native paths. + pub fn resolve() -> Result { + let config_dir = Self::resolve_config_dir()?; + let data_dir = Self::resolve_data_dir()?; + Ok(Self { + config_dir, + data_dir, + }) + } + + /// Create directories and symlink if they don't exist. + pub fn ensure_dirs(&self) -> Result<(), ConfigPathError> { + std::fs::create_dir_all(&self.config_dir).map_err(|e| ConfigPathError::CreateDir { + path: self.config_dir.clone(), + source: e, + })?; + + std::fs::create_dir_all(&self.data_dir).map_err(|e| ConfigPathError::CreateDir { + path: self.data_dir.clone(), + source: e, + })?; + + let noproject = self.noproject_home_dir(); + std::fs::create_dir_all(&noproject).map_err(|e| ConfigPathError::CreateDir { + path: noproject, + source: e, + })?; + + // Create symlink on Linux/macOS: config_dir/data -> data_dir + #[cfg(unix)] + { + let symlink_path = self.config_dir.join("data"); + if symlink_path.symlink_metadata().is_err() { + std::os::unix::fs::symlink(&self.data_dir, &symlink_path).map_err(|e| { + ConfigPathError::Symlink { + from: symlink_path, + to: self.data_dir.clone(), + source: e, + } + })?; + } + } + + Ok(()) + } + + /// Where `dirigent.toml` lives. + pub fn config_dir(&self) -> &Path { + &self.config_dir + } + + /// Where archives, projects, caches live. + pub fn data_dir(&self) -> &Path { + &self.data_dir + } + + /// Archive storage directory. + pub fn archive_dir(&self) -> PathBuf { + self.data_dir.join("archives") + } + + /// Project storage directory. + pub fn projects_dir(&self) -> PathBuf { + self.data_dir.join("projects") + } + + /// Default working directory for connectors when no project is active. + /// Lives under the data directory alongside other runtime artifacts. + pub fn noproject_home_dir(&self) -> PathBuf { + self.data_dir.join("noproject_home") + } + + /// Task output storage directory. + pub fn tasks_dir(&self) -> PathBuf { + self.data_dir.join("tasks") + } + + /// Log storage directory. + pub fn logs_dir(&self) -> PathBuf { + self.data_dir.join("logs") + } + + /// Main config file path. + pub fn config_file(&self) -> PathBuf { + self.config_dir.join("dirigent.toml") + } + + fn resolve_config_dir() -> Result { + // Environment variable override (highest priority) + if let Ok(dir) = std::env::var("DIRIGENT_CONFIG_DIR") { + if !dir.trim().is_empty() { + // Absolutize relative paths against CWD to prevent nested resolution + // when subprocesses run from different working directories. + let path = PathBuf::from(&dir); + if path.is_relative() { + return Ok(std::env::current_dir() + .map_err(|_| ConfigPathError::NoConfigDir)? + .join(path)); + } + return Ok(path); + } + } + + // Platform-native fallback + #[cfg(target_os = "macos")] + { + dirs::home_dir() + .map(|d| d.join(".config").join("dirigent")) + .ok_or(ConfigPathError::NoConfigDir) + } + #[cfg(not(target_os = "macos"))] + { + dirs::config_dir() + .map(|d| d.join("dirigent")) + .ok_or(ConfigPathError::NoConfigDir) + } + } + + fn resolve_data_dir() -> Result { + // Environment variable override (highest priority) + if let Ok(dir) = std::env::var("DIRIGENT_DATA_DIR") { + if !dir.trim().is_empty() { + // Absolutize relative paths against CWD to prevent nested resolution + // when subprocesses run from different working directories. + let path = PathBuf::from(&dir); + if path.is_relative() { + return Ok(std::env::current_dir() + .map_err(|_| ConfigPathError::NoDataDir)? + .join(path)); + } + return Ok(path); + } + } + + // Platform-native fallback + #[cfg(target_os = "windows")] + { + dirs::data_local_dir() + .map(|d| d.join("dirigent")) + .ok_or(ConfigPathError::NoDataDir) + } + #[cfg(not(target_os = "windows"))] + { + dirs::data_dir() + .map(|d| d.join("dirigent")) + .ok_or(ConfigPathError::NoDataDir) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_resolve_returns_paths() { + let paths = DirigentPaths::resolve().expect("should resolve paths"); + assert!(paths.config_dir().to_string_lossy().contains("dirigent")); + assert!(paths.data_dir().to_string_lossy().contains("dirigent")); + } + + #[test] + fn test_config_file_path() { + let paths = DirigentPaths::resolve().unwrap(); + let config = paths.config_file(); + assert!(config.to_string_lossy().ends_with("dirigent.toml")); + } + + #[test] + fn test_archive_dir() { + let paths = DirigentPaths::resolve().unwrap(); + let archive = paths.archive_dir(); + assert!(archive.to_string_lossy().ends_with("archives")); + } + + #[test] + fn test_projects_dir() { + let paths = DirigentPaths::resolve().unwrap(); + let projects = paths.projects_dir(); + assert!(projects.to_string_lossy().ends_with("projects")); + } + + #[test] + fn test_config_file_under_config_dir() { + let paths = DirigentPaths::resolve().unwrap(); + assert!(paths.config_file().starts_with(paths.config_dir())); + } + + #[test] + fn test_subdirs_under_data_dir() { + let paths = DirigentPaths::resolve().unwrap(); + assert!(paths.archive_dir().starts_with(paths.data_dir())); + assert!(paths.projects_dir().starts_with(paths.data_dir())); + assert!(paths.logs_dir().starts_with(paths.data_dir())); + } + + use serial_test::serial; + + #[test] + #[serial] + fn test_config_dir_override_from_env() { + let tmp = std::env::temp_dir().join("dirigent_test_config_override"); + unsafe { std::env::set_var("DIRIGENT_CONFIG_DIR", &tmp) }; + let paths = DirigentPaths::resolve().expect("should resolve with env override"); + assert_eq!(paths.config_dir(), tmp.as_path()); + unsafe { std::env::remove_var("DIRIGENT_CONFIG_DIR") }; + } + + #[test] + #[serial] + fn test_data_dir_override_from_env() { + let tmp = std::env::temp_dir().join("dirigent_test_data_override"); + unsafe { std::env::set_var("DIRIGENT_DATA_DIR", &tmp) }; + let paths = DirigentPaths::resolve().expect("should resolve with env override"); + assert_eq!(paths.data_dir(), tmp.as_path()); + unsafe { std::env::remove_var("DIRIGENT_DATA_DIR") }; + } + + #[test] + fn test_noproject_home_under_data_dir() { + let paths = DirigentPaths::resolve().unwrap(); + assert!(paths.noproject_home_dir().starts_with(paths.data_dir())); + } + + #[test] + #[serial] + fn test_both_dir_overrides_from_env() { + let tmp_config = std::env::temp_dir().join("dirigent_test_both_config"); + let tmp_data = std::env::temp_dir().join("dirigent_test_both_data"); + unsafe { std::env::set_var("DIRIGENT_CONFIG_DIR", &tmp_config) }; + unsafe { std::env::set_var("DIRIGENT_DATA_DIR", &tmp_data) }; + let paths = DirigentPaths::resolve().expect("should resolve with both overrides"); + assert_eq!(paths.config_dir(), tmp_config.as_path()); + assert_eq!(paths.data_dir(), tmp_data.as_path()); + assert!(paths.archive_dir().starts_with(&tmp_data)); + assert!(paths.projects_dir().starts_with(&tmp_data)); + unsafe { std::env::remove_var("DIRIGENT_CONFIG_DIR") }; + unsafe { std::env::remove_var("DIRIGENT_DATA_DIR") }; + } +} diff --git a/crates/dirigent_core/CLAUDE.md b/crates/dirigent_core/CLAUDE.md new file mode 100644 index 0000000..855b66c --- /dev/null +++ b/crates/dirigent_core/CLAUDE.md @@ -0,0 +1,616 @@ +# Package: dirigent_core + +Core orchestration engine for multi-connector agent system management. + +## Quick Facts +- **Type**: Library +- **Main Entry**: src/lib.rs +- **Dependencies**: dirigent_protocol, tokio, axum, serde, uuid + +## Architecture Overview + +The dirigent_core package provides a **runtime-based architecture** for managing long-lived connections to external agent systems (OpenCode.ai, ACP agents, etc.). The core abstraction is the **Connector**, which represents a bidirectional communication channel to an agent system. + +### Core Components + +#### CoreRuntime +The central orchestrator for managing connectors. It maintains: +- Registry of active connectors (keyed by ConnectorId) +- Global event broadcast channel for system-wide events +- User registry for ownership and authorization +- Configuration state (with persistence support) + +#### CoreHandle +Lightweight, cloneable wrapper around CoreRuntime. Uses Arc internally for cheap cloning across async tasks and server functions. + +#### Connector Trait +Defines the interface for connector implementations: +- Command channel (mpsc) for control operations +- Event broadcast channel for publishing events +- State tracking (Initializing, Connecting, Ready, Error, Stopped) +- User ownership for authorization + +#### ConnectorHandle +Concrete implementation of the Connector trait that wraps: +- Metadata (id, kind, owner, title) +- Shared state (protected by RwLock) +- Command and event channels +- Optional task handle for lifecycle management + +### Connector Implementations + +#### OpenCodeConnector +Connector for OpenCode.ai REST + SSE API: +- HTTP client for session/message operations +- SSE event stream for real-time updates +- Background task loop for command processing +- State machine for connection lifecycle +- **TurnComplete emission**: Uses `TurnCompleteTrigger::ExplicitSignal` after `MessageCompleted` events (based on upstream session.idle signals) + +#### AcpConnector (Future) +Connector for Agent-Client Protocol: +- WebSocket or HTTP/2 transport +- ACP message protocol handling +- Tool execution and streaming support + +## Key Files + +### Core Runtime +- `src/runtime.rs` - CoreRuntime and CoreHandle implementation +- `src/types.rs` - Core types (ConnectorId, ConnectorState, User, etc.) +- `src/error.rs` - Error types for the runtime +- `src/config.rs` - Configuration types and template system + +### Connectors +- `src/connectors/mod.rs` - Connector trait and ConnectorHandle +- `src/connectors/opencode/mod.rs` - OpenCode connector implementation +- `src/connectors/opencode/config.rs` - OpenCode-specific configuration +- `src/connectors/acp/mod.rs` - ACP connector implementation (in progress) + +### ACP Protocol Implementation +- `src/acp/protocol/initialize.rs` - Protocol initialization and capability negotiation +- `src/acp/protocol/authenticate.rs` - Optional authentication flow +- `src/acp/protocol/session.rs` - Session lifecycle (new, load, set_mode, cancel) +- `src/acp/protocol/prompt.rs` - Prompt requests with content blocks (Phases 3-7 complete) +- `src/acp/protocol/streaming.rs` - Session update notifications and handlers +- `src/acp/protocol/stop_reason.rs` - Stop reason interpretation and actions +- `src/acp/protocol/cancellation.rs` - Cancellation and disconnect handling +- `src/acp/protocol/error.rs` - Error classification and retry logic +- `src/acp/connector_state.rs` - Connection and session state management +- `src/acp/transport/mod.rs` - Transport abstraction layer +- `src/acp/transport/stdio.rs` - Stdio transport (process spawning) +- `src/acp/transport/http.rs` - HTTP+SSE transport + +### Bidirectional Request Handling Pattern + +The ACP connector implements **true bidirectional communication** where both client and agent can send JSON-RPC requests at any time. This creates an architectural challenge: how to maintain synchronous request/response semantics (async/await) while handling incoming agent requests during outgoing client requests. + +**The Challenge:** + +When the connector sends `session/prompt` to the agent: +1. It calls `send_request()` which awaits the JSON-RPC response +2. The agent may send permission requests (e.g., `tools/write`) before responding +3. These agent requests arrive as `ConnectorCommand::AgentResponse` via the command channel +4. If `send_request()` only polls transport and response channels, the command channel isn't polled +5. **Result**: Deadlock - agent waits for permission → permission stuck in channel → client waits for prompt response + +**The Solution (src/connectors/acp/connector.rs:1253-1523):** + +The `send_request()` method uses `tokio::select!` to poll **three** sources simultaneously: +- **Response channel** (`response_rx`) - Waiting for the correlated JSON-RPC response +- **Transport channel** - Receiving messages/notifications from agent (may trigger response) +- **Command channel** (`cmd_rx`) - Receiving commands from event bridge (e.g., `AgentResponse`) + +When an `AgentResponse` command arrives during `send_request()`: +1. Extract the response payload from the command +2. Send it to the agent via transport immediately +3. Remove from pending requests map +4. Continue waiting for the original prompt response + +This pattern is **idiomatic async Rust** for implementing synchronous abstractions over bidirectional transports. It's similar to: +- gRPC bidirectional streaming with request/response correlation +- WebSocket clients with RPC-style method calls +- HTTP/2 multiplexing with concurrent streams + +**Why Not Separate Tasks?** + +Alternative architectures (separate command processing task, full actor model) introduce: +- Complex synchronization between tasks +- Race conditions on shared state +- Message ordering guarantees across channels +- Significantly more code and cognitive overhead + +The single-task, multi-channel select pattern keeps all state local and eliminates these issues. + +**Key Invariant:** + +Any method that blocks waiting for a response MUST also poll the command channel to process `AgentResponse` commands, otherwise bidirectional flows deadlock. + +## Main Exports + +### Runtime +- `CoreRuntime` - Main orchestrator +- `CoreHandle` - Cloneable runtime handle +- `CoreConfig` - Runtime configuration + +### Types +- `ConnectorId`, `UserId` - Type aliases for IDs +- `ConnectorKind` - Enum of connector types (OpenCode, Acp, Mock) +- `ConnectorState` - Lifecycle state enum +- `ConnectorSummary` - Lightweight connector view +- `User` - User information + +### Connectors +- `Connector` - Trait for connector implementations +- `ConnectorHandle` - Handle to a running connector +- `ConnectorCommand` - Commands sent to connectors +- `OpenCodeConnector` - OpenCode.ai integration +- `OpenCodeConfig` - OpenCode connector configuration + +### Configuration +- `ConnectorConfig` - Configuration for creating connectors +- `apply_template()` - Apply connector templates with patches + +### ACP Protocol Types (Phases 3-7) +- **Prompt Turn**: + - `SessionPromptRequest`, `SessionPromptResponse` - Prompt requests/responses + - `ContentBlock` - Text, Image, Audio, Resource, ResourceLink content + - `EmbeddedResource` - Text or Blob embedded resources + - `StopReason` - EndTurn, MaxTokens, MaxTurnRequests, Refusal, Cancelled + - `PromptError` - Timeout, JsonRpcError, TransportError, Cancelled, ValidationError + +- **Streaming Updates**: + - `SessionUpdate` - All update types (agent_message_chunk, tool_call, plan, etc.) + - `SessionUpdateNotification` - Notification wrapper with session_id + - `ToolCallInfo` - Tool call tracking with status and content + - `ToolKind` - Read, Edit, Search, Execute, Think, Other + - `ToolCallStatus` - Pending, InProgress, Completed, Failed, Cancelled + - `ToolCallContent` - Content, Diff, Terminal output + - `MessageAccumulator` - Message chunk accumulation helper + - `PlanEntry`, `Command` - Plan and command structures + +- **Stop Reason Handling**: + - `StopReasonAction` - Complete, ShowWarning, ShowError, ShowInfo + - `handle_stop_reason()` - Interpret stop reasons + - `is_continuable()`, `is_error()` - Stop reason classification + +- **Cancellation**: + - `handle_cancellation()` - Cancel pending operations + - `cancel_pending_tool_calls()` - Mark tool calls as cancelled + - `handle_disconnect()` - Handle transport disconnect + +- **Error Classification**: + - `ClassifiedError` - Error with class, message, details, retry_after + - `ErrorClass` - Transient, Terminal, User + - `ErrorSeverity` - Info, Warning, Error, Fatal + - `ErrorAction` - Retry, Reconnect, CheckConfig, ContactSupport, Dismiss + - `classify_jsonrpc_error()`, `classify_transport_error()` - Classify errors + - `exponential_backoff()` - Calculate retry delays + +## Usage Examples + +### Creating a Runtime + +```rust +use dirigent_core::{CoreRuntime, CoreConfig, CoreHandle}; + +// Load config from file or use default +let config = CoreConfig::load_config(None)?; + +// Create runtime +let runtime = CoreRuntime::new(config); + +// Wrap in handle for cheap cloning +let handle = CoreHandle::new(runtime); +``` + +### Creating a Connector + +```rust +use dirigent_core::{ConnectorConfig, ConnectorKind, OpenCodeConfig}; +use serde_json::json; + +// Build OpenCode connector config +let config = OpenCodeConfig { + base_url: "http://localhost:12225".to_string(), + title: "My OpenCode".to_string(), + initial_session: None, +}; + +// Serialize to JSON for ConnectorConfig +let params = serde_json::to_value(&config)?; + +let connector_config = ConnectorConfig { + id: None, // Runtime generates ID + kind: ConnectorKind::OpenCode, + owner: Some("user-123".to_string()), + title: Some("My OpenCode".to_string()), + params, +}; + +// Create connector via runtime +let connector_id = handle.create_connector( + "user-123".to_string(), + connector_config +).await?; +``` + +### Using Templates + +```rust +use dirigent_core::{apply_template, ConnectorKind}; +use serde_json::json; + +// Use default template with custom URL +let connector_config = apply_template( + ConnectorKind::OpenCode, + "default", + json!({ + "base_url": "http://localhost:8080", + "title": "Custom OpenCode" + }) +)?; + +let connector_id = handle.create_connector( + "user-123".to_string(), + connector_config +).await?; +``` + +### Managing Connectors + +```rust +// List all connectors +let all_connectors = handle.list_connectors(None).await; + +// List connectors for a specific user +let user_connectors = handle.list_connectors(Some("user-123".to_string())).await; + +// Get a specific connector +let connector = handle.get_connector(&connector_id).await; + +// Stop a connector +handle.stop_connector(&connector_id).await?; + +// Restart a stopped connector +handle.restart_connector(&connector_id).await?; + +// Remove a connector +handle.remove_connector(&connector_id).await?; +``` + +### Connector Lifecycle with Restart + +Connectors can be restarted after being stopped or entering an error state: + +```rust +// Create and start a connector +let connector_id = handle.create_connector( + "user-123".to_string(), + connector_config +).await?; +// Connector is now in Ready state + +// Stop the connector +handle.stop_connector(&connector_id).await?; +// Connector is now in Stopped state + +// Restart the connector (recreates background task with fresh channels) +handle.restart_connector(&connector_id).await?; +// Connector transitions: Stopped → Initializing → Connecting → Ready + +// Restart preserves: +// - Connector ID (same instance) +// - Configuration (base_url, title, etc.) +// - Event broadcast channel (subscribers continue receiving) +// - State Arc (observers see real-time updates) + +// Restart recreates: +// - Command channel (new sender/receiver pair) +// - Connector instance (fresh OpenCodeConnector) +// - Background task (new spawn) +// - Task handle (new JoinHandle) +``` + +### Sending Commands + +```rust +use dirigent_core::connectors::ConnectorCommand; + +// Get connector handle +let connector = handle.get_connector(&connector_id).await.unwrap(); + +// Subscribe to events +let mut events = connector.subscribe(); + +// Send a command +let cmd_tx = connector.command_tx(); +cmd_tx.send(ConnectorCommand::ListSessions).await?; + +// Receive events +while let Ok(event) = events.recv().await { + match event { + Event::SessionsListed { sessions } => { + println!("Got {} sessions", sessions.len()); + break; + } + Event::Error { message } => { + eprintln!("Error: {}", message); + break; + } + _ => {} + } +} +``` + +### Global Event Stream + +```rust +// Subscribe to every event on the SharingBus. Callers can also pick an +// `EventFilter` via `subscribe_filtered()` to receive only the events +// they care about. +let mut bus_rx = handle.sharing_bus().subscribe_all().await; + +tokio::spawn(async move { + while let Some(bus_event) = bus_rx.rx.recv().await { + println!("Bus event: {:?}", bus_event.event); + } +}); +``` + +## Configuration + +### Runtime Configuration (dirigent.toml or dirigent.json) + +```toml +port = 3000 +project_dir = "." +project_name = "my_project" +templates_enabled = true + +[[connectors]] +id = "opencode-1" +kind = "OpenCode" +owner = "user-123" +title = "OpenCode Local" + +[connectors.params] +base_url = "http://localhost:12225" +title = "OpenCode Local" +initial_session = null +``` + +### Templates + +Available templates: +- `opencode/default` - Standard localhost OpenCode connector +- `acp/claude-default` - Claude API connector (stub for future) + +## Connector Event Emission Patterns + +### TurnComplete Event Semantics + +All connectors emit `Event::TurnComplete` to signal that a turn/message is finalized. This is the **primary signal** for: +- Archivist to finalize and write message to disk +- UI cache to lock message state as immutable +- Conductor bridge to flush response to upstream + +**Event ordering guarantee**: +```text +MessageCompleted → TurnComplete → SessionIdle +``` + +### Connector-Specific TurnComplete Strategies + +Different connectors use different strategies to determine when a turn is complete: + +#### OpenCode Connector +- **Trigger**: `TurnCompleteTrigger::ExplicitSignal` +- **Strategy**: Relies on upstream `session.idle` events from OpenCode.ai +- **Implementation**: After translating `MessageCompleted`, emits `TurnComplete` then `SessionIdle` +- **Code location**: `crates/dirigent_core/src/connectors/opencode.rs:550-575` + +#### ACP Connector (stdio transport) +- **Trigger**: `TurnCompleteTrigger::ResponseReceived` +- **Strategy**: JSON-RPC response message is the final message in a turn +- **Implementation**: Emits `TurnComplete` after receiving the JSON-RPC response to `session/prompt` +- **Code location**: `crates/dirigent_core/src/connectors/acp/connector.rs:328` + +#### Gateway Connector +- **Trigger**: `TurnCompleteTrigger::OperationsComplete` +- **Strategy**: Tracks pending tool calls and emits when all operations resolve +- **Implementation**: Monitors tool call status changes and emits when last pending call completes +- **Code location**: `crates/dirigent_core/src/connectors/gateway/mod.rs:464,556,626,678` + +**Important**: Connectors MUST emit `TurnComplete` exactly once per turn. Duplicate emissions can cause archiving issues and UI state corruption. + +### Gateway Session Transfer Mechanics + +The Gateway connector serves as an **entry point** for incoming ACP connections. Sessions can be transferred to real agent connectors (Claude, etc.) via `/select-connector` commands. + +#### Key Principle: New Connector Is Authority + +**After transfer, the target connector becomes the sole authority for session configuration.** + +```text +Before transfer: + Gateway has placeholder modes: "ask", "write", "yolo" + Gateway has placeholder models: "simple", "default", "high" + +After transfer to Claude: + Claude's actual modes/models become authoritative + Gateway's placeholders are irrelevant + Editor receives config_option_update with Claude's real options +``` + +#### Transfer Flow + +1. User sends `/select-connector claude` in Gateway session +2. Gateway emits `SessionTransferRequest` to CoreRuntime +3. CoreRuntime creates/loads session in target connector +4. Target connector emits `SessionCreated` with its modes/models +5. CoreRuntime extracts modes/models from `SessionCreated` event +6. CoreRuntime emits `SessionTransferred` event with modes/models +7. Event bridge sends `config_option_update` to editor with target's modes/models + +#### What Does NOT Happen + +- Gateway does NOT adjust or map its values to the target connector +- Gateway does NOT remain involved after transfer completes +- Target connector does NOT inherit Gateway's mode/model selections +- No "mapping" between Gateway placeholders and real connector values + +#### Code Locations + +- Transfer request handling: `src/runtime.rs:execute_transfer()` +- Gateway commands: `src/connectors/gateway/commands.rs` +- SessionTransferred event: `dirigent_protocol/src/events/mod.rs` +- config_option_update emission: `dirigent_acp_api/src/event_bridge.rs:handle_session_transferred_internal()` + +## Architecture Patterns + +### Request-Response Pattern +For operations that need results (list_sessions, list_messages): +1. Subscribe to connector events +2. Send command via command channel +3. Wait for corresponding response event (with timeout) +4. Return result or error + +### Fire-and-Forget Pattern +For operations that stream results (send_message): +1. Send command via command channel +2. Return immediately +3. Clients subscribe to event stream for updates + +### Lifecycle Management +Connectors progress through states: +1. **Initializing** - Created but not connecting +2. **Connecting** - Attempting to establish connection +3. **Ready** - Connected and operational +4. **Error** - Encountered failure (with error message) +5. **Stopped** - Shutdown or unrecoverable error + +#### Restart Support +Connectors in `Stopped` or `Error` state can be restarted: +- **restart_connector()** recreates the connector's background task with fresh channels +- Preserves connector identity (ID, owner, configuration) +- Preserves event broadcast channel (existing subscribers continue receiving) +- Recreates command channel and background task +- State transitions: `Stopped`/`Error` → `Initializing` → `Connecting` → `Ready` + +### Configuration Persistence +The runtime automatically saves configuration to disk when: +- A connector is created +- A connector is removed +- Configuration is explicitly saved + +This ensures connectors are restored on server restart. + +## Session Tracking Responsibilities + +**IMPORTANT**: CoreRuntime is a **stateless orchestrator** for session operations. It does NOT maintain session state or cache message history. + +### What CoreRuntime Does + +- **Route Commands**: Forward session/message commands to appropriate connectors +- **Broadcast Events**: Relay connector events to global event stream +- **Manage Connectors**: Track which connectors are active and available +- **Persist Configuration**: Save/load connector configuration (not session data) + +### What CoreRuntime Does NOT Do + +- **Cache Sessions**: Does not maintain lists of sessions or session metadata +- **Store Messages**: Does not retain message content or history +- **Buffer Events**: Does not cache events for replay or historical access +- **Track Session State**: Does not know which sessions exist or their current state + +### Stateless Design Rationale + +1. **Multi-Connector Support**: With multiple connectors (OpenCode, ACP, etc.), caching sessions would require complex invalidation +2. **Memory Efficiency**: Long-running server should not accumulate unbounded session history +3. **Single Source of Truth**: External APIs (OpenCode.ai, ACP agents) are authoritative for session state +4. **Scalability**: Stateless design supports future horizontal scaling + +### Data Flow for Session Operations + +**List Sessions**: +``` +Server Function → CoreRuntime.get_connector() → Send Command → Connector queries API + ↓ ↓ + Returns handle Broadcasts SessionsListed + ↓ + Server function returns to UI + (CoreRuntime does NOT cache list) +``` + +**Send Message**: +``` +Server Function → CoreRuntime.get_connector() → Send Command → Connector sends to API + ↓ ↓ + Returns handle Broadcasts MessageSent + ↓ + SSE pushes to UI in real-time + (CoreRuntime does NOT store message) +``` + +All session data passes through CoreRuntime but is **never cached**. See `docs/architecture/session_tracking.md` for the complete three-layer architecture (CoreRuntime, UI Cache, Archivist). + +## Phase 4: `SharingBus` + `StreamRegistry` (2026-04-21) + +Every `Event` emitted by a connector or the runtime is published onto a +single `SharingBus` that owns: + +- A `broadcast::Sender` as the internal multicast. +- A worker task that receives from that broadcast and dispatches per- + subscriber `mpsc::Sender` pipes with filter matching. +- A `HashMap<(connector_id, native_session_id), scroll_id>` cache that + late-binds `routing.scroll_id` for events emitted before their + `SessionRegistered` event arrived. + +### Subscriber model + +`SharingBus::subscribe_all()` and `subscribe_filtered(EventFilter, cap)` +return a `BusReceiver { id, rx, lagged }`. Filters are applied by the +worker — subscribers never allocate a closure for skipped events. + +`EventFilter` variants: `All | ScrollId | ConnectorUid | Kinds | AnyOf +| AllOf`. + +### Stream registry + +The runtime owns a `StreamRegistry` and a `StreamFactoryRegistry`. +Streams are attached via `CoreRuntime::attach_stream(StreamConfig)`: + +- Factory registry resolves the `kind` string → concrete + `Arc`. +- The new stream gets its own `BusReceiver` scoped via `scope_to_filter` + (Session → ScrollId, Connector → ConnectorUid, ArchiveWide → All). +- A worker task pumps events into `stream.on_event(&bus_event).await`. + +Health drift: `record_failure` / `record_success` in `sharing/health.rs` +transitions `HealthStatus` on 5 consecutive failures +(`Healthy → Degraded → Unavailable`). + +### Replay + +`CoreRuntime::replay_session_to_stream(scroll_id, stream_id, opts)` +loads archived messages via the archivist's read API and calls +`stream.on_event` directly, bypassing the bus. Each replayed event has +`origin = EventOrigin::Replay { replay_id }` and +`routing.scroll_id = Some(scroll_id)`. + +### Config + +`[[streams]]` blocks in `dirigent.toml` are parsed into `StreamsConfig` +and applied at boot (best-effort: failures log + continue). + +## Related Packages +- **api** - Server functions that wrap CoreRuntime operations +- **web** - Dioxus UI that calls server functions +- **dirigent_protocol** - Shared event and message types +- **opencode_client** - Low-level OpenCode.ai HTTP client (used by OpenCodeConnector) + +## Documentation +- README: ./README.md +- Architecture: ../../docs/architecture/overview.md +- Migration Guide: ../../docs/migration/singleton_to_runtime.md diff --git a/crates/dirigent_core/Cargo.toml b/crates/dirigent_core/Cargo.toml new file mode 100644 index 0000000..eabfa72 --- /dev/null +++ b/crates/dirigent_core/Cargo.toml @@ -0,0 +1,120 @@ +[package] +name = "dirigent_core" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "dirigent-core" +path = "src/bin/main.rs" +required-features = ["server"] + +[dependencies] +# Server-only dependencies (optional) +# ACP (Agent-Client Protocol) - the main dependency for connecting to agents +# Using the rust-sdk from https://github.com/agentclientprotocol/rust-sdk +agent-client-protocol = { version = "0.6", optional = true } +# Async streams +async-stream = { version = "0.3", optional = true } +# Async trait support +async-trait = { version = "0.1", optional = true } +# Web server +axum = { version = "0.8", optional = true } +# Base64 encoding for embedded resources +base64 = { version = "0.22", optional = true } +# BLAKE3 hashing for stable URIs +blake3 = { version = "1.5", optional = true } +chrono = { version = "0.4", features = ["serde"] } +dirigent_acp_api = { path = "../dirigent_acp_api", optional = true } +# Workspace dependencies +dirigent_archivist = { path = "../dirigent_archivist", optional = true } +dirigent_config = { path = "../dirigent_config", optional = true } +dirigent_auth = { path = "../dirigent_auth" } +dirigent_process = { path = "../dirigent_process", features = ["tokio"], optional = true } +dirigent_taskrunner = { path = "../dirigent_taskrunner", optional = true } +dirigent_matrix = { path = "../dirigent_matrix", optional = true } +dirigent_zed = { path = "../dirigent_zed", optional = true } +dirigent_inspector = { path = "../dirigent_inspector", optional = true } +dirigent_protocol = { path = "../dirigent_protocol", features = ["adapters"], optional = true } +dirigent_tools = { path = "../dirigent_tools", optional = true } +# SSE client for ACP transport +eventsource-client = { version = "0.13", optional = true } +futures = { version = "0.3", optional = true } +# Lazy static initialization +once_cell = { version = "1.20", optional = true } +opencode_client = { path = "../opencode_client", optional = true } +# HTTP client for ACP transport +reqwest = { version = "0.12", features = ["json", "stream"], optional = true } +# Core types - always available (WASM-compatible) +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +# Error handling +thiserror = { version = "2.0", optional = true } +tokio = { version = "1", features = ["full"], optional = true } +toml = { version = "0.8", optional = true } +tower = { version = "0.5", optional = true } +tower-http = { version = "0.6", features = ["cors"], optional = true } +# Logging +tracing = { version = "0.1", optional = true } +tracing-subscriber = { version = "0.3", features = ["env-filter"], optional = true } +uuid = { version = "1.0", features = ["js", "serde", "v4", "v5", "v7"] } + +[dev-dependencies] +anyhow = "1.0" +axum = { version = "0.8", features = ["macros"] } +base64 = "0.22" +blake3 = "1.5" +dirigent_protocol = { path = "../dirigent_protocol" } +dirigent_tools = { path = "../dirigent_tools" } +tempfile = "3.0" +# Test dependencies - mirror the optional server dependencies +tokio = { version = "1", features = ["full", "test-util"] } +toml = "0.8" + +[[test]] +name = "stream_registry_test" +required-features = ["test-utils"] + +[[test]] +name = "replay_test" +required-features = ["test-utils", "server"] + +[[test]] +name = "matrix_migration_test" +required-features = ["server"] + +[features] +default = [] +test-utils = [] +server = [ + "dep:agent-client-protocol", + "dep:async-stream", + "dep:async-trait", + "dep:axum", + "dep:base64", + "dep:blake3", + "dep:dirigent_acp_api", + "dep:dirigent_archivist", + "dep:dirigent_config", + "dep:dirigent_inspector", + "dep:dirigent_protocol", + "dep:dirigent_tools", + "dep:eventsource-client", + "dep:futures", + "dep:once_cell", + "dep:opencode_client", + "dep:reqwest", + "dep:thiserror", + "dep:tokio", + "dep:toml", + "dep:tower", + "dep:tower-http", + "dep:tracing", + "dep:tracing-subscriber", + "dep:dirigent_matrix", + "dep:dirigent_zed", + "dep:dirigent_taskrunner", + "dep:dirigent_process", +] diff --git a/crates/dirigent_core/README.md b/crates/dirigent_core/README.md new file mode 100644 index 0000000..b65c6f8 --- /dev/null +++ b/crates/dirigent_core/README.md @@ -0,0 +1,7 @@ +The Dirigent Core should implement the core functionality of Dirigent: + + - setup and manage ACP Agents + - connect as ACP client to ACP Agents + - manage Projects + - manage Sessions + - manage file access, terminal access diff --git a/crates/dirigent_core/dirigent.json b/crates/dirigent_core/dirigent.json new file mode 100644 index 0000000..ebbecee --- /dev/null +++ b/crates/dirigent_core/dirigent.json @@ -0,0 +1,9 @@ +{ + "port": 3000, + "project_dir": ".", + "project_name": "default", + "connectors": [], + "storage_dir": null, + "templates_enabled": false, + "archive_root": null +} \ No newline at end of file diff --git a/crates/dirigent_core/src/acp/connector_state.rs b/crates/dirigent_core/src/acp/connector_state.rs new file mode 100644 index 0000000..08f4cae --- /dev/null +++ b/crates/dirigent_core/src/acp/connector_state.rs @@ -0,0 +1,547 @@ +//! ACP connector state management. +//! +//! This module defines the state structure for tracking ACP connection lifecycle, +//! including protocol negotiation, authentication status, and capabilities. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Connection state for an ACP connector. +/// +/// Tracks the lifecycle of an ACP connection from initial connection through +/// initialization, optional authentication, and ready state. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ConnectionState { + /// Not yet initialized (initial state). + Uninitialized, + /// Currently performing initialization handshake. + Initializing, + /// Initialization complete, ready for optional authentication. + Initialized, + /// Currently authenticating. + Authenticating, + /// Fully connected and ready for session operations. + Ready, + /// Connection lost or transport error. + Disconnected, + /// Terminal error state. + Error(String), +} + +/// Complete ACP connector state. +/// +/// This structure tracks all state related to an active ACP connection, +/// including protocol version, capabilities, authentication status, and +/// agent information. +#[derive(Debug, Clone)] +pub struct AcpConnectorState { + /// Current connection state. + pub connection_state: ConnectionState, + + /// Negotiated protocol version (set after successful initialize). + pub protocol_version: Option, + + /// Client capabilities that were advertised to the agent. + pub client_capabilities: Option, + + /// Agent capabilities received from the agent. + pub agent_capabilities: Option, + + /// Agent implementation info (for display/logging). + pub agent_info: Option, + + /// Authentication methods supported by the agent. + pub auth_methods: Vec, + + /// Whether authentication has been completed successfully. + pub authenticated: bool, + + /// Active sessions (keyed by session ID). + /// + /// Tracks all sessions created through this connector. Each session + /// maintains its own state including working directory, MCP servers, + /// current mode, and cancellation status. + pub sessions: HashMap, +} + +impl AcpConnectorState { + /// Create a new uninitialized ACP connector state. + pub fn new() -> Self { + Self { + connection_state: ConnectionState::Uninitialized, + protocol_version: None, + client_capabilities: None, + agent_capabilities: None, + agent_info: None, + auth_methods: Vec::new(), + authenticated: false, + sessions: HashMap::new(), + } + } + + /// Add a new session to the state. + pub fn add_session(&mut self, session_state: SessionState) { + self.sessions + .insert(session_state.session_id.clone(), session_state); + } + + /// Get a session by ID. + pub fn get_session(&self, session_id: &str) -> Option<&SessionState> { + self.sessions.get(session_id) + } + + /// Get a mutable reference to a session by ID. + pub fn get_session_mut(&mut self, session_id: &str) -> Option<&mut SessionState> { + self.sessions.get_mut(session_id) + } + + /// Remove a session by ID. + pub fn remove_session(&mut self, session_id: &str) -> Option { + self.sessions.remove(session_id) + } + + /// Get all active session IDs. + pub fn session_ids(&self) -> Vec { + self.sessions.keys().cloned().collect() + } + + /// Check if the connector is ready for session operations. + pub fn is_ready(&self) -> bool { + matches!(self.connection_state, ConnectionState::Ready) + } + + /// Check if authentication is required. + pub fn requires_auth(&self) -> bool { + !self.auth_methods.is_empty() && !self.authenticated + } + + /// Transition to error state with message. + pub fn set_error(&mut self, message: impl Into) { + self.connection_state = ConnectionState::Error(message.into()); + } +} + +impl Default for AcpConnectorState { + fn default() -> Self { + Self::new() + } +} + +/// Shared thread-safe wrapper for ACP connector state. +pub type SharedAcpState = Arc>; + +/// Client capabilities advertised to the agent. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ClientCapabilities { + /// Filesystem capabilities. + #[serde(skip_serializing_if = "Option::is_none")] + pub fs: Option, + + /// Terminal/command execution support. + #[serde(skip_serializing_if = "Option::is_none")] + pub terminal: Option, + + /// Extension metadata. + #[serde(skip_serializing_if = "Option::is_none")] + pub _meta: Option, +} + +impl ClientCapabilities { + /// Create default client capabilities (safe defaults: read-only filesystem, no terminal). + pub fn default_safe() -> Self { + Self { + fs: Some(FsCapabilities { + read_text_file: Some(true), + write_text_file: Some(false), + }), + terminal: Some(false), + _meta: None, + } + } + + /// Create client capabilities with all features enabled. + pub fn all_enabled() -> Self { + Self { + fs: Some(FsCapabilities { + read_text_file: Some(true), + write_text_file: Some(true), + }), + terminal: Some(true), + _meta: None, + } + } +} + +impl Default for ClientCapabilities { + /// Default capabilities: CURRENTLY DISABLED (handlers not implemented). + /// + /// **IMPORTANT**: These capabilities are currently disabled because the request handlers + /// in the ACP connector are not yet implemented. Advertising capabilities without + /// functioning handlers violates ACP compliance and will cause agent requests to fail. + /// + /// - **Filesystem Operations** (`fs`): + /// - `read_text_file`: Disabled (handler not implemented) + /// - `write_text_file`: Disabled (handler not implemented) + /// + /// - **Terminal Operations** (`terminal`): + /// - Disabled (handler not implemented) + /// + /// **Future Implementation**: While the underlying tools exist in `packages/dirigent_tools/`, + /// the ACP connector's request handler loop does not yet route agent requests to these + /// tools. Once the handlers are implemented and tested, these capabilities can be enabled. + /// + /// **Note**: Use `ClientCapabilities::all_enabled()` to override if you have implemented + /// custom handlers, but this is NOT recommended until proper request handling is in place. + fn default() -> Self { + Self { + fs: Some(FsCapabilities { + read_text_file: Some(false), // Not implemented yet + write_text_file: Some(false), // Not implemented yet + }), + terminal: Some(false), // Not implemented yet + _meta: None, + } + } +} + +/// Filesystem capabilities. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct FsCapabilities { + /// Support for reading text files. + #[serde(skip_serializing_if = "Option::is_none")] + pub read_text_file: Option, + + /// Support for writing text files. + #[serde(skip_serializing_if = "Option::is_none")] + pub write_text_file: Option, +} + +/// Agent capabilities received from the agent. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] +pub struct AgentCapabilities { + /// Support for loading existing sessions. + #[serde(skip_serializing_if = "Option::is_none")] + pub load_session: Option, + + /// Prompt-related capabilities. + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_capabilities: Option, + + /// MCP (Model Context Protocol) server support. + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp: Option, + + /// Extension metadata. + #[serde(skip_serializing_if = "Option::is_none")] + pub _meta: Option, +} + +impl AgentCapabilities { + /// Check if the agent supports loading sessions. + pub fn supports_load_session(&self) -> bool { + self.load_session.unwrap_or(false) + } + + /// Check if the agent supports image content in prompts. + pub fn supports_image(&self) -> bool { + self.prompt_capabilities + .as_ref() + .and_then(|pc| pc.image) + .unwrap_or(false) + } + + /// Check if the agent supports audio content in prompts. + pub fn supports_audio(&self) -> bool { + self.prompt_capabilities + .as_ref() + .and_then(|pc| pc.audio) + .unwrap_or(false) + } + + /// Check if the agent supports embedded context (resource blocks). + pub fn supports_embedded_context(&self) -> bool { + self.prompt_capabilities + .as_ref() + .and_then(|pc| pc.embedded_context) + .unwrap_or(false) + } +} + +/// Prompt-related capabilities. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] +pub struct PromptCapabilities { + /// Support for image content blocks. + #[serde(skip_serializing_if = "Option::is_none")] + pub image: Option, + + /// Support for audio content blocks. + #[serde(skip_serializing_if = "Option::is_none")] + pub audio: Option, + + /// Support for embedded context (resource blocks). + #[serde(skip_serializing_if = "Option::is_none")] + pub embedded_context: Option, +} + +/// MCP server capabilities. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct McpCapabilities { + /// Support for HTTP transport MCP servers. + #[serde(skip_serializing_if = "Option::is_none")] + pub http: Option, + + /// Support for SSE transport MCP servers. + #[serde(skip_serializing_if = "Option::is_none")] + pub sse: Option, +} + +/// Implementation information (for client or agent). +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ImplementationInfo { + /// Implementation name (e.g., "dirigent", "claude"). + pub name: String, + + /// Human-readable title. + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + + /// Version string (e.g., "0.1.0"). + #[serde(skip_serializing_if = "Option::is_none")] + pub version: Option, +} + +/// Session state tracking. +/// +/// Tracks the state of a single session including working directory, +/// MCP servers, current mode, and cancellation status. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SessionState { + /// Session ID. + pub session_id: String, + /// Working directory. + pub cwd: String, + /// MCP servers configured for this session. + pub mcp_servers: Vec, + /// Current mode (e.g., "code", "chat"), if applicable. + pub current_mode: Option, + /// Current model, if applicable. + pub current_model: Option, + /// Whether a prompt is currently in progress. + pub prompt_in_progress: bool, + /// Whether cancellation is in progress. + pub cancelling: bool, + /// Whether this session is currently loading (replaying history). + pub loading: bool, +} + +impl SessionState { + /// Create a new session state. + pub fn new(session_id: String, cwd: String, mcp_servers: Vec) -> Self { + Self { + session_id, + cwd, + mcp_servers, + current_mode: None, + current_model: None, + prompt_in_progress: false, + cancelling: false, + loading: false, + } + } + + /// Check if the session is ready for new prompts. + pub fn is_ready(&self) -> bool { + !self.prompt_in_progress && !self.cancelling && !self.loading + } + + /// Mark a prompt as started. + pub fn start_prompt(&mut self) { + self.prompt_in_progress = true; + } + + /// Mark a prompt as completed. + pub fn complete_prompt(&mut self) { + self.prompt_in_progress = false; + self.cancelling = false; + } + + /// Initiate cancellation. + pub fn start_cancellation(&mut self) { + self.cancelling = true; + } + + /// Mark the session as loading (history replay). + pub fn start_loading(&mut self) { + self.loading = true; + } + + /// Mark loading as complete. + pub fn complete_loading(&mut self) { + self.loading = false; + } + + /// Update the current mode. + pub fn set_mode(&mut self, mode: String) { + self.current_mode = Some(mode); + } + + /// Update the current model. + pub fn set_model(&mut self, model: String) { + self.current_model = Some(model); + } +} + +/// MCP (Model Context Protocol) server configuration. +/// +/// Specifies how to connect to an MCP server that provides additional +/// tools and context to the agent. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum McpServer { + /// Stdio transport - spawn a process and communicate via stdin/stdout. + Stdio { + /// Name of the MCP server (for display). + name: String, + /// Command to execute. + command: String, + /// Command-line arguments. + args: Vec, + /// Environment variables. + env: Vec, + }, + /// HTTP transport - connect to an HTTP endpoint. + Http { + /// Name of the MCP server (for display). + name: String, + /// Base URL of the HTTP endpoint. + url: String, + /// HTTP headers to include in requests. + headers: Vec, + }, + /// SSE (Server-Sent Events) transport - connect to an SSE endpoint. + Sse { + /// Name of the MCP server (for display). + name: String, + /// URL of the SSE endpoint. + url: String, + /// HTTP headers to include in connection. + headers: Vec, + }, +} + +/// Environment variable for MCP server. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct EnvVariable { + pub name: String, + pub value: String, +} + +/// HTTP header for MCP server. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct HttpHeader { + pub name: String, + pub value: String, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_connection_state_transitions() { + let mut state = AcpConnectorState::new(); + assert_eq!(state.connection_state, ConnectionState::Uninitialized); + assert!(!state.is_ready()); + + state.connection_state = ConnectionState::Initializing; + assert!(!state.is_ready()); + + state.connection_state = ConnectionState::Initialized; + assert!(!state.is_ready()); + + state.connection_state = ConnectionState::Ready; + assert!(state.is_ready()); + } + + #[test] + fn test_requires_auth() { + let mut state = AcpConnectorState::new(); + assert!(!state.requires_auth()); + + state.auth_methods = vec!["api_key".to_string()]; + assert!(state.requires_auth()); + + state.authenticated = true; + assert!(!state.requires_auth()); + } + + #[test] + fn test_client_capabilities_default_safe() { + let caps = ClientCapabilities::default_safe(); + assert!(caps.fs.is_some()); + assert_eq!(caps.fs.as_ref().unwrap().read_text_file, Some(true)); + assert_eq!(caps.fs.as_ref().unwrap().write_text_file, Some(false)); + assert_eq!(caps.terminal, Some(false)); + } + + #[test] + fn test_client_capabilities_all_enabled() { + let caps = ClientCapabilities::all_enabled(); + assert!(caps.fs.is_some()); + assert_eq!(caps.fs.as_ref().unwrap().read_text_file, Some(true)); + assert_eq!(caps.fs.as_ref().unwrap().write_text_file, Some(true)); + assert_eq!(caps.terminal, Some(true)); + } + + #[test] + fn test_agent_capabilities_checks() { + let caps = AgentCapabilities { + load_session: Some(true), + prompt_capabilities: Some(PromptCapabilities { + image: Some(true), + audio: Some(false), + embedded_context: Some(true), + }), + mcp: None, + _meta: None, + }; + + assert!(caps.supports_load_session()); + assert!(caps.supports_image()); + assert!(!caps.supports_audio()); + assert!(caps.supports_embedded_context()); + } + + #[test] + fn test_serialization() { + let caps = ClientCapabilities { + fs: Some(FsCapabilities { + read_text_file: Some(true), + write_text_file: Some(true), + }), + terminal: Some(true), + _meta: Some(serde_json::json!({"custom": "data"})), + }; + + let json = serde_json::to_string(&caps).unwrap(); + let deserialized: ClientCapabilities = serde_json::from_str(&json).unwrap(); + + assert_eq!(caps, deserialized); + } + + #[test] + fn test_implementation_info() { + let info = ImplementationInfo { + name: "test-impl".to_string(), + title: Some("Test Implementation".to_string()), + version: Some("1.0.0".to_string()), + }; + + let json = serde_json::to_string(&info).unwrap(); + let deserialized: ImplementationInfo = serde_json::from_str(&json).unwrap(); + + assert_eq!(info, deserialized); + } +} diff --git a/crates/dirigent_core/src/acp/content_blocks.rs b/crates/dirigent_core/src/acp/content_blocks.rs new file mode 100644 index 0000000..777c37c --- /dev/null +++ b/crates/dirigent_core/src/acp/content_blocks.rs @@ -0,0 +1,546 @@ +//! Content block generation from file attachments. +//! +//! This module provides the `ContentBlockBuilder` for converting file paths into +//! appropriate `ContentBlock` variants based on: +//! - Agent capabilities (embedded context support) +//! - File properties (size, type, content) +//! - Embedding configuration (size limits, redaction, snippet strategy) +//! - Sandbox restrictions (allowed roots, blocklists) + +use crate::acp::connector_state::AgentCapabilities; +use crate::acp::protocol::prompt::{Annotations, ContentBlock, EmbeddedResource}; +use base64::Engine; +use dirigent_tools::config::{EmbeddingConfig, SandboxConfig}; +use dirigent_tools::embedding::{EmbeddingDecider, EmbeddingStrategy}; +use dirigent_tools::error::{ToolError, ToolResult}; +use dirigent_tools::path::validate_path; +use std::path::{Path, PathBuf}; +use tracing::{debug, warn}; + +/// Builder for generating content blocks from file attachments. +/// +/// This orchestrates the file embedding pipeline: +/// 1. Validate paths against sandbox +/// 2. Decide embedding strategy per file +/// 3. Apply redaction to embedded content +/// 4. Generate appropriate ContentBlock variants +/// 5. Track accumulated bytes and file counts +pub struct ContentBlockBuilder { + /// Agent capabilities (embedded context support). + // TODO: Capability validation - will be used to validate content blocks against agent capabilities + #[allow(dead_code)] + agent_caps: AgentCapabilities, + /// Embedding configuration. + embedding_config: EmbeddingConfig, + /// Sandbox configuration. + sandbox_config: SandboxConfig, + /// Embedding decider (tracks accumulated state). + decider: EmbeddingDecider, +} + +impl ContentBlockBuilder { + /// Create a new content block builder. + /// + /// # Arguments + /// + /// * `agent_caps` - Agent capabilities (from Initialize response) + /// * `embedding_config` - Embedding configuration + /// * `sandbox_config` - Sandbox configuration + pub fn new( + agent_caps: AgentCapabilities, + embedding_config: EmbeddingConfig, + sandbox_config: SandboxConfig, + ) -> Self { + let agent_supports_embedded = agent_caps.supports_embedded_context(); + let decider = EmbeddingDecider::new(embedding_config.clone(), agent_supports_embedded); + + Self { + agent_caps, + embedding_config, + sandbox_config, + decider, + } + } + + /// Build content blocks from a list of file paths. + /// + /// This is the main entry point for converting file attachments into content blocks. + /// It validates each path, decides the embedding strategy, and generates the + /// appropriate ContentBlock variant. + /// + /// # Arguments + /// + /// * `files` - List of file paths to process (must be absolute) + /// + /// # Returns + /// + /// A vector of `ContentBlock` variants, or an error if any path is invalid + /// or if the embedding pipeline fails. + /// + /// # Errors + /// + /// - `ToolError::SandboxViolation` - Path outside allowed roots or blocked + /// - `ToolError::InvalidPath` - Path doesn't exist or is not a file + /// - `ToolError::FileReadError` - Failed to read file content + /// + /// # Security + /// + /// All paths are validated against the sandbox before processing. Absolute + /// paths are never exposed in error messages. + pub fn build_from_files(&mut self, files: &[PathBuf]) -> ToolResult> { + let mut blocks = Vec::new(); + + for file_path in files { + // Validate path against sandbox first + let validated_path = self.validate_file_path(file_path)?; + + // Decide embedding strategy + let strategy = self.decider.decide(&validated_path)?; + + // Generate content block based on strategy + match self.generate_content_block(&validated_path, strategy)? { + Some(block) => blocks.push(block), + None => { + // Strategy was Deny - log and skip + let filename = validated_path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown"); + warn!("File denied for embedding: {}", filename); + } + } + } + + debug!( + "Generated {} content blocks from {} files ({} bytes accumulated)", + blocks.len(), + files.len(), + self.decider.accumulated_bytes() + ); + + Ok(blocks) + } + + /// Validate a file path against sandbox configuration. + /// + /// This ensures the path: + /// - Is within allowed roots + /// - Is not in blocklist + /// - Exists and is a file + /// + /// # Security + /// + /// Error messages never expose absolute paths - only basenames. + fn validate_file_path(&self, path: &Path) -> ToolResult { + // Convert to string for validation + let path_str = path + .to_str() + .ok_or_else(|| ToolError::InvalidConfig( + format!("Path contains invalid UTF-8: {}", path.display()) + ))?; + + // Validate against sandbox + let canonical_path = validate_path(path_str, &self.sandbox_config)?; + + // Ensure it's a file, not a directory + if !canonical_path.is_file() { + let filename = canonical_path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown"); + return Err(ToolError::FileReadError { + path: filename.to_string(), + source: std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Not a file" + ), + }); + } + + Ok(canonical_path) + } + + /// Generate a content block from a validated path and strategy. + /// + /// Returns `None` if the strategy is `Deny`. + fn generate_content_block( + &self, + path: &Path, + strategy: EmbeddingStrategy, + ) -> ToolResult> { + match strategy { + EmbeddingStrategy::EmbedText { content, mime_type } => { + // Apply redaction if configured + let redacted_content = self.apply_redaction(&content); + + // Generate URI for the resource + let uri = self.generate_uri(path); + + // Create embedded resource + let resource = EmbeddedResource::Text { + uri: uri.clone(), + text: redacted_content, + mime_type: Some(mime_type), + }; + + // Create annotations (mark as assistant-focused) + let annotations = Some(Annotations { + audience: Some(vec!["assistant".to_string()]), + priority: Some(0.5), + }); + + Ok(Some(ContentBlock::Resource { + resource, + annotations, + })) + } + + EmbeddingStrategy::EmbedBlob { data, mime_type } => { + // Encode as base64 + let blob = base64::engine::general_purpose::STANDARD.encode(&data); + + // Generate URI + let uri = self.generate_uri(path); + + // Create embedded resource + let resource = EmbeddedResource::Blob { + uri: uri.clone(), + blob, + mime_type: Some(mime_type), + }; + + // Create annotations + let annotations = Some(Annotations { + audience: Some(vec!["assistant".to_string()]), + priority: Some(0.5), + }); + + Ok(Some(ContentBlock::Resource { + resource, + annotations, + })) + } + + EmbeddingStrategy::Link { + uri, + name, + size, + mime_type, + } => { + // Create resource link + let annotations = Some(Annotations { + audience: Some(vec!["assistant".to_string()]), + priority: Some(0.3), + }); + + Ok(Some(ContentBlock::ResourceLink { + uri, + name, + mime_type, + title: None, + description: None, + size: Some(size), + annotations, + })) + } + + EmbeddingStrategy::Snippet { + head, + tail, + total_size: _, + mime_type, + } => { + // For snippets, create both a resource (embedded snippet) and a link (full file) + // Combine head and tail with separator + let separator = "\n\n[... truncated ...]\n\n"; + let snippet_content = format!("{}{}{}", head, separator, tail); + + // Apply redaction to snippet + let redacted_snippet = self.apply_redaction(&snippet_content); + + // Generate URI + let uri = self.generate_uri(path); + + // Create embedded resource for snippet + let resource = EmbeddedResource::Text { + uri: uri.clone(), + text: redacted_snippet, + mime_type: Some(mime_type.clone()), + }; + + // Create annotations for snippet + let annotations = Some(Annotations { + audience: Some(vec!["assistant".to_string()]), + priority: Some(0.4), + }); + + // Return the snippet as embedded resource + // Note: Could optionally also include a ResourceLink to the full file + Ok(Some(ContentBlock::Resource { + resource, + annotations, + })) + } + + EmbeddingStrategy::Deny { reason } => { + // Log denial reason + let filename = path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown"); + warn!("File embedding denied for {}: {}", filename, reason); + + // Return None to skip this file + Ok(None) + } + } + } + + /// Generate a stable URI for a file path. + /// + /// Uses BLAKE3 hash of the canonical path for stability and opacity. + fn generate_uri(&self, path: &Path) -> String { + use blake3::Hasher; + + let mut hasher = Hasher::new(); + hasher.update(path.to_string_lossy().as_bytes()); + let hash = hasher.finalize(); + + format!("dirigent://resource/{}", hash.to_hex()) + } + + /// Apply redaction patterns to content. + /// + /// This is a best-effort operation that doesn't modify files on disk. + fn apply_redaction(&self, content: &str) -> String { + if self.embedding_config.redact_patterns.is_empty() { + return content.to_string(); + } + + // Use ContentRedactor from dirigent_tools + match dirigent_tools::embedding::ContentRedactor::new(&self.embedding_config.redact_patterns) + { + Ok(redactor) => redactor.redact(content), + Err(e) => { + warn!("Failed to create redactor: {}", e); + content.to_string() + } + } + } + + /// Get the accumulated bytes processed so far. + pub fn accumulated_bytes(&self) -> usize { + self.decider.accumulated_bytes() + } + + /// Get the number of files processed so far. + pub fn file_count(&self) -> usize { + self.decider.file_count() + } +} + +/// Convenience function for building content blocks from files. +/// +/// This is a simplified API for one-off conversions. +/// +/// # Arguments +/// +/// * `files` - List of file paths to convert (must be absolute) +/// * `agent_caps` - Agent capabilities +/// * `embedding_config` - Embedding configuration +/// * `sandbox_config` - Sandbox configuration +/// +/// # Returns +/// +/// A vector of `ContentBlock` variants. +pub fn build_content_blocks_from_files( + files: &[PathBuf], + agent_caps: &AgentCapabilities, + embedding_config: &EmbeddingConfig, + sandbox_config: &SandboxConfig, +) -> ToolResult> { + let mut builder = ContentBlockBuilder::new( + agent_caps.clone(), + embedding_config.clone(), + sandbox_config.clone(), + ); + + builder.build_from_files(files) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::acp::connector_state::PromptCapabilities; + use tempfile::TempDir; + + fn create_test_agent_caps(embedded_context: bool) -> AgentCapabilities { + AgentCapabilities { + prompt_capabilities: Some(PromptCapabilities { + embedded_context: Some(embedded_context), + ..Default::default() + }), + ..Default::default() + } + } + + fn create_test_config(temp_dir: &TempDir) -> (EmbeddingConfig, SandboxConfig) { + let embedding_config = EmbeddingConfig { + max_embed_bytes: 1000, + allow_resource_link: true, + redact_patterns: vec![], + snippet_strategy: dirigent_tools::config::SnippetStrategy::HeadTail, + max_files_per_prompt: 10, + }; + + let mut sandbox_config = SandboxConfig::default(); + sandbox_config.allowed_roots = vec![temp_dir.path().to_path_buf()]; + sandbox_config.normalize_roots(); + + (embedding_config, sandbox_config) + } + + #[test] + fn test_small_text_file_embeds() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + std::fs::write(&file_path, "Hello, world!").unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + assert_eq!(blocks.len(), 1); + match &blocks[0] { + ContentBlock::Resource { resource, .. } => match resource { + EmbeddedResource::Text { text, .. } => { + assert_eq!(text, "Hello, world!"); + } + _ => panic!("Expected text resource"), + }, + _ => panic!("Expected resource block"), + } + } + + #[test] + fn test_large_file_creates_link() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("large.txt"); + let large_content = "x".repeat(2000); // Exceeds 1000 byte limit + std::fs::write(&file_path, &large_content).unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + assert_eq!(blocks.len(), 1); + match &blocks[0] { + ContentBlock::ResourceLink { size, .. } => { + assert_eq!(*size, Some(2000)); + } + _ => panic!("Expected resource link block"), + } + } + + #[test] + fn test_sandbox_violation_rejected() { + let temp_dir = TempDir::new().unwrap(); + let outside_dir = TempDir::new().unwrap(); + let file_path = outside_dir.path().join("outside.txt"); + std::fs::write(&file_path, "Outside sandbox").unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let result = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ); + + assert!(result.is_err()); + assert!(matches!(result, Err(ToolError::SandboxViolation { .. }))); + } + + #[test] + fn test_multiple_files_accumulate_bytes() { + let temp_dir = TempDir::new().unwrap(); + + let file1 = temp_dir.path().join("file1.txt"); + std::fs::write(&file1, "File 1 content").unwrap(); + + let file2 = temp_dir.path().join("file2.txt"); + std::fs::write(&file2, "File 2 content").unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks = build_content_blocks_from_files( + &[file1, file2], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + assert_eq!(blocks.len(), 2); + } + + #[test] + fn test_uri_generation_is_stable() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("stable.txt"); + std::fs::write(&file_path, "Stable content").unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks1 = build_content_blocks_from_files( + &[file_path.clone()], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + let blocks2 = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + // Extract URIs and compare + match (&blocks1[0], &blocks2[0]) { + ( + ContentBlock::Resource { + resource: EmbeddedResource::Text { uri: uri1, .. }, + .. + }, + ContentBlock::Resource { + resource: EmbeddedResource::Text { uri: uri2, .. }, + .. + }, + ) => { + assert_eq!(uri1, uri2, "URIs should be stable"); + } + _ => panic!("Expected resource blocks"), + } + } +} diff --git a/crates/dirigent_core/src/acp/mod.rs b/crates/dirigent_core/src/acp/mod.rs new file mode 100644 index 0000000..83727ce --- /dev/null +++ b/crates/dirigent_core/src/acp/mod.rs @@ -0,0 +1,67 @@ +//! ACP (Agent-Client Protocol) client implementation. +//! +//! This module provides a complete ACP client implementation, including: +//! - Transport layer (stdio and HTTP+SSE) +//! - Protocol implementation (initialization, sessions, prompts) +//! - Tool execution and permission handling +//! - Streaming updates and notifications +//! +//! # Architecture +//! +//! The ACP implementation is organized into layers: +//! +//! 1. **Transport Layer** (`transport/`) - JSON-RPC types and shared utilities +//! - JSON-RPC message types (request, response, notification, error) +//! - `JsonLineReader` for resilient multi-line JSON reading over stdio +//! - Transport trait definition +//! - Actual transport implementations live in `connectors::acp::transport/` +//! +//! 2. **Protocol Layer** (`protocol/`) - ACP message handling +//! - Initialization and capability negotiation +//! - Authentication (optional) +//! - Capability validation +//! - Session lifecycle (new, load, cancel) +//! - Prompt turns and streaming updates +//! - Error handling and classification +//! +//! # Example +//! +//! ```ignore +//! use dirigent_core::connectors::acp::transport::{AcpTransport, StdioTransport}; +//! +//! # async fn example() -> Result<(), Box> { +//! let mut transport = StdioTransport::new("claude", &["--acp"]); +//! transport.connect().await?; +//! +//! // Send and receive JSON-RPC messages +//! transport.send(serde_json::json!({ +//! "jsonrpc": "2.0", +//! "method": "initialize", +//! "id": 1, +//! "params": {} +//! })).await?; +//! +//! if let Some(response) = transport.recv().await? { +//! println!("Response: {:?}", response); +//! } +//! # Ok(()) +//! # } +//! ``` + +pub mod connector_state; +pub mod content_blocks; +pub mod protocol; +pub mod transport; + +// Re-export commonly used types +pub use connector_state::{ + AcpConnectorState, AgentCapabilities, ClientCapabilities, ConnectionState, EnvVariable, + FsCapabilities, HttpHeader, ImplementationInfo, McpCapabilities, McpServer, PromptCapabilities, + SessionState, SharedAcpState, +}; +pub use content_blocks::{build_content_blocks_from_files, ContentBlockBuilder}; +pub use protocol::{authenticate, capabilities, initialize}; +pub use transport::{ + JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, JsonRpcResult, Transport, + TransportError, TransportState, +}; diff --git a/crates/dirigent_core/src/acp/protocol/authenticate.rs b/crates/dirigent_core/src/acp/protocol/authenticate.rs new file mode 100644 index 0000000..5c2e8d4 --- /dev/null +++ b/crates/dirigent_core/src/acp/protocol/authenticate.rs @@ -0,0 +1,254 @@ +//! Authentication method implementation for ACP. +//! +//! This module implements the `authenticate` JSON-RPC method for agents that +//! require authentication after initialization. The authentication flow is optional +//! and only used if the agent advertises auth_methods in the initialize response. + +use crate::acp::transport::{JsonRpcRequest, JsonRpcResult, Transport, TransportError}; +use serde::{Deserialize, Serialize}; + +/// Authentication request sent to the agent. +/// +/// This is sent after initialization if the agent advertised auth_methods +/// and the client has credentials available. +#[derive(Clone, Serialize, Deserialize)] +pub struct AuthenticateRequest { + /// Authentication method to use (must be one of the methods advertised by agent). + pub method: String, + + /// Method-specific credentials (structure depends on the authentication method). + /// Common examples: + /// - API key: `{"api_key": "secret"}` + /// - OAuth: `{"token": "bearer_token"}` + pub credentials: serde_json::Value, +} + +impl AuthenticateRequest { + /// Create a new authentication request. + pub fn new(method: impl Into, credentials: serde_json::Value) -> Self { + Self { + method: method.into(), + credentials, + } + } + + /// Convert to JSON-RPC request. + pub fn to_jsonrpc(&self, id: u64) -> JsonRpcRequest { + JsonRpcRequest::new( + id, + "authenticate", + Some(serde_json::to_value(self).expect("Failed to serialize AuthenticateRequest")), + ) + } +} + +// Custom Debug implementation that redacts credentials (no derive(Debug) above) +impl std::fmt::Debug for AuthenticateRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AuthenticateRequest") + .field("method", &self.method) + .field("credentials", &"[REDACTED]") + .finish() + } +} + +/// Authentication response from the agent. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct AuthenticateResponse { + /// Whether authentication succeeded. + pub success: bool, + + /// Error message if authentication failed. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl AuthenticateResponse { + /// Parse from JSON-RPC result. + pub fn from_jsonrpc(result: &serde_json::Value) -> Result { + serde_json::from_value(result.clone()) + } +} + +/// Result of authentication attempt. +#[derive(Debug)] +pub enum AuthenticateResult { + /// Authentication succeeded. + Success, + /// Authentication failed with error message. + Failed(String), + /// Agent does not require authentication. + NotRequired, +} + +/// Perform authentication with the agent. +/// +/// This sends an authenticate request if the agent requires authentication. +/// If auth_methods is empty, returns NotRequired without sending a request. +/// +/// # Arguments +/// +/// * `transport` - The transport to use for communication +/// * `auth_methods` - Authentication methods advertised by the agent (from initialize response) +/// * `method` - Authentication method to use (must be in auth_methods) +/// * `credentials` - Method-specific credentials +/// +/// # Returns +/// +/// Result of the authentication attempt. +/// +/// # Security +/// +/// Credentials are never logged in plain text. The AuthenticateRequest Debug +/// implementation redacts credentials. +pub async fn authenticate( + transport: &mut dyn Transport, + auth_methods: &[String], + method: impl Into, + credentials: serde_json::Value, +) -> Result { + let method_str = method.into(); + + // Check if authentication is required + if auth_methods.is_empty() { + return Ok(AuthenticateResult::NotRequired); + } + + // Validate that the requested method is supported + if !auth_methods.contains(&method_str) { + return Ok(AuthenticateResult::Failed(format!( + "Authentication method '{}' not supported by agent. Supported methods: {:?}", + method_str, auth_methods + ))); + } + + // Create authenticate request + let request = AuthenticateRequest::new(method_str, credentials); + let jsonrpc_request = request.to_jsonrpc(2); // Use ID 2 (ID 1 is for initialize) + + // Send request and await response + let result = transport.send_request(jsonrpc_request).await?; + + // Handle response + match result { + JsonRpcResult::Success(response) => { + // Parse authenticate response + let auth_response = AuthenticateResponse::from_jsonrpc(&response.result) + .map_err(|e| TransportError::SerializationError(e))?; + + if auth_response.success { + Ok(AuthenticateResult::Success) + } else { + Ok(AuthenticateResult::Failed( + auth_response + .error + .unwrap_or_else(|| "Authentication failed".to_string()), + )) + } + } + JsonRpcResult::Error(error_response) => { + // Authentication failed with JSON-RPC error + Ok(AuthenticateResult::Failed(error_response.error.message)) + } + } +} + +/// Check if authentication is required based on agent's auth_methods. +pub fn is_auth_required(auth_methods: &[String]) -> bool { + !auth_methods.is_empty() +} + +/// Validate that credentials do not appear in log output. +/// +/// This is a utility function for testing that ensures credentials are properly +/// redacted in debug output. +#[cfg(test)] +pub fn ensure_credentials_redacted(debug_str: &str) -> bool { + !debug_str.contains("secret") && !debug_str.contains("password") && !debug_str.contains("token") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_authenticate_request_creation() { + let creds = serde_json::json!({"api_key": "secret123"}); + let request = AuthenticateRequest::new("api_key", creds.clone()); + + assert_eq!(request.method, "api_key"); + assert_eq!(request.credentials, creds); + } + + #[test] + fn test_authenticate_request_redacts_credentials() { + let creds = serde_json::json!({"api_key": "secret123"}); + let request = AuthenticateRequest::new("api_key", creds); + + let debug_str = format!("{:?}", request); + + // Should contain method + assert!(debug_str.contains("api_key")); + // Should NOT contain credentials + assert!(!debug_str.contains("secret123")); + // Should show redacted marker + assert!(debug_str.contains("REDACTED")); + } + + #[test] + fn test_authenticate_request_to_jsonrpc() { + let creds = serde_json::json!({"api_key": "secret"}); + let request = AuthenticateRequest::new("api_key", creds); + + let jsonrpc = request.to_jsonrpc(42); + + assert_eq!(jsonrpc.method, "authenticate"); + assert_eq!(jsonrpc.id, serde_json::Value::Number(42.into())); + assert!(jsonrpc.params.is_some()); + } + + #[test] + fn test_authenticate_response_success() { + let json = serde_json::json!({ + "success": true + }); + + let response = AuthenticateResponse::from_jsonrpc(&json).unwrap(); + + assert!(response.success); + assert!(response.error.is_none()); + } + + #[test] + fn test_authenticate_response_failure() { + let json = serde_json::json!({ + "success": false, + "error": "Invalid API key" + }); + + let response = AuthenticateResponse::from_jsonrpc(&json).unwrap(); + + assert!(!response.success); + assert_eq!(response.error, Some("Invalid API key".to_string())); + } + + #[test] + fn test_is_auth_required() { + assert!(!is_auth_required(&[])); + assert!(is_auth_required(&["api_key".to_string()])); + assert!(is_auth_required(&["api_key".to_string(), "oauth".to_string()])); + } + + #[test] + fn test_authenticate_response_serialization() { + let response = AuthenticateResponse { + success: true, + error: None, + }; + + let json = serde_json::to_string(&response).unwrap(); + let deserialized: AuthenticateResponse = serde_json::from_str(&json).unwrap(); + + assert_eq!(response, deserialized); + } +} diff --git a/crates/dirigent_core/src/acp/protocol/cancellation.rs b/crates/dirigent_core/src/acp/protocol/cancellation.rs new file mode 100644 index 0000000..37112f3 --- /dev/null +++ b/crates/dirigent_core/src/acp/protocol/cancellation.rs @@ -0,0 +1,237 @@ +//! Cancellation and disconnect handling for ACP. +//! +//! This module implements: +//! - Cancellation cleanup (tool calls, permission requests, state reset) +//! - Disconnect detection and handling +//! - Timeout handling for cancellation + +use crate::acp::connector_state::{ConnectionState, SessionState}; +use crate::acp::protocol::streaming::{ToolCallInfo, ToolCallStatus}; +use crate::acp::transport::{Transport, TransportError, TransportState}; +use std::time::Duration; + +/// Errors that can occur during cancellation. +#[derive(Debug, thiserror::Error)] +pub enum CancellationError { + #[error("Transport error during cancellation: {0}")] + TransportError(#[from] TransportError), + + #[error("Cancellation timed out after {0:?}")] + Timeout(Duration), + + #[error("Session not found: {0}")] + SessionNotFound(String), +} + +/// Handle cancellation of a prompt turn. +/// +/// This function performs comprehensive cleanup when a user cancels a prompt: +/// 1. Marks all pending tool calls as cancelled +/// 2. Sends session/cancel notification to agent +/// 3. Waits for agent to respond with stopReason: cancelled (with timeout) +/// +/// # Arguments +/// +/// * `session_state` - The session state to clean up +/// * `transport` - The transport to send cancellation notification +/// +/// # Returns +/// +/// Ok(()) if cancellation completed successfully, or an error if it failed. +pub async fn handle_cancellation( + session_state: &mut SessionState, + transport: &mut dyn Transport, +) -> Result<(), CancellationError> { + // Mark session as cancelling + session_state.start_cancellation(); + + // Send session/cancel notification + let notification = crate::acp::protocol::session::SessionCancelNotification::new( + session_state.session_id.clone(), + ); + + let jsonrpc_notification = crate::acp::transport::JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: "session/cancel".to_string(), + params: Some( + serde_json::to_value(¬ification) + .expect("Failed to serialize SessionCancelNotification"), + ), + }; + + transport.send_notification(jsonrpc_notification).await?; + + // Note: Actual waiting for stopReason: cancelled is done by the prompt + // response handler. This function just initiates the cancellation. + + Ok(()) +} + +/// Mark all pending tool calls as cancelled. +/// +/// This is called immediately when cancellation is initiated, before +/// waiting for the agent to respond. +pub fn cancel_pending_tool_calls(tool_calls: &mut [ToolCallInfo]) { + for tool_call in tool_calls.iter_mut() { + if !tool_call.is_terminal() { + tool_call.status = ToolCallStatus::Cancelled; + } + } +} + +/// Handle transport disconnect. +/// +/// This function is called when the transport connection is lost. It: +/// 1. Marks all pending operations as failed +/// 2. Updates connection state to Disconnected +/// 3. Cleans up resources +/// +/// # Arguments +/// +/// * `session_state` - The session state to clean up +/// * `connection_state` - The connection state to update +/// * `tool_calls` - Active tool calls to mark as failed +pub fn handle_disconnect( + session_state: &mut SessionState, + connection_state: &mut ConnectionState, + tool_calls: &mut [ToolCallInfo], +) { + // Mark all pending tool calls as failed + for tool_call in tool_calls.iter_mut() { + if !tool_call.is_terminal() { + tool_call.status = ToolCallStatus::Failed; + } + } + + // Clear session flags + session_state.complete_prompt(); + session_state.cancelling = false; + session_state.loading = false; + + // Update connection state + *connection_state = ConnectionState::Disconnected; +} + +/// Check if a transport is connected. +pub fn is_transport_connected(transport: &dyn Transport) -> bool { + transport.is_connected() && transport.state() == TransportState::Connected +} + +/// Detect if a disconnect has occurred. +/// +/// This checks the transport state and returns true if the connection +/// has been lost. +pub fn detect_disconnect(transport: &dyn Transport) -> bool { + !is_transport_connected(transport) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::acp::protocol::streaming::ToolKind; + + #[test] + fn test_cancel_pending_tool_calls() { + let mut tool_calls = vec![ + ToolCallInfo::new( + "tool-1".to_string(), + "Test 1".to_string(), + ToolKind::Read, + ToolCallStatus::Pending, + None, + None, + ), + ToolCallInfo::new( + "tool-2".to_string(), + "Test 2".to_string(), + ToolKind::Edit, + ToolCallStatus::InProgress, + None, + None, + ), + ToolCallInfo::new( + "tool-3".to_string(), + "Test 3".to_string(), + ToolKind::Execute, + ToolCallStatus::Completed, + None, + None, + ), + ]; + + cancel_pending_tool_calls(&mut tool_calls); + + // First two should be cancelled (not terminal) + assert_eq!(tool_calls[0].status, ToolCallStatus::Cancelled); + assert_eq!(tool_calls[1].status, ToolCallStatus::Cancelled); + + // Last one should remain completed (already terminal) + assert_eq!(tool_calls[2].status, ToolCallStatus::Completed); + } + + #[test] + fn test_handle_disconnect() { + let mut session_state = + SessionState::new("session-123".to_string(), "/path".to_string(), vec![]); + session_state.start_prompt(); + + let mut connection_state = ConnectionState::Ready; + + let mut tool_calls = vec![ + ToolCallInfo::new( + "tool-1".to_string(), + "Test".to_string(), + ToolKind::Read, + ToolCallStatus::InProgress, + None, + None, + ), + ToolCallInfo::new( + "tool-2".to_string(), + "Test 2".to_string(), + ToolKind::Edit, + ToolCallStatus::Completed, + None, + None, + ), + ]; + + handle_disconnect(&mut session_state, &mut connection_state, &mut tool_calls); + + // In-progress tool call should be marked failed + assert_eq!(tool_calls[0].status, ToolCallStatus::Failed); + + // Completed tool call should remain completed + assert_eq!(tool_calls[1].status, ToolCallStatus::Completed); + + // Session state should be cleared + assert!(!session_state.prompt_in_progress); + assert!(!session_state.cancelling); + assert!(!session_state.loading); + + // Connection state should be disconnected + assert_eq!(connection_state, ConnectionState::Disconnected); + } + + #[test] + fn test_session_cancellation_state() { + let mut session_state = + SessionState::new("session-123".to_string(), "/path".to_string(), vec![]); + + // Start a prompt + session_state.start_prompt(); + assert!(session_state.prompt_in_progress); + assert!(!session_state.is_ready()); + + // Start cancellation + session_state.start_cancellation(); + assert!(session_state.cancelling); + assert!(!session_state.is_ready()); + + // Complete the prompt (after cancel) + session_state.complete_prompt(); + assert!(!session_state.prompt_in_progress); + assert!(!session_state.cancelling); + assert!(session_state.is_ready()); + } +} diff --git a/crates/dirigent_core/src/acp/protocol/capabilities.rs b/crates/dirigent_core/src/acp/protocol/capabilities.rs new file mode 100644 index 0000000..13aff7b --- /dev/null +++ b/crates/dirigent_core/src/acp/protocol/capabilities.rs @@ -0,0 +1,261 @@ +//! Capability validation for ACP. +//! +//! This module implements validation logic to ensure that agent requests only +//! use capabilities that the client advertised as supported. Requests for +//! unsupported capabilities are rejected with clear error messages. + +use crate::acp::connector_state::ClientCapabilities; +use crate::acp::transport::JsonRpcError; + +/// JSON-RPC error code for method not found (unsupported capability). +pub const ERROR_METHOD_NOT_FOUND: i32 = -32601; + +/// Result of capability validation. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum CapabilityValidation { + /// The capability is supported and the method can proceed. + Supported, + /// The capability is not supported. + Unsupported(String), +} + +/// Validate that a method is supported by the client's advertised capabilities. +/// +/// This checks the method name against the client capabilities and returns +/// whether the method should be allowed to proceed. +/// +/// # Arguments +/// +/// * `method` - The JSON-RPC method name to validate +/// * `capabilities` - The client capabilities advertised during initialization +/// +/// # Returns +/// +/// `CapabilityValidation::Supported` if the method is allowed, or +/// `CapabilityValidation::Unsupported` with an error message if not. +pub fn validate_capability( + method: &str, + capabilities: &ClientCapabilities, +) -> CapabilityValidation { + match method { + // Filesystem capabilities + "fs/read_text_file" => { + if capabilities + .fs + .as_ref() + .and_then(|fs| fs.read_text_file) + .unwrap_or(false) + { + CapabilityValidation::Supported + } else { + CapabilityValidation::Unsupported( + "fs/read_text_file capability not advertised".to_string(), + ) + } + } + "fs/write_text_file" => { + if capabilities + .fs + .as_ref() + .and_then(|fs| fs.write_text_file) + .unwrap_or(false) + { + CapabilityValidation::Supported + } else { + CapabilityValidation::Unsupported( + "fs/write_text_file capability not advertised".to_string(), + ) + } + } + + // Terminal capabilities + method if method.starts_with("terminal/") => { + if capabilities.terminal.unwrap_or(false) { + CapabilityValidation::Supported + } else { + CapabilityValidation::Unsupported( + "terminal capability not advertised".to_string(), + ) + } + } + + // All other methods are assumed to be core protocol methods that don't + // require capability checks (initialize, authenticate, session/*, etc.) + _ => CapabilityValidation::Supported, + } +} + +/// Create a JSON-RPC error for an unsupported capability. +/// +/// This generates a properly formatted JSON-RPC error response that can be +/// sent back to the agent when it requests an unsupported method. +pub fn capability_not_supported_error(method: &str, reason: &str) -> JsonRpcError { + JsonRpcError { + code: ERROR_METHOD_NOT_FOUND, + message: format!("Method not found: {} ({})", method, reason), + data: None, + } +} + +/// Check if a method requires capability validation. +/// +/// Core protocol methods (initialize, authenticate, session operations) do not +/// require capability checks. Tool methods (fs/*, terminal/*) do require checks. +pub fn requires_capability_check(method: &str) -> bool { + method.starts_with("fs/") || method.starts_with("terminal/") +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::acp::connector_state::FsCapabilities; + + #[test] + fn test_validate_read_capability_supported() { + let caps = ClientCapabilities { + fs: Some(FsCapabilities { + read_text_file: Some(true), + write_text_file: Some(false), + }), + terminal: Some(false), + _meta: None, + }; + + let result = validate_capability("fs/read_text_file", &caps); + assert_eq!(result, CapabilityValidation::Supported); + } + + #[test] + fn test_validate_read_capability_unsupported() { + let caps = ClientCapabilities { + fs: Some(FsCapabilities { + read_text_file: Some(false), + write_text_file: Some(false), + }), + terminal: Some(false), + _meta: None, + }; + + let result = validate_capability("fs/read_text_file", &caps); + assert!(matches!(result, CapabilityValidation::Unsupported(_))); + } + + #[test] + fn test_validate_write_capability_supported() { + let caps = ClientCapabilities { + fs: Some(FsCapabilities { + read_text_file: Some(true), + write_text_file: Some(true), + }), + terminal: Some(false), + _meta: None, + }; + + let result = validate_capability("fs/write_text_file", &caps); + assert_eq!(result, CapabilityValidation::Supported); + } + + #[test] + fn test_validate_write_capability_unsupported() { + let caps = ClientCapabilities { + fs: Some(FsCapabilities { + read_text_file: Some(true), + write_text_file: Some(false), + }), + terminal: Some(false), + _meta: None, + }; + + let result = validate_capability("fs/write_text_file", &caps); + assert!(matches!(result, CapabilityValidation::Unsupported(_))); + } + + #[test] + fn test_validate_terminal_capability_supported() { + let caps = ClientCapabilities { + fs: None, + terminal: Some(true), + _meta: None, + }; + + let result = validate_capability("terminal/execute", &caps); + assert_eq!(result, CapabilityValidation::Supported); + + let result = validate_capability("terminal/read", &caps); + assert_eq!(result, CapabilityValidation::Supported); + } + + #[test] + fn test_validate_terminal_capability_unsupported() { + let caps = ClientCapabilities { + fs: None, + terminal: Some(false), + _meta: None, + }; + + let result = validate_capability("terminal/execute", &caps); + assert!(matches!(result, CapabilityValidation::Unsupported(_))); + } + + #[test] + fn test_validate_core_methods_always_supported() { + let caps = ClientCapabilities { + fs: None, + terminal: Some(false), + _meta: None, + }; + + // Core protocol methods should always be supported + assert_eq!( + validate_capability("initialize", &caps), + CapabilityValidation::Supported + ); + assert_eq!( + validate_capability("authenticate", &caps), + CapabilityValidation::Supported + ); + assert_eq!( + validate_capability("session/new", &caps), + CapabilityValidation::Supported + ); + assert_eq!( + validate_capability("session/prompt", &caps), + CapabilityValidation::Supported + ); + } + + #[test] + fn test_validate_capability_missing_fs() { + let caps = ClientCapabilities { + fs: None, + terminal: Some(false), + _meta: None, + }; + + let result = validate_capability("fs/read_text_file", &caps); + assert!(matches!(result, CapabilityValidation::Unsupported(_))); + } + + #[test] + fn test_capability_not_supported_error() { + let error = capability_not_supported_error("fs/write_text_file", "write capability not enabled"); + + assert_eq!(error.code, ERROR_METHOD_NOT_FOUND); + assert!(error.message.contains("fs/write_text_file")); + assert!(error.message.contains("write capability not enabled")); + } + + #[test] + fn test_requires_capability_check() { + // Tool methods require checks + assert!(requires_capability_check("fs/read_text_file")); + assert!(requires_capability_check("fs/write_text_file")); + assert!(requires_capability_check("terminal/execute")); + + // Core methods don't require checks + assert!(!requires_capability_check("initialize")); + assert!(!requires_capability_check("authenticate")); + assert!(!requires_capability_check("session/new")); + assert!(!requires_capability_check("session/prompt")); + } +} diff --git a/crates/dirigent_core/src/acp/protocol/error.rs b/crates/dirigent_core/src/acp/protocol/error.rs new file mode 100644 index 0000000..b425206 --- /dev/null +++ b/crates/dirigent_core/src/acp/protocol/error.rs @@ -0,0 +1,398 @@ +//! Error classification and handling for ACP. +//! +//! This module implements: +//! - Error classification (transient vs terminal vs user) +//! - Retry logic with exponential backoff +//! - Error surfacing to UI +//! - JSON-RPC error code mapping + +use crate::acp::transport::{JsonRpcError, TransportError}; +use std::time::Duration; + +/// Classification of an error. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ErrorClass { + /// Transient error that may succeed on retry (network issues, timeouts). + Transient, + /// Terminal error that requires user intervention (auth failure, unsupported feature). + Terminal, + /// User error (invalid input, permission denied). + User, +} + +/// Classified error with context. +#[derive(Debug, Clone)] +pub struct ClassifiedError { + /// Error classification. + pub class: ErrorClass, + /// JSON-RPC error code (if applicable). + pub code: Option, + /// User-facing message. + pub message: String, + /// Technical details for logs. + pub details: Option, + /// Suggested retry delay (for transient errors). + pub retry_after: Option, +} + +impl ClassifiedError { + /// Create a new classified error. + pub fn new( + class: ErrorClass, + message: impl Into, + ) -> Self { + Self { + class, + code: None, + message: message.into(), + details: None, + retry_after: None, + } + } + + /// Set the error code. + pub fn with_code(mut self, code: i32) -> Self { + self.code = Some(code); + self + } + + /// Set technical details. + pub fn with_details(mut self, details: impl Into) -> Self { + self.details = Some(details.into()); + self + } + + /// Set retry delay. + pub fn with_retry_after(mut self, duration: Duration) -> Self { + self.retry_after = Some(duration); + self + } + + /// Check if this error is retryable. + pub fn is_retryable(&self) -> bool { + self.class == ErrorClass::Transient + } +} + +/// Classify a JSON-RPC error. +pub fn classify_jsonrpc_error(error: &JsonRpcError) -> ClassifiedError { + match error.code { + // JSON-RPC standard errors + -32700 => ClassifiedError::new( + ErrorClass::Terminal, + "Invalid JSON received", + ) + .with_code(-32700) + .with_details(error.message.clone()), + + -32600 => ClassifiedError::new( + ErrorClass::Terminal, + "Invalid request format", + ) + .with_code(-32600) + .with_details(error.message.clone()), + + -32601 => ClassifiedError::new( + ErrorClass::Terminal, + "Method not supported", + ) + .with_code(-32601) + .with_details(error.message.clone()), + + -32602 => ClassifiedError::new( + ErrorClass::User, + "Invalid parameters", + ) + .with_code(-32602) + .with_details(error.message.clone()), + + -32603 => ClassifiedError::new( + ErrorClass::Transient, + "Internal error, please try again", + ) + .with_code(-32603) + .with_details(error.message.clone()), + + // Custom error codes (examples) + -40001 => ClassifiedError::new( + ErrorClass::Terminal, + "Authentication required", + ) + .with_code(-40001) + .with_details(error.message.clone()), + + -40003 => { + // Rate limit error + let retry_after = parse_retry_after(&error.data); + ClassifiedError::new( + ErrorClass::Transient, + "Rate limit exceeded, please wait", + ) + .with_code(-40003) + .with_details(error.message.clone()) + .with_retry_after(retry_after) + } + + -40401 => ClassifiedError::new( + ErrorClass::User, + error.message.clone(), + ) + .with_code(-40401), + + // Default: treat unknown errors as transient + _ => ClassifiedError::new( + ErrorClass::Transient, + error.message.clone(), + ) + .with_code(error.code) + .with_details(format!("Unknown error code: {}", error.code)), + } +} + +/// Classify a transport error. +pub fn classify_transport_error(error: &TransportError) -> ClassifiedError { + match error { + TransportError::ConnectionError(msg) => { + ClassifiedError::new(ErrorClass::Transient, format!("Connection error: {}", msg)) + .with_retry_after(Duration::from_secs(2)) + } + TransportError::IoError(io_err) => { + ClassifiedError::new(ErrorClass::Transient, format!("I/O error: {}", io_err)) + .with_retry_after(Duration::from_secs(2)) + } + TransportError::SerializationError(serde_err) => { + ClassifiedError::new(ErrorClass::Terminal, format!("Serialization error: {}", serde_err)) + } + TransportError::JsonRpcError(jsonrpc_err) => classify_jsonrpc_error(jsonrpc_err), + TransportError::Closed => { + ClassifiedError::new(ErrorClass::Terminal, "Transport closed") + } + TransportError::Timeout => { + ClassifiedError::new(ErrorClass::Transient, "Request timed out") + .with_retry_after(Duration::from_secs(5)) + } + TransportError::ProcessExited => { + ClassifiedError::new(ErrorClass::Terminal, "Process exited unexpectedly") + } + TransportError::HttpError(msg) => { + ClassifiedError::new(ErrorClass::Transient, format!("HTTP error: {}", msg)) + .with_retry_after(Duration::from_secs(2)) + } + TransportError::SseError(msg) => { + ClassifiedError::new(ErrorClass::Transient, format!("SSE error: {}", msg)) + .with_retry_after(Duration::from_secs(2)) + } + TransportError::Other(msg) => { + ClassifiedError::new(ErrorClass::Transient, msg.clone()) + } + } +} + +/// Parse retry-after duration from error data. +/// +/// This looks for a "retry_after" field in the error data that specifies +/// how long to wait before retrying. +fn parse_retry_after(data: &Option) -> Duration { + if let Some(serde_json::Value::Object(map)) = data { + if let Some(serde_json::Value::Number(secs)) = map.get("retry_after") { + if let Some(secs) = secs.as_u64() { + return Duration::from_secs(secs); + } + } + } + + // Default retry delay + Duration::from_secs(10) +} + +/// Calculate exponential backoff delay. +/// +/// # Arguments +/// +/// * `attempt` - The attempt number (0-indexed) +/// * `base_delay` - Base delay for first retry +/// * `max_delay` - Maximum delay cap +/// +/// # Returns +/// +/// The delay to wait before the next attempt. +pub fn exponential_backoff(attempt: u32, base_delay: Duration, max_delay: Duration) -> Duration { + let delay_secs = base_delay.as_secs() * 2u64.pow(attempt); + let delay = Duration::from_secs(delay_secs); + delay.min(max_delay) +} + +/// Severity level for UI error surfacing. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ErrorSeverity { + /// Informational message. + Info, + /// Warning that user should be aware of. + Warning, + /// Error requiring user attention. + Error, + /// Fatal error requiring restart or reconnect. + Fatal, +} + +/// Suggested actions for errors. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ErrorAction { + /// Retry the operation. + Retry, + /// Reconnect to the agent. + Reconnect, + /// Check configuration. + CheckConfig, + /// Contact support. + ContactSupport, + /// Dismiss the error. + Dismiss, +} + +/// Map error classification to UI severity. +pub fn error_severity(classified: &ClassifiedError) -> ErrorSeverity { + match classified.class { + ErrorClass::Transient => ErrorSeverity::Warning, + ErrorClass::Terminal => ErrorSeverity::Error, + ErrorClass::User => ErrorSeverity::Info, + } +} + +/// Suggest actions for an error. +pub fn error_actions(classified: &ClassifiedError) -> Vec { + match classified.class { + ErrorClass::Transient => vec![ErrorAction::Retry, ErrorAction::Dismiss], + ErrorClass::Terminal => { + if classified.code == Some(-40001) { + // Auth error + vec![ErrorAction::CheckConfig, ErrorAction::ContactSupport] + } else { + vec![ErrorAction::CheckConfig, ErrorAction::Dismiss] + } + } + ErrorClass::User => vec![ErrorAction::Dismiss], + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_classify_parse_error() { + let error = JsonRpcError { + code: -32700, + message: "Parse error".to_string(), + data: None, + }; + + let classified = classify_jsonrpc_error(&error); + assert_eq!(classified.class, ErrorClass::Terminal); + assert_eq!(classified.code, Some(-32700)); + assert!(!classified.is_retryable()); + } + + #[test] + fn test_classify_invalid_params() { + let error = JsonRpcError { + code: -32602, + message: "Invalid params".to_string(), + data: None, + }; + + let classified = classify_jsonrpc_error(&error); + assert_eq!(classified.class, ErrorClass::User); + assert!(!classified.is_retryable()); + } + + #[test] + fn test_classify_internal_error() { + let error = JsonRpcError { + code: -32603, + message: "Internal error".to_string(), + data: None, + }; + + let classified = classify_jsonrpc_error(&error); + assert_eq!(classified.class, ErrorClass::Transient); + assert!(classified.is_retryable()); + } + + #[test] + fn test_classify_rate_limit() { + let error = JsonRpcError { + code: -40003, + message: "Rate limited".to_string(), + data: Some(serde_json::json!({"retry_after": 30})), + }; + + let classified = classify_jsonrpc_error(&error); + assert_eq!(classified.class, ErrorClass::Transient); + assert_eq!(classified.retry_after, Some(Duration::from_secs(30))); + } + + #[test] + fn test_classify_transport_timeout() { + let error = TransportError::Timeout; + let classified = classify_transport_error(&error); + assert_eq!(classified.class, ErrorClass::Transient); + assert!(classified.is_retryable()); + } + + #[test] + fn test_classify_transport_closed() { + let error = TransportError::Closed; + let classified = classify_transport_error(&error); + assert_eq!(classified.class, ErrorClass::Terminal); + assert!(!classified.is_retryable()); + } + + #[test] + fn test_exponential_backoff() { + let base = Duration::from_secs(1); + let max = Duration::from_secs(60); + + assert_eq!(exponential_backoff(0, base, max), Duration::from_secs(1)); + assert_eq!(exponential_backoff(1, base, max), Duration::from_secs(2)); + assert_eq!(exponential_backoff(2, base, max), Duration::from_secs(4)); + assert_eq!(exponential_backoff(3, base, max), Duration::from_secs(8)); + assert_eq!(exponential_backoff(10, base, max), Duration::from_secs(60)); // Capped + } + + #[test] + fn test_error_severity_mapping() { + let transient = ClassifiedError::new(ErrorClass::Transient, "Test"); + assert_eq!(error_severity(&transient), ErrorSeverity::Warning); + + let terminal = ClassifiedError::new(ErrorClass::Terminal, "Test"); + assert_eq!(error_severity(&terminal), ErrorSeverity::Error); + + let user = ClassifiedError::new(ErrorClass::User, "Test"); + assert_eq!(error_severity(&user), ErrorSeverity::Info); + } + + #[test] + fn test_error_actions() { + let transient = ClassifiedError::new(ErrorClass::Transient, "Test"); + let actions = error_actions(&transient); + assert!(actions.contains(&ErrorAction::Retry)); + + let terminal = ClassifiedError::new(ErrorClass::Terminal, "Test"); + let actions = error_actions(&terminal); + assert!(actions.contains(&ErrorAction::CheckConfig)); + + let user = ClassifiedError::new(ErrorClass::User, "Test"); + let actions = error_actions(&user); + assert!(actions.contains(&ErrorAction::Dismiss)); + } + + #[test] + fn test_parse_retry_after() { + let data = Some(serde_json::json!({"retry_after": 15})); + assert_eq!(parse_retry_after(&data), Duration::from_secs(15)); + + let no_data = None; + assert_eq!(parse_retry_after(&no_data), Duration::from_secs(10)); // Default + } +} diff --git a/crates/dirigent_core/src/acp/protocol/initialize.rs b/crates/dirigent_core/src/acp/protocol/initialize.rs new file mode 100644 index 0000000..8bff8b8 --- /dev/null +++ b/crates/dirigent_core/src/acp/protocol/initialize.rs @@ -0,0 +1,290 @@ +//! Initialize method implementation for ACP. +//! +//! This module implements the `initialize` JSON-RPC method, which is the first +//! required interaction in the ACP protocol. Both client and agent must agree +//! on protocol version and exchange capability information. + +use crate::acp::connector_state::{ + AgentCapabilities, ClientCapabilities, ImplementationInfo, +}; +use crate::acp::transport::{JsonRpcRequest, JsonRpcResult, Transport, TransportError}; +use serde::{Deserialize, Serialize}; + +/// Current ACP protocol version supported by this implementation. +pub const PROTOCOL_VERSION: u32 = 1; + +/// Initialize request sent to the agent. +/// +/// This is the first message sent after establishing a transport connection. +/// The client advertises its capabilities and requests the agent's capabilities. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct InitializeRequest { + /// Protocol version supported by the client. + pub protocol_version: u32, + + /// Capabilities that the client supports. + pub client_capabilities: ClientCapabilities, + + /// Information about the client implementation. + #[serde(skip_serializing_if = "Option::is_none")] + pub client_info: Option, +} + +impl InitializeRequest { + /// Create a new initialize request with the current protocol version and capabilities. + pub fn new( + client_capabilities: ClientCapabilities, + client_info: Option, + ) -> Self { + Self { + protocol_version: PROTOCOL_VERSION, + client_capabilities, + client_info, + } + } + + /// Convert to JSON-RPC request. + pub fn to_jsonrpc(&self, id: u64) -> JsonRpcRequest { + JsonRpcRequest::new( + id, + "initialize", + Some(serde_json::to_value(self).expect("Failed to serialize InitializeRequest")), + ) + } +} + +/// Initialize response from the agent. +/// +/// The agent responds with its supported protocol version, capabilities, +/// and optional authentication requirements. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct InitializeResponse { + /// Protocol version supported by the agent. + pub protocol_version: u32, + + /// Capabilities that the agent supports. + pub agent_capabilities: AgentCapabilities, + + /// Information about the agent implementation. + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_info: Option, + + /// Authentication methods supported by the agent (empty if no auth required). + #[serde(default)] + pub auth_methods: Vec, +} + +impl InitializeResponse { + /// Parse from JSON-RPC result. + pub fn from_jsonrpc(result: &serde_json::Value) -> Result { + serde_json::from_value(result.clone()) + } +} + +/// Result of initialization attempt. +#[derive(Debug)] +pub enum InitializeResult { + /// Initialization succeeded with negotiated version and agent capabilities. + Success { + protocol_version: u32, + agent_capabilities: AgentCapabilities, + agent_info: Option, + auth_methods: Vec, + }, + /// Version mismatch - client and agent don't support compatible versions. + VersionMismatch { + client_version: u32, + agent_version: u32, + }, + /// Other error during initialization. + Error(String), +} + +/// Perform initialization handshake with the agent. +/// +/// This sends an initialize request and processes the response, performing +/// version negotiation and capability exchange. +/// +/// # Arguments +/// +/// * `transport` - The transport to use for communication +/// * `client_capabilities` - Capabilities that the client supports +/// * `client_info` - Optional client implementation information +/// +/// # Returns +/// +/// Result of the initialization attempt, including negotiated version and +/// agent capabilities on success. +pub async fn initialize( + transport: &mut dyn Transport, + client_capabilities: ClientCapabilities, + client_info: Option, +) -> Result { + // Create initialize request + let request = InitializeRequest::new(client_capabilities, client_info); + let jsonrpc_request = request.to_jsonrpc(1); + + // Send request and await response + let result = transport.send_request(jsonrpc_request).await?; + + // Handle response + match result { + JsonRpcResult::Success(response) => { + // Parse initialize response + let init_response = InitializeResponse::from_jsonrpc(&response.result) + .map_err(|e| TransportError::SerializationError(e))?; + + // Check protocol version compatibility + if init_response.protocol_version != PROTOCOL_VERSION { + return Ok(InitializeResult::VersionMismatch { + client_version: PROTOCOL_VERSION, + agent_version: init_response.protocol_version, + }); + } + + // Version matches, initialization successful + Ok(InitializeResult::Success { + protocol_version: init_response.protocol_version, + agent_capabilities: init_response.agent_capabilities, + agent_info: init_response.agent_info, + auth_methods: init_response.auth_methods, + }) + } + JsonRpcResult::Error(error_response) => { + // Initialization failed with JSON-RPC error + Ok(InitializeResult::Error(error_response.error.message)) + } + } +} + +/// Validate that client and agent protocol versions are compatible. +/// +/// Currently, we require exact version match. In the future, this could +/// be enhanced to support backward compatibility. +pub fn is_version_compatible(client_version: u32, agent_version: u32) -> bool { + client_version == agent_version +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::acp::connector_state::FsCapabilities; + + #[test] + fn test_initialize_request_creation() { + let caps = ClientCapabilities { + fs: Some(FsCapabilities { + read_text_file: Some(true), + write_text_file: Some(false), + }), + terminal: Some(false), + _meta: None, + }; + + let request = InitializeRequest::new(caps.clone(), None); + + assert_eq!(request.protocol_version, PROTOCOL_VERSION); + assert_eq!(request.client_capabilities, caps); + assert!(request.client_info.is_none()); + } + + #[test] + fn test_initialize_request_with_client_info() { + let caps = ClientCapabilities::default_safe(); + let info = ImplementationInfo { + name: "dirigent".to_string(), + title: Some("Dirigent ACP Client".to_string()), + version: Some("0.1.0".to_string()), + }; + + let request = InitializeRequest::new(caps, Some(info.clone())); + + assert!(request.client_info.is_some()); + assert_eq!(request.client_info.unwrap(), info); + } + + #[test] + fn test_initialize_request_serialization() { + let caps = ClientCapabilities::default_safe(); + let request = InitializeRequest::new(caps, None); + + let json = serde_json::to_string(&request).unwrap(); + let deserialized: InitializeRequest = serde_json::from_str(&json).unwrap(); + + assert_eq!(request, deserialized); + } + + #[test] + fn test_initialize_request_to_jsonrpc() { + let caps = ClientCapabilities::default_safe(); + let request = InitializeRequest::new(caps, None); + + let jsonrpc = request.to_jsonrpc(42); + + assert_eq!(jsonrpc.method, "initialize"); + assert_eq!(jsonrpc.id, serde_json::Value::Number(42.into())); + assert!(jsonrpc.params.is_some()); + } + + #[test] + fn test_initialize_response_parsing() { + let json = serde_json::json!({ + "protocol_version": 1, + "agent_capabilities": { + "load_session": true, + "prompt_capabilities": { + "image": true, + "audio": false, + "embedded_context": true + } + }, + "agent_info": { + "name": "test-agent", + "version": "1.0.0" + }, + "auth_methods": [] + }); + + let response = InitializeResponse::from_jsonrpc(&json).unwrap(); + + assert_eq!(response.protocol_version, 1); + assert!(response.agent_capabilities.load_session.unwrap()); + assert_eq!(response.auth_methods.len(), 0); + } + + #[test] + fn test_initialize_response_with_auth() { + let json = serde_json::json!({ + "protocol_version": 1, + "agent_capabilities": {}, + "auth_methods": ["api_key", "oauth"] + }); + + let response = InitializeResponse::from_jsonrpc(&json).unwrap(); + + assert_eq!(response.auth_methods.len(), 2); + assert!(response.auth_methods.contains(&"api_key".to_string())); + assert!(response.auth_methods.contains(&"oauth".to_string())); + } + + #[test] + fn test_version_compatibility() { + assert!(is_version_compatible(1, 1)); + assert!(!is_version_compatible(1, 2)); + assert!(!is_version_compatible(2, 1)); + } + + #[test] + fn test_initialize_response_missing_optional_fields() { + let json = serde_json::json!({ + "protocol_version": 1, + "agent_capabilities": {} + }); + + let response = InitializeResponse::from_jsonrpc(&json).unwrap(); + + assert_eq!(response.protocol_version, 1); + assert!(response.agent_info.is_none()); + assert_eq!(response.auth_methods.len(), 0); + } +} diff --git a/crates/dirigent_core/src/acp/protocol/mod.rs b/crates/dirigent_core/src/acp/protocol/mod.rs new file mode 100644 index 0000000..455ddb9 --- /dev/null +++ b/crates/dirigent_core/src/acp/protocol/mod.rs @@ -0,0 +1,39 @@ +//! ACP protocol implementation. +//! +//! This module implements the Agent-Client Protocol (ACP) message handling, +//! including initialization, authentication, session lifecycle, prompt turns, +//! streaming updates, cancellation, and error handling. + +pub mod initialize; +pub mod authenticate; +pub mod capabilities; +pub mod session; +pub mod prompt; +pub mod streaming; +pub mod stop_reason; +pub mod cancellation; +pub mod error; + +// Re-export commonly used types +pub use initialize::{InitializeRequest, InitializeResponse}; +pub use authenticate::{AuthenticateRequest, AuthenticateResponse}; +pub use capabilities::validate_capability; +pub use session::{ + SessionCancelNotification, SessionLoadRequest, SessionLoadResponse, SessionNewRequest, + SessionNewResponse, SessionSetModeRequest, SessionSetModeResponse, SessionSetModelRequest, + SessionSetModelResponse, +}; +pub use prompt::{ + ContentBlock, SessionPromptRequest, SessionPromptResponse, StopReason, EmbeddedResource, + Annotations, PromptError, +}; +pub use streaming::{ + SessionUpdate, SessionUpdateNotification, ToolKind, ToolCallStatus, ToolCallInfo, + ToolCallLocation, ToolCallContent, MessageAccumulator, PlanEntry, Command, +}; +pub use stop_reason::{StopReasonAction, handle_stop_reason, is_continuable, is_error}; +pub use cancellation::{handle_cancellation, cancel_pending_tool_calls, handle_disconnect}; +pub use error::{ + ErrorClass, ClassifiedError, ErrorSeverity, ErrorAction, classify_jsonrpc_error, + classify_transport_error, exponential_backoff, error_severity, error_actions, +}; diff --git a/crates/dirigent_core/src/acp/protocol/prompt.rs b/crates/dirigent_core/src/acp/protocol/prompt.rs new file mode 100644 index 0000000..3da7d85 --- /dev/null +++ b/crates/dirigent_core/src/acp/protocol/prompt.rs @@ -0,0 +1,525 @@ +//! Prompt turn handling for ACP. +//! +//! This module implements session/prompt requests and responses, including: +//! - Content blocks (text, image, audio, resource, resource_link) +//! - Prompt request handling with timeout +//! - Stop reason interpretation +//! - Content validation against agent capabilities + +use crate::acp::connector_state::AgentCapabilities; +use crate::acp::transport::{JsonRpcRequest, JsonRpcResult, Transport, TransportError}; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +/// Request to send a prompt to the agent. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SessionPromptRequest { + /// Session ID to send prompt to. + pub session_id: String, + /// Content blocks that make up the prompt. + pub prompt: Vec, +} + +impl SessionPromptRequest { + /// Create a new session/prompt request. + pub fn new(session_id: impl Into, prompt: Vec) -> Self { + Self { + session_id: session_id.into(), + prompt, + } + } + + /// Convert to JSON-RPC request. + pub fn to_jsonrpc(&self, id: u64) -> JsonRpcRequest { + JsonRpcRequest::new( + id, + "session/prompt", + Some(serde_json::to_value(self).expect("Failed to serialize SessionPromptRequest")), + ) + } + + /// Validate content blocks against agent capabilities. + /// + /// Returns an error if any content block uses a capability the agent + /// doesn't support. + pub fn validate(&self, agent_capabilities: &AgentCapabilities) -> Result<(), String> { + for block in &self.prompt { + match block { + ContentBlock::Text { .. } => { + // Text is always supported + } + ContentBlock::Image { .. } => { + if !agent_capabilities.supports_image() { + return Err("Agent does not support image content blocks".to_string()); + } + } + ContentBlock::Audio { .. } => { + if !agent_capabilities.supports_audio() { + return Err("Agent does not support audio content blocks".to_string()); + } + } + ContentBlock::Resource { .. } => { + if !agent_capabilities.supports_embedded_context() { + return Err( + "Agent does not support embedded context (resource blocks)" + .to_string(), + ); + } + } + ContentBlock::ResourceLink { .. } => { + // ResourceLink is always supported + } + } + } + Ok(()) + } +} + +/// Response from session/prompt. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionPromptResponse { + /// Reason the prompt turn stopped. + pub stop_reason: StopReason, +} + +impl SessionPromptResponse { + /// Parse from JSON-RPC result. + pub fn from_jsonrpc(result: &serde_json::Value) -> Result { + serde_json::from_value(result.clone()) + } +} + +/// Reason a prompt turn stopped. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "snake_case")] +pub enum StopReason { + /// Normal completion (agent finished its turn). + EndTurn, + /// Hit token limit for the response. + MaxTokens, + /// Hit maximum number of LLM requests per turn. + MaxTurnRequests, + /// Agent refused to continue. + Refusal, + /// User cancelled the prompt. + Cancelled, +} + +/// Content block in a prompt or message. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentBlock { + /// Plain text content. + Text { + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + annotations: Option, + }, + /// Image content (base64-encoded). + Image { + /// Base64-encoded image data. + data: String, + /// MIME type (e.g., "image/png", "image/jpeg"). + mime_type: String, + /// Optional URI reference. + #[serde(skip_serializing_if = "Option::is_none")] + uri: Option, + #[serde(skip_serializing_if = "Option::is_none")] + annotations: Option, + }, + /// Audio content (base64-encoded). + Audio { + /// Base64-encoded audio data. + data: String, + /// MIME type (e.g., "audio/wav", "audio/mp3"). + mime_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + annotations: Option, + }, + /// Embedded resource (file content or blob). + Resource { + resource: EmbeddedResource, + #[serde(skip_serializing_if = "Option::is_none")] + annotations: Option, + }, + /// Link to a resource (reference without full content). + ResourceLink { + /// URI of the resource. + uri: String, + /// Name/filename of the resource. + name: String, + /// MIME type (if known). + #[serde(skip_serializing_if = "Option::is_none")] + mime_type: Option, + /// Human-readable title. + #[serde(skip_serializing_if = "Option::is_none")] + title: Option, + /// Description of the resource. + #[serde(skip_serializing_if = "Option::is_none")] + description: Option, + /// Size in bytes (if known). + #[serde(skip_serializing_if = "Option::is_none")] + size: Option, + #[serde(skip_serializing_if = "Option::is_none")] + annotations: Option, + }, +} + +/// Embedded resource content. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(untagged)] +pub enum EmbeddedResource { + /// Text resource (e.g., file contents). + Text { + uri: String, + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + mime_type: Option, + }, + /// Binary resource (base64-encoded). + Blob { + uri: String, + /// Base64-encoded binary data. + blob: String, + #[serde(skip_serializing_if = "Option::is_none")] + mime_type: Option, + }, +} + +/// Annotations for content blocks. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Annotations { + /// Intended audience for this content (e.g., ["assistant"]). + #[serde(skip_serializing_if = "Option::is_none")] + pub audience: Option>, + /// Priority hint (higher = more important). + #[serde(skip_serializing_if = "Option::is_none")] + pub priority: Option, +} + +/// Error type for prompt operations. +#[derive(Debug, thiserror::Error)] +pub enum PromptError { + #[error("Prompt timed out after {0:?}")] + Timeout(Duration), + + #[error("JSON-RPC error: {0}")] + JsonRpcError(crate::acp::transport::JsonRpcError), + + #[error("Transport error: {0}")] + TransportError(#[from] TransportError), + + #[error("Prompt was cancelled")] + Cancelled, + + #[error("Validation error: {0}")] + ValidationError(String), +} + +impl From for PromptError { + fn from(err: crate::acp::transport::JsonRpcError) -> Self { + PromptError::JsonRpcError(err) + } +} + +/// Send a prompt request to the agent. +/// +/// # Arguments +/// +/// * `transport` - The transport to use for communication +/// * `session_id` - ID of the session to prompt +/// * `prompt` - Content blocks to send +/// * `agent_capabilities` - Agent capabilities (for validation) +/// +/// # Returns +/// +/// The stop reason for the prompt turn. +pub async fn session_prompt( + transport: &mut dyn Transport, + session_id: impl Into, + prompt: Vec, + agent_capabilities: &AgentCapabilities, +) -> Result { + // Create and validate request + let request = SessionPromptRequest::new(session_id, prompt); + request + .validate(agent_capabilities) + .map_err(PromptError::ValidationError)?; + + // Convert to JSON-RPC + let jsonrpc_request = request.to_jsonrpc(10); // Use ID 10+ for prompts + + // Send request + let result = transport.send_request(jsonrpc_request).await?; + + // Handle response + match result { + JsonRpcResult::Success(response) => { + let prompt_response = SessionPromptResponse::from_jsonrpc(&response.result) + .map_err(|e| TransportError::SerializationError(e))?; + Ok(prompt_response.stop_reason) + } + JsonRpcResult::Error(error_response) => Err(PromptError::JsonRpcError(error_response.error)), + } +} + +/// Send a prompt request with a timeout. +/// +/// This function sends a prompt and waits for a response, but will return +/// a timeout error if no response is received within the specified duration. +pub async fn session_prompt_with_timeout( + transport: &mut dyn Transport, + session_id: impl Into, + prompt: Vec, + agent_capabilities: &AgentCapabilities, + timeout: Duration, +) -> Result { + let session_id = session_id.into(); + + match tokio::time::timeout( + timeout, + session_prompt(transport, session_id, prompt, agent_capabilities), + ) + .await + { + Ok(result) => result, + Err(_) => Err(PromptError::Timeout(timeout)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::acp::connector_state::PromptCapabilities; + + #[test] + fn test_session_prompt_request() { + let request = SessionPromptRequest::new( + "session-123", + vec![ContentBlock::Text { + text: "Hello".to_string(), + annotations: None, + }], + ); + + assert_eq!(request.session_id, "session-123"); + assert_eq!(request.prompt.len(), 1); + } + + #[test] + fn test_session_prompt_request_serialization() { + let request = SessionPromptRequest::new( + "session-123", + vec![ContentBlock::Text { + text: "Hello".to_string(), + annotations: None, + }], + ); + + let json = serde_json::to_string(&request).unwrap(); + let deserialized: SessionPromptRequest = serde_json::from_str(&json).unwrap(); + + assert_eq!(request, deserialized); + } + + #[test] + fn test_content_block_text() { + let block = ContentBlock::Text { + text: "Test content".to_string(), + annotations: Some(Annotations { + audience: Some(vec!["assistant".to_string()]), + priority: Some(1.0), + }), + }; + + let json = serde_json::to_string(&block).unwrap(); + assert!(json.contains("\"type\":\"text\"")); + assert!(json.contains("Test content")); + } + + #[test] + fn test_content_block_image() { + let block = ContentBlock::Image { + data: "base64data".to_string(), + mime_type: "image/png".to_string(), + uri: Some("file://test.png".to_string()), + annotations: None, + }; + + let json = serde_json::to_string(&block).unwrap(); + assert!(json.contains("\"type\":\"image\"")); + assert!(json.contains("base64data")); + assert!(json.contains("image/png")); + } + + #[test] + fn test_content_block_resource_link() { + let block = ContentBlock::ResourceLink { + uri: "file:///path/to/file.txt".to_string(), + name: "file.txt".to_string(), + mime_type: Some("text/plain".to_string()), + title: Some("Test File".to_string()), + description: Some("A test file".to_string()), + size: Some(1024), + annotations: None, + }; + + let json = serde_json::to_string(&block).unwrap(); + assert!(json.contains("\"type\":\"resource_link\"")); + assert!(json.contains("file.txt")); + } + + #[test] + fn test_embedded_resource_text() { + let resource = EmbeddedResource::Text { + uri: "file:///test.txt".to_string(), + text: "File contents".to_string(), + mime_type: Some("text/plain".to_string()), + }; + + let json = serde_json::to_string(&resource).unwrap(); + assert!(json.contains("File contents")); + } + + #[test] + fn test_stop_reason_serialization() { + let reasons = vec![ + StopReason::EndTurn, + StopReason::MaxTokens, + StopReason::MaxTurnRequests, + StopReason::Refusal, + StopReason::Cancelled, + ]; + + for reason in reasons { + let json = serde_json::to_string(&reason).unwrap(); + let deserialized: StopReason = serde_json::from_str(&json).unwrap(); + assert_eq!(reason, deserialized); + } + } + + #[test] + fn test_validate_text_block() { + let request = SessionPromptRequest::new( + "session-123", + vec![ContentBlock::Text { + text: "Hello".to_string(), + annotations: None, + }], + ); + + let caps = AgentCapabilities { + load_session: None, + prompt_capabilities: None, + mcp: None, + _meta: None, + }; + + // Text is always supported + assert!(request.validate(&caps).is_ok()); + } + + #[test] + fn test_validate_image_unsupported() { + let request = SessionPromptRequest::new( + "session-123", + vec![ContentBlock::Image { + data: "base64".to_string(), + mime_type: "image/png".to_string(), + uri: None, + annotations: None, + }], + ); + + let caps = AgentCapabilities { + load_session: None, + prompt_capabilities: None, + mcp: None, + _meta: None, + }; + + // Image not supported + assert!(request.validate(&caps).is_err()); + } + + #[test] + fn test_validate_image_supported() { + let request = SessionPromptRequest::new( + "session-123", + vec![ContentBlock::Image { + data: "base64".to_string(), + mime_type: "image/png".to_string(), + uri: None, + annotations: None, + }], + ); + + let caps = AgentCapabilities { + load_session: None, + prompt_capabilities: Some(PromptCapabilities { + image: Some(true), + audio: None, + embedded_context: None, + }), + mcp: None, + _meta: None, + }; + + // Image supported + assert!(request.validate(&caps).is_ok()); + } + + #[test] + fn test_validate_resource_unsupported() { + let request = SessionPromptRequest::new( + "session-123", + vec![ContentBlock::Resource { + resource: EmbeddedResource::Text { + uri: "file://test".to_string(), + text: "content".to_string(), + mime_type: None, + }, + annotations: None, + }], + ); + + let caps = AgentCapabilities { + load_session: None, + prompt_capabilities: None, + mcp: None, + _meta: None, + }; + + // Embedded context not supported + assert!(request.validate(&caps).is_err()); + } + + #[test] + fn test_validate_resource_supported() { + let request = SessionPromptRequest::new( + "session-123", + vec![ContentBlock::Resource { + resource: EmbeddedResource::Text { + uri: "file://test".to_string(), + text: "content".to_string(), + mime_type: None, + }, + annotations: None, + }], + ); + + let caps = AgentCapabilities { + load_session: None, + prompt_capabilities: Some(PromptCapabilities { + image: None, + audio: None, + embedded_context: Some(true), + }), + mcp: None, + _meta: None, + }; + + // Embedded context supported + assert!(request.validate(&caps).is_ok()); + } +} diff --git a/crates/dirigent_core/src/acp/protocol/session.rs b/crates/dirigent_core/src/acp/protocol/session.rs new file mode 100644 index 0000000..62d1065 --- /dev/null +++ b/crates/dirigent_core/src/acp/protocol/session.rs @@ -0,0 +1,665 @@ +//! Session lifecycle methods for ACP. +//! +//! This module implements session management methods including: +//! - `session/new` - Create new sessions with working directory and MCP servers +//! - `session/load` - Load existing sessions with history replay +//! - `session/set_mode` - Change session mode (e.g., "code", "chat") +//! - `session/set_model` - Change model (optional, unstable in spec) +//! - `session/cancel` - Cancel ongoing prompt turns + +use crate::acp::connector_state::{ + AgentCapabilities, McpServer, +}; +use crate::acp::transport::{JsonRpcRequest, JsonRpcResult, Transport, TransportError}; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +/// Request to create a new session. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionNewRequest { + /// Working directory for the session (absolute path). + pub cwd: String, + /// MCP servers to use for this session. + #[serde(default)] + pub mcp_servers: Vec, +} + +impl SessionNewRequest { + /// Create a new session/new request. + /// + /// # Arguments + /// + /// * `cwd` - Working directory (absolute path) + /// * `mcp_servers` - MCP servers to use + /// + /// # Panics + /// + /// Panics if cwd is not an absolute path. + pub fn new(cwd: impl Into, mcp_servers: Vec) -> Self { + let cwd = cwd.into(); + assert!( + Path::new(&cwd).is_absolute(), + "Working directory must be an absolute path: {}", + cwd + ); + Self { cwd, mcp_servers } + } + + /// Convert to JSON-RPC request. + pub fn to_jsonrpc(&self, id: u64) -> JsonRpcRequest { + JsonRpcRequest::new( + id, + "session/new", + Some(serde_json::to_value(self).expect("Failed to serialize SessionNewRequest")), + ) + } +} + +/// Response from session/new. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionNewResponse { + /// Unique identifier for the created session. + pub session_id: String, +} + +impl SessionNewResponse { + /// Parse from JSON-RPC result. + pub fn from_jsonrpc(result: &serde_json::Value) -> Result { + serde_json::from_value(result.clone()) + } +} + +/// Request to load an existing session. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionLoadRequest { + /// Session ID to load. + pub session_id: String, + /// Working directory for the session (absolute path). + pub cwd: String, + /// MCP servers to use for this session. + #[serde(default)] + pub mcp_servers: Vec, +} + +impl SessionLoadRequest { + /// Create a new session/load request. + /// + /// # Arguments + /// + /// * `session_id` - ID of the session to load + /// * `cwd` - Working directory (absolute path) + /// * `mcp_servers` - MCP servers to use + /// + /// # Panics + /// + /// Panics if cwd is not an absolute path. + pub fn new( + session_id: impl Into, + cwd: impl Into, + mcp_servers: Vec, + ) -> Self { + let cwd = cwd.into(); + assert!( + Path::new(&cwd).is_absolute(), + "Working directory must be an absolute path: {}", + cwd + ); + Self { + session_id: session_id.into(), + cwd, + mcp_servers, + } + } + + /// Convert to JSON-RPC request. + pub fn to_jsonrpc(&self, id: u64) -> JsonRpcRequest { + JsonRpcRequest::new( + id, + "session/load", + Some(serde_json::to_value(self).expect("Failed to serialize SessionLoadRequest")), + ) + } +} + +/// Response from session/load. +/// +/// The response itself is empty (null in JSON), but the agent will send +/// session/update notifications to replay the conversation history. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionLoadResponse; + +impl SessionLoadResponse { + /// Parse from JSON-RPC result. + pub fn from_jsonrpc(_result: &serde_json::Value) -> Result { + Ok(SessionLoadResponse) + } +} + +/// Request to change session mode. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionSetModeRequest { + /// Session ID to update. + pub session_id: String, + /// Mode to switch to (e.g., "code", "chat"). + pub mode: String, +} + +impl SessionSetModeRequest { + /// Create a new session/set_mode request. + pub fn new(session_id: impl Into, mode: impl Into) -> Self { + Self { + session_id: session_id.into(), + mode: mode.into(), + } + } + + /// Convert to JSON-RPC request. + pub fn to_jsonrpc(&self, id: u64) -> JsonRpcRequest { + JsonRpcRequest::new( + id, + "session/set_mode", + Some(serde_json::to_value(self).expect("Failed to serialize SessionSetModeRequest")), + ) + } +} + +/// Response from session/set_mode (empty). +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionSetModeResponse; + +impl SessionSetModeResponse { + /// Parse from JSON-RPC result. + pub fn from_jsonrpc(_result: &serde_json::Value) -> Result { + Ok(SessionSetModeResponse) + } +} + +/// Request to change session model (UNSTABLE in spec). +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionSetModelRequest { + /// Session ID to update. + pub session_id: String, + /// Model identifier to switch to. + pub model: String, +} + +impl SessionSetModelRequest { + /// Create a new session/set_model request. + pub fn new(session_id: impl Into, model: impl Into) -> Self { + Self { + session_id: session_id.into(), + model: model.into(), + } + } + + /// Convert to JSON-RPC request. + pub fn to_jsonrpc(&self, id: u64) -> JsonRpcRequest { + JsonRpcRequest::new( + id, + "session/set_model", + Some(serde_json::to_value(self).expect("Failed to serialize SessionSetModelRequest")), + ) + } +} + +/// Response from session/set_model (empty). +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionSetModelResponse; + +impl SessionSetModelResponse { + /// Parse from JSON-RPC result. + pub fn from_jsonrpc(_result: &serde_json::Value) -> Result { + Ok(SessionSetModelResponse) + } +} + +/// Notification to cancel an ongoing prompt turn. +/// +/// This is a notification (no response expected). The agent should stop +/// LLM requests and tool executions, then respond to the original +/// session/prompt with stopReason: cancelled. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SessionCancelNotification { + /// Session ID to cancel. + pub session_id: String, +} + +impl SessionCancelNotification { + /// Create a new session/cancel notification. + pub fn new(session_id: impl Into) -> Self { + Self { + session_id: session_id.into(), + } + } + + /// Convert to JSON-RPC notification (no id field). + pub fn to_jsonrpc(&self) -> JsonRpcRequest { + JsonRpcRequest::notification( + "session/cancel", + Some(serde_json::to_value(self).expect("Failed to serialize SessionCancelNotification")), + ) + } +} + +/// Validate MCP server configurations against agent capabilities. +/// +/// Returns an error if any MCP server uses a transport that the agent +/// doesn't support. +pub fn validate_mcp_servers( + mcp_servers: &[McpServer], + agent_capabilities: &AgentCapabilities, +) -> Result<(), String> { + let mcp_caps = agent_capabilities.mcp.as_ref(); + + for server in mcp_servers { + match server { + McpServer::Stdio { .. } => { + // Stdio is always supported + } + McpServer::Http { .. } => { + let supported = mcp_caps + .and_then(|caps| caps.http) + .unwrap_or(false); + if !supported { + return Err(format!( + "Agent does not support HTTP MCP servers (server: {})", + match server { + McpServer::Http { name, .. } => name, + _ => unreachable!(), + } + )); + } + } + McpServer::Sse { .. } => { + let supported = mcp_caps + .and_then(|caps| caps.sse) + .unwrap_or(false); + if !supported { + return Err(format!( + "Agent does not support SSE MCP servers (server: {})", + match server { + McpServer::Sse { name, .. } => name, + _ => unreachable!(), + } + )); + } + } + } + } + + Ok(()) +} + +/// Create a new session. +/// +/// Sends a session/new request to the agent and returns the session ID. +/// +/// # Arguments +/// +/// * `transport` - The transport to use for communication +/// * `cwd` - Working directory (absolute path) +/// * `mcp_servers` - MCP servers to configure +/// * `agent_capabilities` - Agent capabilities (for validation) +/// +/// # Returns +/// +/// The session ID of the created session. +pub async fn session_new( + transport: &mut dyn Transport, + cwd: impl Into, + mcp_servers: Vec, + agent_capabilities: &AgentCapabilities, +) -> Result { + // Validate MCP servers + validate_mcp_servers(&mcp_servers, agent_capabilities) + .map_err(|e| TransportError::Other(e))?; + + // Create request + let request = SessionNewRequest::new(cwd, mcp_servers); + let jsonrpc_request = request.to_jsonrpc(2); // Use ID 2 (after initialize which uses 1) + + // Send request + let result = transport.send_request(jsonrpc_request).await?; + + // Handle response + match result { + JsonRpcResult::Success(response) => { + let session_response = SessionNewResponse::from_jsonrpc(&response.result) + .map_err(|e| TransportError::SerializationError(e))?; + Ok(session_response.session_id) + } + JsonRpcResult::Error(error_response) => { + Err(TransportError::JsonRpcError(error_response.error)) + } + } +} + +/// Load an existing session. +/// +/// Sends a session/load request to the agent. The agent will respond with +/// an empty response, then send session/update notifications to replay +/// the conversation history. +/// +/// # Arguments +/// +/// * `transport` - The transport to use for communication +/// * `session_id` - ID of the session to load +/// * `cwd` - Working directory (absolute path) +/// * `mcp_servers` - MCP servers to configure +/// * `agent_capabilities` - Agent capabilities (for validation) +/// +/// # Returns +/// +/// Ok(()) if the session load was initiated successfully. History will be +/// replayed via session/update notifications. +pub async fn session_load( + transport: &mut dyn Transport, + session_id: impl Into, + cwd: impl Into, + mcp_servers: Vec, + agent_capabilities: &AgentCapabilities, +) -> Result<(), TransportError> { + // Check if agent supports load_session + if !agent_capabilities.supports_load_session() { + return Err(TransportError::Other( + "Agent does not support session loading".to_string(), + )); + } + + // Validate MCP servers + validate_mcp_servers(&mcp_servers, agent_capabilities) + .map_err(|e| TransportError::Other(e))?; + + // Create request + let request = SessionLoadRequest::new(session_id, cwd, mcp_servers); + let jsonrpc_request = request.to_jsonrpc(3); // Use next available ID + + // Send request + let result = transport.send_request(jsonrpc_request).await?; + + // Handle response + match result { + JsonRpcResult::Success(_) => Ok(()), + JsonRpcResult::Error(error_response) => { + Err(TransportError::JsonRpcError(error_response.error)) + } + } +} + +/// Change the session mode. +/// +/// # Arguments +/// +/// * `transport` - The transport to use for communication +/// * `session_id` - ID of the session to update +/// * `mode` - Mode to switch to (e.g., "code", "chat") +pub async fn session_set_mode( + transport: &mut dyn Transport, + session_id: impl Into, + mode: impl Into, +) -> Result<(), TransportError> { + let request = SessionSetModeRequest::new(session_id, mode); + let jsonrpc_request = request.to_jsonrpc(4); // Use next available ID + + let result = transport.send_request(jsonrpc_request).await?; + + match result { + JsonRpcResult::Success(_) => Ok(()), + JsonRpcResult::Error(error_response) => { + Err(TransportError::JsonRpcError(error_response.error)) + } + } +} + +/// Change the session model (UNSTABLE). +/// +/// # Arguments +/// +/// * `transport` - The transport to use for communication +/// * `session_id` - ID of the session to update +/// * `model` - Model identifier to switch to +pub async fn session_set_model( + transport: &mut dyn Transport, + session_id: impl Into, + model: impl Into, +) -> Result<(), TransportError> { + let request = SessionSetModelRequest::new(session_id, model); + let jsonrpc_request = request.to_jsonrpc(5); // Use next available ID + + let result = transport.send_request(jsonrpc_request).await?; + + match result { + JsonRpcResult::Success(_) => Ok(()), + JsonRpcResult::Error(error_response) => { + Err(TransportError::JsonRpcError(error_response.error)) + } + } +} + +/// Cancel an ongoing prompt turn. +/// +/// Sends a session/cancel notification (no response expected). The agent +/// should stop LLM requests and tool executions, then respond to the +/// original session/prompt with stopReason: cancelled. +/// +/// # Arguments +/// +/// * `transport` - The transport to use for communication +/// * `session_id` - ID of the session to cancel +pub async fn session_cancel( + transport: &mut dyn Transport, + session_id: impl Into, +) -> Result<(), TransportError> { + let notification = SessionCancelNotification::new(session_id); + let jsonrpc_request = notification.to_jsonrpc(); + + // Convert the notification-style request to a JsonRpcNotification + let jsonrpc_notification = crate::acp::transport::JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: jsonrpc_request.method, + params: jsonrpc_request.params, + }; + + // Send notification (no response expected) + transport.send_notification(jsonrpc_notification).await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::acp::connector_state::{ + EnvVariable, McpCapabilities, SessionState, + }; + + #[test] + fn test_mcp_server_serialization() { + let server = McpServer::Stdio { + name: "test-server".to_string(), + command: "test-cmd".to_string(), + args: vec!["arg1".to_string()], + env: vec![EnvVariable { + name: "VAR".to_string(), + value: "value".to_string(), + }], + }; + + let json = serde_json::to_string(&server).unwrap(); + let deserialized: McpServer = serde_json::from_str(&json).unwrap(); + + assert_eq!(server, deserialized); + } + + #[test] + fn test_session_new_request() { + let request = SessionNewRequest::new("/absolute/path", vec![]); + + assert_eq!(request.cwd, "/absolute/path"); + assert_eq!(request.mcp_servers.len(), 0); + } + + #[test] + #[should_panic(expected = "Working directory must be an absolute path")] + fn test_session_new_request_relative_path() { + SessionNewRequest::new("relative/path", vec![]); + } + + #[test] + fn test_session_new_request_serialization() { + let request = SessionNewRequest::new("/test/path", vec![]); + + let json = serde_json::to_string(&request).unwrap(); + let deserialized: SessionNewRequest = serde_json::from_str(&json).unwrap(); + + assert_eq!(request, deserialized); + } + + #[test] + fn test_session_load_request() { + let request = SessionLoadRequest::new("session-123", "/absolute/path", vec![]); + + assert_eq!(request.session_id, "session-123"); + assert_eq!(request.cwd, "/absolute/path"); + } + + #[test] + fn test_session_set_mode_request() { + let request = SessionSetModeRequest::new("session-123", "code"); + + assert_eq!(request.session_id, "session-123"); + assert_eq!(request.mode, "code"); + } + + #[test] + fn test_session_cancel_notification() { + let notification = SessionCancelNotification::new("session-123"); + + assert_eq!(notification.session_id, "session-123"); + } + + #[test] + fn test_session_state() { + let mut state = SessionState::new("session-123".to_string(), "/path".to_string(), vec![]); + + assert!(state.is_ready()); + assert!(!state.prompt_in_progress); + + state.start_prompt(); + assert!(!state.is_ready()); + assert!(state.prompt_in_progress); + + state.complete_prompt(); + assert!(state.is_ready()); + assert!(!state.prompt_in_progress); + } + + #[test] + fn test_session_state_cancellation() { + let mut state = SessionState::new("session-123".to_string(), "/path".to_string(), vec![]); + + state.start_prompt(); + state.start_cancellation(); + + assert!(!state.is_ready()); + assert!(state.cancelling); + + state.complete_prompt(); + assert!(state.is_ready()); + assert!(!state.cancelling); + } + + #[test] + fn test_session_state_loading() { + let mut state = SessionState::new("session-123".to_string(), "/path".to_string(), vec![]); + + state.start_loading(); + assert!(!state.is_ready()); + assert!(state.loading); + + state.complete_loading(); + assert!(state.is_ready()); + assert!(!state.loading); + } + + #[test] + fn test_validate_mcp_servers_stdio() { + let servers = vec![McpServer::Stdio { + name: "test".to_string(), + command: "cmd".to_string(), + args: vec![], + env: vec![], + }]; + + let caps = AgentCapabilities { + load_session: None, + prompt_capabilities: None, + mcp: None, + _meta: None, + }; + + // Stdio is always supported + assert!(validate_mcp_servers(&servers, &caps).is_ok()); + } + + #[test] + fn test_validate_mcp_servers_http_unsupported() { + let servers = vec![McpServer::Http { + name: "test".to_string(), + url: "http://localhost".to_string(), + headers: vec![], + }]; + + let caps = AgentCapabilities { + load_session: None, + prompt_capabilities: None, + mcp: None, + _meta: None, + }; + + // HTTP not supported + assert!(validate_mcp_servers(&servers, &caps).is_err()); + } + + #[test] + fn test_validate_mcp_servers_http_supported() { + let servers = vec![McpServer::Http { + name: "test".to_string(), + url: "http://localhost".to_string(), + headers: vec![], + }]; + + let caps = AgentCapabilities { + load_session: None, + prompt_capabilities: None, + mcp: Some(McpCapabilities { + http: Some(true), + sse: None, + }), + _meta: None, + }; + + // HTTP supported + assert!(validate_mcp_servers(&servers, &caps).is_ok()); + } + + #[test] + fn test_validate_mcp_servers_sse_unsupported() { + let servers = vec![McpServer::Sse { + name: "test".to_string(), + url: "http://localhost/events".to_string(), + headers: vec![], + }]; + + let caps = AgentCapabilities { + load_session: None, + prompt_capabilities: None, + mcp: Some(McpCapabilities { + http: Some(true), + sse: None, + }), + _meta: None, + }; + + // SSE not supported + assert!(validate_mcp_servers(&servers, &caps).is_err()); + } +} diff --git a/crates/dirigent_core/src/acp/protocol/stop_reason.rs b/crates/dirigent_core/src/acp/protocol/stop_reason.rs new file mode 100644 index 0000000..908091a --- /dev/null +++ b/crates/dirigent_core/src/acp/protocol/stop_reason.rs @@ -0,0 +1,146 @@ +//! Stop reason handling for ACP. +//! +//! This module implements interpretation and action determination for +//! different stop reasons (end_turn, max_tokens, max_turn_requests, refusal, cancelled). + +use crate::acp::protocol::prompt::StopReason; + +/// Action to take based on a stop reason. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum StopReasonAction { + /// Normal completion, no special action needed. + Complete, + /// Show a warning to the user with optional continuation. + ShowWarning(String), + /// Show an error to the user. + ShowError(String), + /// Show an info message to the user. + ShowInfo(String), +} + +/// Interpret a stop reason and determine the appropriate action. +/// +/// # Arguments +/// +/// * `stop_reason` - The stop reason from the agent +/// +/// # Returns +/// +/// The action that should be taken in response to this stop reason. +pub fn handle_stop_reason(stop_reason: StopReason) -> StopReasonAction { + match stop_reason { + StopReason::EndTurn => { + // Normal completion, no special action needed + StopReasonAction::Complete + } + StopReason::MaxTokens => { + // Hit token limit, offer to continue + StopReasonAction::ShowWarning( + "Response truncated due to token limit. You can continue the conversation to get more output.".to_string() + ) + } + StopReason::MaxTurnRequests => { + // Too many LLM requests in turn, offer to continue + StopReasonAction::ShowWarning( + "Maximum turn requests reached. The agent made many requests. You can continue if needed.".to_string() + ) + } + StopReason::Refusal => { + // Agent refused to continue + StopReasonAction::ShowError( + "The agent refused to continue with this request.".to_string() + ) + } + StopReason::Cancelled => { + // User cancelled, show confirmation + StopReasonAction::ShowInfo("Prompt cancelled.".to_string()) + } + } +} + +/// Check if a stop reason indicates the turn can be continued. +/// +/// Some stop reasons (max_tokens, max_turn_requests) indicate the agent +/// stopped due to limits but could continue if prompted again. +pub fn is_continuable(stop_reason: StopReason) -> bool { + matches!( + stop_reason, + StopReason::MaxTokens | StopReason::MaxTurnRequests + ) +} + +/// Check if a stop reason indicates an error condition. +pub fn is_error(stop_reason: StopReason) -> bool { + matches!(stop_reason, StopReason::Refusal) +} + +/// Get a user-facing message for a stop reason. +pub fn stop_reason_message(stop_reason: StopReason) -> String { + match handle_stop_reason(stop_reason) { + StopReasonAction::Complete => "Turn completed".to_string(), + StopReasonAction::ShowWarning(msg) => msg, + StopReasonAction::ShowError(msg) => msg, + StopReasonAction::ShowInfo(msg) => msg, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_end_turn_action() { + let action = handle_stop_reason(StopReason::EndTurn); + assert_eq!(action, StopReasonAction::Complete); + assert!(!is_continuable(StopReason::EndTurn)); + assert!(!is_error(StopReason::EndTurn)); + } + + #[test] + fn test_max_tokens_action() { + let action = handle_stop_reason(StopReason::MaxTokens); + assert!(matches!(action, StopReasonAction::ShowWarning(_))); + assert!(is_continuable(StopReason::MaxTokens)); + assert!(!is_error(StopReason::MaxTokens)); + } + + #[test] + fn test_max_turn_requests_action() { + let action = handle_stop_reason(StopReason::MaxTurnRequests); + assert!(matches!(action, StopReasonAction::ShowWarning(_))); + assert!(is_continuable(StopReason::MaxTurnRequests)); + assert!(!is_error(StopReason::MaxTurnRequests)); + } + + #[test] + fn test_refusal_action() { + let action = handle_stop_reason(StopReason::Refusal); + assert!(matches!(action, StopReasonAction::ShowError(_))); + assert!(!is_continuable(StopReason::Refusal)); + assert!(is_error(StopReason::Refusal)); + } + + #[test] + fn test_cancelled_action() { + let action = handle_stop_reason(StopReason::Cancelled); + assert!(matches!(action, StopReasonAction::ShowInfo(_))); + assert!(!is_continuable(StopReason::Cancelled)); + assert!(!is_error(StopReason::Cancelled)); + } + + #[test] + fn test_stop_reason_messages() { + let reasons = vec![ + StopReason::EndTurn, + StopReason::MaxTokens, + StopReason::MaxTurnRequests, + StopReason::Refusal, + StopReason::Cancelled, + ]; + + for reason in reasons { + let message = stop_reason_message(reason); + assert!(!message.is_empty()); + } + } +} diff --git a/crates/dirigent_core/src/acp/protocol/streaming.rs b/crates/dirigent_core/src/acp/protocol/streaming.rs new file mode 100644 index 0000000..e8a5d77 --- /dev/null +++ b/crates/dirigent_core/src/acp/protocol/streaming.rs @@ -0,0 +1,482 @@ +//! Streaming update handling for ACP. +//! +//! This module implements session/update notification handling, including: +//! - Session update dispatcher +//! - Message and thought chunk accumulation +//! - Tool call tracking and updates +//! - Plan and command updates +//! - User message chunks (for session replay) + +use crate::acp::protocol::prompt::ContentBlock; +use serde::{Deserialize, Serialize}; + +/// Notification of a session update from the agent. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SessionUpdateNotification { + /// Session ID this update belongs to. + pub session_id: String, + /// The update itself. + #[serde(flatten)] + pub update: SessionUpdate, +} + +/// Session update type. +/// +/// These are sent from the agent to the client during a prompt turn +/// to provide real-time updates on progress. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "sessionUpdate", rename_all = "snake_case")] +pub enum SessionUpdate { + /// Agent message chunk (assistant output). + AgentMessageChunk { + content: ContentBlock, + }, + /// Agent thought/reasoning chunk. + AgentThoughtChunk { + content: ContentBlock, + }, + /// User message chunk (for session replay). + UserMessageChunk { + content: ContentBlock, + }, + /// New tool call initiated. + ToolCall { + tool_call_id: String, + title: String, + kind: ToolKind, + status: ToolCallStatus, + #[serde(skip_serializing_if = "Option::is_none")] + location: Option, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option>, + }, + /// Update to existing tool call. + ToolCallUpdate { + tool_call_id: String, + status: ToolCallStatus, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option>, + }, + /// Agent execution plan. + Plan { + entries: Vec, + }, + /// Available commands update. + AvailableCommandsUpdate { + commands: Vec, + }, + /// Current mode update (agent-initiated mode change). + /// + /// Sent when the agent changes its own mode (e.g., exiting plan mode via a tool). + /// Per ACP spec, the field is `modeId` (camelCase). + CurrentModeUpdate { + #[serde(rename = "modeId")] + mode_id: String, + }, + /// Current model update (agent-initiated model change). + /// + /// UNSTABLE in ACP spec - may be sent when agent changes model. + /// Included for forward compatibility. + CurrentModelUpdate { + #[serde(rename = "modelId")] + model_id: String, + }, +} + +/// Kind of tool being invoked. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[serde(rename_all = "snake_case")] +pub enum ToolKind { + /// File read operations. + Read, + /// File write/modification operations. + Edit, + /// Search operations (glob, grep, ls). + Search, + /// Terminal/command execution. + Execute, + /// Agent internal reasoning. + Think, + /// Other/unknown tool type. + Other, +} + +/// Status of a tool call. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)] +#[serde(rename_all = "snake_case")] +pub enum ToolCallStatus { + /// Waiting to start. + Pending, + /// Currently executing. + InProgress, + /// Completed successfully. + Completed, + /// Failed with error. + Failed, + /// Cancelled by user. + Cancelled, +} + +/// Location information for a tool call. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct ToolCallLocation { + /// File path. + pub path: String, + /// Line number (optional). + #[serde(skip_serializing_if = "Option::is_none")] + pub line: Option, +} + +/// Content of a tool call. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolCallContent { + /// Regular content block. + Content { + content: ContentBlock, + }, + /// File diff. + Diff { + path: String, + diff: String, + }, + /// Terminal output reference. + Terminal { + terminal_id: String, + }, +} + +/// Plan entry. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct PlanEntry { + /// Plan step content. + pub content: String, + /// Priority level. + pub priority: String, + /// Status of this step. + pub status: String, +} + +/// Command available to the agent. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct Command { + /// Command name. + pub name: String, + /// Optional description. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, +} + +/// State for accumulating message chunks. +#[derive(Debug, Clone, Default)] +pub struct MessageAccumulator { + /// Accumulated content blocks. + pub blocks: Vec, + /// Whether this is a thought/reasoning block. + pub is_thought: bool, +} + +impl MessageAccumulator { + /// Create a new message accumulator. + pub fn new(is_thought: bool) -> Self { + Self { + blocks: Vec::new(), + is_thought, + } + } + + /// Add a content block to the accumulator. + pub fn add_block(&mut self, block: ContentBlock) { + self.blocks.push(block); + } + + /// Get accumulated text (concatenates all text blocks). + pub fn get_text(&self) -> String { + self.blocks + .iter() + .filter_map(|block| { + if let ContentBlock::Text { text, .. } = block { + Some(text.as_str()) + } else { + None + } + }) + .collect::>() + .join("") + } + + /// Clear the accumulator. + pub fn clear(&mut self) { + self.blocks.clear(); + } +} + +/// Tool call tracking state. +#[derive(Debug, Clone, PartialEq)] +pub struct ToolCallInfo { + /// Unique ID for this tool call. + pub tool_call_id: String, + /// Human-readable title. + pub title: String, + /// Kind of tool. + pub kind: ToolKind, + /// Current status. + pub status: ToolCallStatus, + /// Location (file/line) if applicable. + pub location: Option, + /// Content accumulated so far. + pub content: Vec, +} + +impl ToolCallInfo { + /// Create a new tool call tracking object. + pub fn new( + tool_call_id: String, + title: String, + kind: ToolKind, + status: ToolCallStatus, + location: Option, + content: Option>, + ) -> Self { + Self { + tool_call_id, + title, + kind, + status, + location, + content: content.unwrap_or_default(), + } + } + + /// Update the tool call with new status and optionally new content. + pub fn update(&mut self, status: ToolCallStatus, new_content: Option>) { + self.status = status; + if let Some(content) = new_content { + self.content.extend(content); + } + } + + /// Check if the tool call is in a terminal state. + pub fn is_terminal(&self) -> bool { + matches!( + self.status, + ToolCallStatus::Completed | ToolCallStatus::Failed | ToolCallStatus::Cancelled + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_session_update_agent_message_chunk() { + let update = SessionUpdate::AgentMessageChunk { + content: ContentBlock::Text { + text: "Hello".to_string(), + annotations: None, + }, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains("\"sessionUpdate\":\"agent_message_chunk\"")); + assert!(json.contains("Hello")); + } + + #[test] + fn test_session_update_tool_call() { + let update = SessionUpdate::ToolCall { + tool_call_id: "tool-1".to_string(), + title: "Read file".to_string(), + kind: ToolKind::Read, + status: ToolCallStatus::Pending, + location: Some(ToolCallLocation { + path: "/path/to/file.txt".to_string(), + line: Some(42), + }), + content: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains("\"sessionUpdate\":\"tool_call\"")); + assert!(json.contains("tool-1")); + assert!(json.contains("Read file")); + } + + #[test] + fn test_tool_kind_serialization() { + let kinds = vec![ + ToolKind::Read, + ToolKind::Edit, + ToolKind::Search, + ToolKind::Execute, + ToolKind::Think, + ToolKind::Other, + ]; + + for kind in kinds { + let json = serde_json::to_string(&kind).unwrap(); + let deserialized: ToolKind = serde_json::from_str(&json).unwrap(); + assert_eq!(kind, deserialized); + } + } + + #[test] + fn test_tool_call_status_serialization() { + let statuses = vec![ + ToolCallStatus::Pending, + ToolCallStatus::InProgress, + ToolCallStatus::Completed, + ToolCallStatus::Failed, + ToolCallStatus::Cancelled, + ]; + + for status in statuses { + let json = serde_json::to_string(&status).unwrap(); + let deserialized: ToolCallStatus = serde_json::from_str(&json).unwrap(); + assert_eq!(status, deserialized); + } + } + + #[test] + fn test_tool_call_content_diff() { + let content = ToolCallContent::Diff { + path: "/path/to/file.txt".to_string(), + diff: "@@ -1,1 +1,1 @@\n-old\n+new".to_string(), + }; + + let json = serde_json::to_string(&content).unwrap(); + assert!(json.contains("\"type\":\"diff\"")); + assert!(json.contains("/path/to/file.txt")); + } + + #[test] + fn test_message_accumulator() { + let mut acc = MessageAccumulator::new(false); + + acc.add_block(ContentBlock::Text { + text: "Hello ".to_string(), + annotations: None, + }); + acc.add_block(ContentBlock::Text { + text: "world".to_string(), + annotations: None, + }); + + assert_eq!(acc.get_text(), "Hello world"); + assert!(!acc.is_thought); + } + + #[test] + fn test_tool_call_info_update() { + let mut info = ToolCallInfo::new( + "tool-1".to_string(), + "Test".to_string(), + ToolKind::Read, + ToolCallStatus::Pending, + None, + None, + ); + + assert!(!info.is_terminal()); + + info.update(ToolCallStatus::InProgress, None); + assert_eq!(info.status, ToolCallStatus::InProgress); + assert!(!info.is_terminal()); + + info.update(ToolCallStatus::Completed, None); + assert_eq!(info.status, ToolCallStatus::Completed); + assert!(info.is_terminal()); + } + + #[test] + fn test_plan_entry() { + let entry = PlanEntry { + content: "Step 1: Do something".to_string(), + priority: "high".to_string(), + status: "pending".to_string(), + }; + + let json = serde_json::to_string(&entry).unwrap(); + let deserialized: PlanEntry = serde_json::from_str(&json).unwrap(); + assert_eq!(entry, deserialized); + } + + #[test] + fn test_command() { + let command = Command { + name: "/help".to_string(), + description: Some("Show help".to_string()), + }; + + let json = serde_json::to_string(&command).unwrap(); + let deserialized: Command = serde_json::from_str(&json).unwrap(); + assert_eq!(command, deserialized); + } + + #[test] + fn test_session_update_notification_current_mode() { + let notification = SessionUpdateNotification { + session_id: "session-123".to_string(), + update: SessionUpdate::CurrentModeUpdate { + mode_id: "code".to_string(), + }, + }; + + let json = serde_json::to_string(¬ification).unwrap(); + assert!(json.contains("session-123")); + assert!(json.contains("\"sessionUpdate\":\"current_mode_update\"")); + // Verify camelCase field name per ACP spec + assert!(json.contains("\"modeId\":\"code\"")); + } + + #[test] + fn test_session_update_notification_current_model() { + let notification = SessionUpdateNotification { + session_id: "session-123".to_string(), + update: SessionUpdate::CurrentModelUpdate { + model_id: "sonnet".to_string(), + }, + }; + + let json = serde_json::to_string(¬ification).unwrap(); + assert!(json.contains("session-123")); + assert!(json.contains("\"sessionUpdate\":\"current_model_update\"")); + // Verify camelCase field name per ACP spec + assert!(json.contains("\"modelId\":\"sonnet\"")); + } + + #[test] + fn test_current_mode_update_deserialization() { + // Test deserialization from ACP spec format + let json = r#"{ + "sessionUpdate": "current_mode_update", + "modeId": "plan" + }"#; + + let update: SessionUpdate = serde_json::from_str(json).unwrap(); + match update { + SessionUpdate::CurrentModeUpdate { mode_id } => { + assert_eq!(mode_id, "plan"); + } + _ => panic!("Expected CurrentModeUpdate"), + } + } + + #[test] + fn test_current_model_update_deserialization() { + // Test deserialization from ACP spec format (forward compatibility) + let json = r#"{ + "sessionUpdate": "current_model_update", + "modelId": "haiku" + }"#; + + let update: SessionUpdate = serde_json::from_str(json).unwrap(); + match update { + SessionUpdate::CurrentModelUpdate { model_id } => { + assert_eq!(model_id, "haiku"); + } + _ => panic!("Expected CurrentModelUpdate"), + } + } +} diff --git a/crates/dirigent_core/src/acp/transport/json_reader.rs b/crates/dirigent_core/src/acp/transport/json_reader.rs new file mode 100644 index 0000000..a8cc10d --- /dev/null +++ b/crates/dirigent_core/src/acp/transport/json_reader.rs @@ -0,0 +1,404 @@ +//! Multi-line JSON reader for stdio-based transports. +//! +//! Provides resilient JSON-RPC message reading that handles agents or clients +//! sending multi-line (pretty-printed) JSON instead of strict single-line +//! JSON-RPC. Uses `serde_json::StreamDeserializer` to correctly extract +//! complete JSON values from a buffer, even when they span multiple lines +//! or when multiple values are concatenated. +//! +//! # Usage +//! +//! Both sides of a JSON-RPC stdio connection can use this: +//! - **Client side** (`connectors::acp::transport::StdioTransport`): reading agent stdout +//! - **Server side** (`dirigate`): reading client stdin +//! +//! ```ignore +//! use dirigent_core::acp::transport::json_reader::{JsonLineReader, ReadResult}; +//! use tokio::io::BufReader; +//! +//! let stdin = tokio::io::stdin(); +//! let mut reader = BufReader::new(stdin); +//! let mut json_reader = JsonLineReader::new(); +//! +//! loop { +//! match json_reader.read_message(&mut reader).await { +//! Ok(ReadResult::Message(msg)) => println!("Got: {}", msg), +//! Ok(ReadResult::Eof) => break, +//! Err(e) => eprintln!("Error: {}", e), +//! } +//! } +//! ``` + +use serde_json::Value; +use tokio::io::{AsyncBufRead, AsyncBufReadExt}; + +/// Safety limit: max lines to buffer before giving up. +const MAX_BUFFER_LINES: usize = 50_000; +/// Safety limit: max buffer size in bytes (10 MB). +const MAX_BUFFER_BYTES: usize = 10 * 1024 * 1024; + +/// Result of reading one JSON message from the stream. +pub enum ReadResult { + /// Successfully read a complete JSON message. + Message(Value), + /// EOF reached (stream closed). + Eof, +} + +/// Persistent state for a multi-line JSON reader. +/// +/// Maintains a pending message buffer across `read_message` calls to handle +/// cases where the stream deserializer extracts two messages from one buffer. +pub struct JsonLineReader { + /// Pending message from a previous parse that found two values in one buffer. + pending: Option, + /// Incomplete remainder from a previous parse where a complete JSON message + /// was followed by the start of another message that wasn't yet complete. + pending_remainder: Option, +} + +impl JsonLineReader { + pub fn new() -> Self { + Self { + pending: None, + pending_remainder: None, + } + } + + /// Read the next complete JSON message from the given buffered reader. + /// + /// Handles: + /// - Single-line JSON (normal case) + /// - Multi-line / pretty-printed JSON + /// - Two concatenated JSON objects in the buffer + /// - Non-JSON garbage (returns error immediately) + /// + /// The reader is borrowed mutably for each call but state is maintained + /// in `self` across calls (for pending messages). + pub async fn read_message( + &mut self, + reader: &mut R, + ) -> Result { + // Return pending message from previous parse if available + if let Some(msg) = self.pending.take() { + tracing::debug!("Returning pending message from previous read_message() parse"); + return Ok(ReadResult::Message(msg)); + } + + let mut json_buffer = String::new(); + let mut buffered_lines: usize = 0; + + // Initialize buffer from pending remainder if available. + if let Some(remainder) = self.pending_remainder.take() { + tracing::debug!( + remainder_bytes = remainder.len(), + "Initializing buffer from pending remainder" + ); + json_buffer = remainder; + buffered_lines = 1; + } + + loop { + let mut line = String::new(); + let bytes_read = reader + .read_line(&mut line) + .await + .map_err(|e| format!("Failed to read line: {}", e))?; + + // EOF + if bytes_read == 0 { + if !json_buffer.is_empty() { + tracing::warn!( + buffer_lines = buffered_lines, + buffer_bytes = json_buffer.len(), + "EOF with incomplete multi-line JSON buffer" + ); + } + return Ok(ReadResult::Eof); + } + + // Skip empty lines when not buffering + if line.trim().is_empty() && json_buffer.is_empty() { + continue; + } + + // Append to buffer + if json_buffer.is_empty() { + json_buffer = line; + buffered_lines = 1; + } else { + json_buffer.push_str(&line); + buffered_lines += 1; + } + + // Try to extract a complete JSON value + if let Some((message, remainder)) = try_extract_json(&json_buffer) { + if buffered_lines > 1 { + tracing::warn!( + lines = buffered_lines, + bytes = json_buffer.len(), + "Recovered multi-line JSON message" + ); + } + + // If there's leftover data, try to parse as a second message + if !remainder.is_empty() { + tracing::debug!( + remainder_bytes = remainder.len(), + "Buffer has remaining data after complete JSON" + ); + if let Some((second_msg, second_remainder)) = try_extract_json(&remainder) { + self.pending = Some(second_msg); + if !second_remainder.is_empty() { + tracing::debug!( + remainder_bytes = second_remainder.len(), + "Carrying over third message fragment as pending remainder" + ); + self.pending_remainder = Some(second_remainder); + } + } else { + // Incomplete second message — carry over instead of discarding + tracing::debug!( + remainder_bytes = remainder.len(), + "Carrying over incomplete remainder for next read_message()" + ); + self.pending_remainder = Some(remainder); + } + } + + return Ok(ReadResult::Message(message)); + } + + // Buffer doesn't contain a complete JSON value yet. + // Check if it could ever become valid JSON. + let buffer_trimmed = json_buffer.trim_start(); + if !buffer_trimmed.starts_with('{') && !buffer_trimmed.starts_with('[') { + return Err(format!( + "Failed to parse JSON: expected JSON object. Line: {}", + &buffer_trimmed[..buffer_trimmed.len().min(200)] + )); + } + + // Safety limits + if buffered_lines >= MAX_BUFFER_LINES { + return Err(format!( + "Multi-line JSON message exceeded {} lines without completing", + MAX_BUFFER_LINES + )); + } + if json_buffer.len() > MAX_BUFFER_BYTES { + return Err(format!( + "Multi-line JSON message exceeded {} bytes without completing", + MAX_BUFFER_BYTES + )); + } + + if buffered_lines % 1000 == 0 { + tracing::debug!( + lines = buffered_lines, + bytes = json_buffer.len(), + "Still buffering multi-line JSON message..." + ); + } + } + } +} + +/// Try to extract a complete JSON value from the front of `buffer`. +/// +/// Uses `serde_json::StreamDeserializer` to parse the first complete JSON value. +/// Returns `Some((value, remaining))` if a value was found, where `remaining` +/// is the unparsed tail of the buffer. Returns `None` if the buffer doesn't +/// contain a complete JSON value yet. +pub fn try_extract_json(buffer: &str) -> Option<(Value, String)> { + let mut stream = serde_json::Deserializer::from_str(buffer).into_iter::(); + match stream.next() { + Some(Ok(value)) => { + let byte_offset = stream.byte_offset(); + let remaining = buffer[byte_offset..].trim_start().to_string(); + Some((value, remaining)) + } + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_single_json() { + let input = r#"{"jsonrpc":"2.0","id":1,"result":{}}"#; + let (value, remainder) = try_extract_json(input).unwrap(); + assert_eq!(value["id"], 1); + assert!(remainder.is_empty()); + } + + #[test] + fn test_extract_incomplete_returns_none() { + assert!(try_extract_json(r#"{"id":1,"res"#).is_none()); + } + + #[test] + fn test_extract_two_concatenated() { + let input = r#"{"id":1,"result":{}} +{"id":2,"result":{}}"#; + let (first, remainder) = try_extract_json(input).unwrap(); + assert_eq!(first["id"], 1); + let (second, remainder2) = try_extract_json(&remainder).unwrap(); + assert_eq!(second["id"], 2); + assert!(remainder2.is_empty()); + } + + #[test] + fn test_extract_multiline_pretty_printed() { + let input = "{\n \"id\": 1,\n \"result\": {}\n}"; + let (value, remainder) = try_extract_json(input).unwrap(); + assert_eq!(value["id"], 1); + assert!(remainder.is_empty()); + } + + #[test] + fn test_extract_garbage_returns_none() { + assert!(try_extract_json("on-management\",\n\n div {").is_none()); + } + + #[test] + fn test_extract_empty_returns_none() { + assert!(try_extract_json("").is_none()); + assert!(try_extract_json(" ").is_none()); + } + + #[test] + fn test_extract_complete_followed_by_partial() { + let input = r#"{"id":1,"result":{}} +{"id":2,"res"#; + let (first, remainder) = try_extract_json(input).unwrap(); + assert_eq!(first["id"], 1); + assert!(try_extract_json(&remainder).is_none()); + } + + #[test] + fn test_extract_with_trailing_whitespace() { + let input = r#"{"id":1,"result":{}} "#; + let (value, remainder) = try_extract_json(input).unwrap(); + assert_eq!(value["id"], 1); + assert!(remainder.is_empty()); + } + + #[test] + fn test_extract_with_escaped_newlines() { + let input = r#"{"content":"line1\nline2\nline3"}"#; + let (value, remainder) = try_extract_json(input).unwrap(); + assert_eq!(value["content"], "line1\nline2\nline3"); + assert!(remainder.is_empty()); + } + + #[tokio::test] + async fn test_reader_single_line() { + let input = b"{\"id\":1,\"result\":{}}\n"; + let mut cursor = std::io::Cursor::new(input.to_vec()); + let mut reader = JsonLineReader::new(); + match reader.read_message(&mut cursor).await.unwrap() { + ReadResult::Message(msg) => assert_eq!(msg["id"], 1), + ReadResult::Eof => panic!("Expected message, got EOF"), + } + } + + #[tokio::test] + async fn test_reader_multiline() { + let input = b"{\n \"id\": 1,\n \"result\": {}\n}\n"; + let mut cursor = std::io::Cursor::new(input.to_vec()); + let mut reader = JsonLineReader::new(); + match reader.read_message(&mut cursor).await.unwrap() { + ReadResult::Message(msg) => assert_eq!(msg["id"], 1), + ReadResult::Eof => panic!("Expected message, got EOF"), + } + } + + #[tokio::test] + async fn test_reader_eof() { + let input = b""; + let mut cursor = std::io::Cursor::new(input.to_vec()); + let mut reader = JsonLineReader::new(); + match reader.read_message(&mut cursor).await.unwrap() { + ReadResult::Message(_) => panic!("Expected EOF"), + ReadResult::Eof => {} + } + } + + #[tokio::test] + async fn test_reader_garbage_returns_error() { + let input = b"not json at all\n"; + let mut cursor = std::io::Cursor::new(input.to_vec()); + let mut reader = JsonLineReader::new(); + let result = reader.read_message(&mut cursor).await; + assert!(result.is_err()); + // Error should contain the raw line for forensics + let err = result.err().unwrap(); + assert!(err.contains("not json at all"), "Error should contain raw line: {}", err); + } + + #[tokio::test] + async fn test_reader_carries_over_incomplete_remainder() { + // Line 1 has a complete JSON object concatenated with the start of a second + // (split at a JSON whitespace boundary — between tokens, not inside a string). + // read_line returns `{"id":1}{"id":2,\n` (up to first newline). + // try_extract_json extracts id:1, remainder `{"id":2,` is incomplete. + // With the fix, this remainder is carried over as pending_remainder. + // Second read_message() initializes json_buffer from it, reads the next line + // `"result":{}}\n`, assembles `{"id":2,\n"result":{}}\n`, and extracts id:2. + let input = b"{\"id\":1}{\"id\":2,\n\"result\":{}}\n"; + let mut cursor = std::io::Cursor::new(input.to_vec()); + let mut reader = JsonLineReader::new(); + + // First call: extracts id:1, carries over incomplete remainder + match reader.read_message(&mut cursor).await.unwrap() { + ReadResult::Message(msg) => assert_eq!(msg["id"], 1), + ReadResult::Eof => panic!("Expected message 1, got EOF"), + } + + // Second call: assembles id:2 from remainder + next line + match reader.read_message(&mut cursor).await.unwrap() { + ReadResult::Message(msg) => assert_eq!(msg["id"], 2), + ReadResult::Eof => panic!("Expected message 2, got EOF"), + } + } + + #[tokio::test] + async fn test_reader_remainder_multiline_recovery() { + // Simulates the bug scenario: agent sends streaming notifications rapidly, + // OS pipe buffering delivers a complete message concatenated with the start + // of the next on the same line. The second message spans multiple lines. + // + // Line 1: complete notification + start of second notification (no newline between) + // Line 2: continuation of second notification's JSON + // Line 3: closing of second notification + let msg1 = r#"{"jsonrpc":"2.0","method":"notify","params":{"ok":true}}"#; + // Second message is split across lines at JSON-whitespace boundaries + let input = format!( + "{}{{\n\"jsonrpc\":\"2.0\",\n\"method\":\"update\",\"params\":{{\"content\":\"card data\"}}}}\n", + msg1 + ); + let mut cursor = std::io::Cursor::new(input.into_bytes()); + let mut reader = JsonLineReader::new(); + + // First message + match reader.read_message(&mut cursor).await.unwrap() { + ReadResult::Message(msg) => { + assert_eq!(msg["method"], "notify"); + assert_eq!(msg["params"]["ok"], true); + } + ReadResult::Eof => panic!("Expected message 1"), + } + + // Second message (assembled from remainder + subsequent lines) + match reader.read_message(&mut cursor).await.unwrap() { + ReadResult::Message(msg) => { + assert_eq!(msg["method"], "update"); + assert_eq!(msg["params"]["content"], "card data"); + } + ReadResult::Eof => panic!("Expected message 2"), + } + } +} diff --git a/crates/dirigent_core/src/acp/transport/mod.rs b/crates/dirigent_core/src/acp/transport/mod.rs new file mode 100644 index 0000000..3ec27a3 --- /dev/null +++ b/crates/dirigent_core/src/acp/transport/mod.rs @@ -0,0 +1,326 @@ +//! Transport layer abstraction for ACP (Agent-Client Protocol). +//! +//! This module provides a unified interface for communicating with ACP agents +//! over different transports (stdio, HTTP+SSE). It handles JSON-RPC message +//! encoding/decoding and transport-specific connection management. + +use async_trait::async_trait; +use futures::Stream; +use serde::{Deserialize, Serialize}; +use std::pin::Pin; +use thiserror::Error; + +// Shared utilities for JSON-over-stdio transports +pub mod json_reader; + +/// JSON-RPC 2.0 request message. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcRequest { + pub jsonrpc: String, // Always "2.0" + pub id: serde_json::Value, // Request ID (number, string, or null) + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +impl JsonRpcRequest { + /// Create a new JSON-RPC request with auto-generated numeric ID. + pub fn new(id: u64, method: impl Into, params: Option) -> Self { + Self { + jsonrpc: "2.0".to_string(), + id: serde_json::Value::Number(id.into()), + method: method.into(), + params, + } + } + + /// Create a new JSON-RPC request with string ID. + pub fn new_with_string_id( + id: impl Into, + method: impl Into, + params: Option, + ) -> Self { + Self { + jsonrpc: "2.0".to_string(), + id: serde_json::Value::String(id.into()), + method: method.into(), + params, + } + } + + /// Create a JSON-RPC notification (no ID field). + /// + /// Notifications are fire-and-forget - no response is expected. + pub fn notification(method: impl Into, params: Option) -> Self { + Self { + jsonrpc: "2.0".to_string(), + id: serde_json::Value::Null, + method: method.into(), + params, + } + } +} + +/// JSON-RPC 2.0 response message (success). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcResponse { + pub jsonrpc: String, // Always "2.0" + pub id: serde_json::Value, // Request ID (matches request) + pub result: serde_json::Value, +} + +/// JSON-RPC 2.0 error response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcErrorResponse { + pub jsonrpc: String, // Always "2.0" + pub id: serde_json::Value, // Request ID (matches request) + pub error: JsonRpcError, +} + +/// JSON-RPC 2.0 error object. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcError { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +impl std::fmt::Display for JsonRpcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "code {}: {}", self.code, self.message) + } +} + +/// JSON-RPC 2.0 notification message (no response expected). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcNotification { + pub jsonrpc: String, // Always "2.0" + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +impl JsonRpcNotification { + /// Create a new JSON-RPC notification. + pub fn new(method: impl Into, params: Option) -> Self { + Self { + jsonrpc: "2.0".to_string(), + method: method.into(), + params, + } + } +} + +/// Transport-level errors. +#[derive(Debug, Error)] +pub enum TransportError { + #[error("Connection error: {0}")] + ConnectionError(String), + + #[error("I/O error: {0}")] + IoError(#[from] std::io::Error), + + #[error("JSON serialization error: {0}")] + SerializationError(#[from] serde_json::Error), + + #[error("JSON-RPC error: {0}")] + JsonRpcError(JsonRpcError), + + #[error("Transport closed")] + Closed, + + #[error("Timeout waiting for response")] + Timeout, + + #[error("Process exited unexpectedly")] + ProcessExited, + + #[error("HTTP error: {0}")] + HttpError(String), + + #[error("SSE error: {0}")] + SseError(String), + + #[error("Other error: {0}")] + Other(String), +} + +/// Transport state. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TransportState { + /// Not yet connected. + Disconnected, + /// Currently connecting. + Connecting, + /// Connected and operational. + Connected, + /// Disconnected due to error. + Error, +} + + +/// Result type for JSON-RPC operations. +/// +/// This wraps either a successful response or an error response. +#[derive(Debug, Clone)] +pub enum JsonRpcResult { + Success(JsonRpcResponse), + Error(JsonRpcErrorResponse), +} + +impl JsonRpcResult { + /// Convert to Result, mapping error responses to TransportError. + pub fn into_result(self) -> Result { + match self { + JsonRpcResult::Success(response) => Ok(response), + JsonRpcResult::Error(error_response) => Err(TransportError::JsonRpcError(error_response.error)), + } + } + + /// Get the result value if successful. + pub fn result(&self) -> Option<&serde_json::Value> { + match self { + JsonRpcResult::Success(response) => Some(&response.result), + JsonRpcResult::Error(_) => None, + } + } + + /// Get the error if this is an error response. + pub fn error(&self) -> Option<&JsonRpcError> { + match self { + JsonRpcResult::Success(_) => None, + JsonRpcResult::Error(response) => Some(&response.error), + } + } + + /// Check if this is a success response. + pub fn is_success(&self) -> bool { + matches!(self, JsonRpcResult::Success(_)) + } + + /// Check if this is an error response. + pub fn is_error(&self) -> bool { + matches!(self, JsonRpcResult::Error(_)) + } +} + +/// Notification stream type. +pub type NotificationStream = Pin + Send>>; + +/// Transport trait for ACP communication. +/// +/// This trait abstracts over stdio and HTTP transports, providing a unified +/// interface for sending JSON-RPC requests and receiving notifications. +#[async_trait] +pub trait Transport: Send + Sync { + /// Send a JSON-RPC request and wait for the response. + /// + /// This is a request-response operation. The transport will match the + /// response by request ID. + async fn send_request(&mut self, request: JsonRpcRequest) -> Result; + + /// Send a JSON-RPC notification (fire-and-forget). + /// + /// Notifications do not expect a response. + async fn send_notification(&mut self, notification: JsonRpcNotification) -> Result<(), TransportError>; + + /// Get a stream of notifications from the agent. + /// + /// This returns a stream that yields notifications as they arrive. + fn notification_stream(&mut self) -> NotificationStream; + + /// Close the transport. + /// + /// This gracefully shuts down the transport, cleaning up resources. + async fn close(&mut self) -> Result<(), TransportError>; + + /// Check if the transport is connected. + fn is_connected(&self) -> bool; + + /// Get the current transport state. + fn state(&self) -> TransportState; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_jsonrpc_request_new() { + let req = JsonRpcRequest::new(1, "test_method", None); + assert_eq!(req.jsonrpc, "2.0"); + assert_eq!(req.method, "test_method"); + assert_eq!(req.id, serde_json::Value::Number(1.into())); + assert!(req.params.is_none()); + } + + #[test] + fn test_jsonrpc_request_with_params() { + let params = serde_json::json!({"key": "value"}); + let req = JsonRpcRequest::new(2, "test_method", Some(params.clone())); + assert_eq!(req.params, Some(params)); + } + + #[test] + fn test_jsonrpc_request_serialize() { + let req = JsonRpcRequest::new(1, "test_method", None); + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"jsonrpc\":\"2.0\"")); + assert!(json.contains("\"method\":\"test_method\"")); + assert!(json.contains("\"id\":1")); + } + + #[test] + fn test_jsonrpc_notification() { + let notif = JsonRpcNotification::new("test_notification", None); + assert_eq!(notif.jsonrpc, "2.0"); + assert_eq!(notif.method, "test_notification"); + assert!(notif.params.is_none()); + } + + #[test] + fn test_jsonrpc_error() { + let error = JsonRpcError { + code: -32600, + message: "Invalid Request".to_string(), + data: None, + }; + assert_eq!(error.code, -32600); + assert_eq!(error.message, "Invalid Request"); + } + + #[test] + fn test_jsonrpc_result_success() { + let response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::Value::Number(1.into()), + result: serde_json::json!({"status": "ok"}), + }; + let result = JsonRpcResult::Success(response); + + assert!(result.is_success()); + assert!(!result.is_error()); + assert!(result.result().is_some()); + assert!(result.error().is_none()); + } + + #[test] + fn test_jsonrpc_result_error() { + let error_response = JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::Value::Number(1.into()), + error: JsonRpcError { + code: -32600, + message: "Invalid Request".to_string(), + data: None, + }, + }; + let result = JsonRpcResult::Error(error_response); + + assert!(!result.is_success()); + assert!(result.is_error()); + assert!(result.result().is_none()); + assert!(result.error().is_some()); + } +} diff --git a/crates/dirigent_core/src/bin/main.rs b/crates/dirigent_core/src/bin/main.rs new file mode 100644 index 0000000..ec4990a --- /dev/null +++ b/crates/dirigent_core/src/bin/main.rs @@ -0,0 +1,48 @@ +//! Dirigent Core Server +//! +//! Main server application that orchestrates ACP agents + +use axum::{response::Json, routing::get, Router}; +use dirigent_acp_api::create_api_router; +use dirigent_core::CoreConfig; +use std::net::SocketAddr; +use tower_http::cors::CorsLayer; +use tracing_subscriber; + +const DEFAULT_PORT: u16 = 3000; + +#[tokio::main] +async fn main() { + // Initialize tracing + tracing_subscriber::fmt::init(); + + let _config = CoreConfig::default(); + + tracing::info!("Starting Dirigent Core server on port {}", DEFAULT_PORT); + + // Create the main application router + let app = Router::new() + .route("/", get(root_handler)) + .route("/info", get(info_handler)) + .nest("/api", create_api_router()) + .layer(CorsLayer::permissive()); + + // Start the server + let addr = SocketAddr::from(([127, 0, 0, 1], DEFAULT_PORT)); + tracing::info!("Listening on {}", addr); + + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + axum::serve(listener, app).await.unwrap(); +} + +async fn root_handler() -> &'static str { + "Dirigent Core - ACP Agent Orchestrator\n\nEndpoints:\n GET /info - Server info\n GET /api - ACP API\n GET /api/health - Health check" +} + +async fn info_handler() -> Json { + Json(serde_json::json!({ + "name": "dirigent_core", + "version": env!("CARGO_PKG_VERSION"), + "description": "Orchestrates agentic clients with ACP", + })) +} diff --git a/crates/dirigent_core/src/config.rs b/crates/dirigent_core/src/config.rs new file mode 100644 index 0000000..ebbc7df --- /dev/null +++ b/crates/dirigent_core/src/config.rs @@ -0,0 +1,2265 @@ +//! Configuration types for the Dirigent core runtime +//! +//! This module defines configuration structures used throughout the dirigent_core +//! system for runtime setup and connector management. + +use crate::types::{ConnectorKind, UserId}; +use crate::CoreError; +use once_cell::sync::Lazy; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::PathBuf; +use tracing::{debug, info, warn}; + +/// Core runtime configuration +/// +/// Defines the configuration for the CoreRuntime, including connector +/// configurations, archive settings, and project management. +/// +/// # Example +/// +/// ``` +/// use dirigent_core::CoreConfig; +/// use std::path::PathBuf; +/// +/// let config = CoreConfig { +/// runtime_working_dir: PathBuf::from("/my/project"), +/// connectors: vec![], +/// ..CoreConfig::default() +/// }; +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CoreConfig { + /// Web server listen port. Read by the server at startup and set as the + /// `PORT` env var before Dioxus resolves the listen address. + /// Precedence: `PORT` env var > this field > Dioxus default (8080). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub port: Option, + + /// Default working directory for connector processes (e.g., spawning Claude). + /// Used when no project or connector-specific working directory is configured. + /// Relative paths are resolved against the config file's directory. + #[serde(alias = "project_dir")] + pub runtime_working_dir: PathBuf, + + /// Connector configurations to be loaded on startup + /// + /// These connectors will be automatically created and started when the + /// runtime initializes. This enables persistent connector setups across + /// server restarts. + pub connectors: Vec, + + /// Archive root directory (deprecated in config file — use DIRIGENT_DATA_DIR instead) + /// Kept deserializable for backward compatibility but no longer written. + #[serde(default, skip_serializing)] + pub archive_root: Option, + + /// Archive backend declarations. Phase 3+ replaces `archive_root` with + /// this `[[archives]]`-array config. When both are set, `archives` wins. + /// + /// Stored as typed `ArchiveConfig` on server builds and as raw + /// `serde_json::Value` on WASM/non-server builds (archivist types pull in + /// the full coordinator which is server-only). + #[cfg(feature = "server")] + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub archives: Vec, + + /// Archive backend declarations (opaque on non-server builds). + #[cfg(not(feature = "server"))] + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub archives: Vec, + + /// Root directory for project storage (deprecated in config file — use DIRIGENT_DATA_DIR instead) + /// Kept deserializable for backward compatibility but no longer written. + #[serde(default, skip_serializing)] + pub projects_root: Option, + + /// ACP Server configuration + /// + /// If specified, configures the ACP Server for accepting incoming + /// Agent-Client Protocol connections. When enabled, external ACP agents + /// can connect to this Dirigent instance. + /// + /// If None, no ACP Server will be started (disabled). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub acp_server: Option, + + /// List of Zed agent connector titles that the user has explicitly dismissed + /// + /// When a user removes a Zed-detected connector, its title is added + /// to this list so it won't be re-added on next startup by auto-detection. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub dismissed_zed_agents: Vec, + + /// Task definitions for background process management + /// + /// These tasks can be configured via the UI or TOML file. + /// Tasks with `run_at_startup = true` will be automatically started. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub tasks: Vec, + + /// Named accounts parsed from [accounts.*] sections. + /// + /// Each key is the account name used elsewhere in the config. + /// Identity, credentials, and connection properties live here. + /// + /// Example in dirigent.toml: + /// ```toml + /// [accounts.matrix-bot] + /// type = "matrix" + /// username = "dirigent_bot" + /// display_name = "Dirigent Bot" + /// homeserver = "https://matrix.example.com" + /// + /// [accounts.matrix-bot.credentials.password] + /// type = "env" + /// value = "DIRIGENT_MATRIX_PASSWORD" + /// ``` + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub accounts: HashMap, + + /// Matrix sharing behavior configuration. + /// + /// References an account by name for identity/credentials. + /// If None, Matrix session sharing is disabled. + /// + /// Stored as typed `MatrixBehaviorConfig` on server builds and as + /// raw `serde_json::Value` on WASM/non-server builds to avoid pulling + /// in the heavy `matrix-sdk` dependency. + /// + /// Example in dirigent.toml: + /// ```toml + /// [matrix] + /// account = "matrix-bot" + /// default_invite = ["@user:example.com"] + /// store_path = "matrix/bot/store" + /// ``` + #[cfg(feature = "server")] + #[serde(default, skip_serializing_if = "Option::is_none")] + pub matrix: Option, + + /// Matrix sharing behavior configuration (opaque on non-server builds). + #[cfg(not(feature = "server"))] + #[serde(default, skip_serializing_if = "Option::is_none")] + pub matrix: Option, + + /// `[[streams]]` config blocks — declarative stream attachments that the + /// runtime wires up at boot via `CoreRuntime::attach_stream`. + /// + /// `StreamsConfig` uses `#[serde(rename = "streams")]` on its `entries` + /// field, so `#[serde(flatten)]` lifts it into the top-level TOML as + /// `[[streams]]` tables. The inner `entries` field has its own + /// `#[serde(default)]`, so omitting `[[streams]]` deserialises fine. + #[serde(flatten)] + pub streams: crate::sharing::StreamsConfig, + + /// Path the config was loaded from (not serialized) + /// + /// Tracked so that `save_config` writes back to the same file that was + /// originally loaded, preserving the user's chosen format and location. + #[serde(skip)] + pub config_source_path: Option, + + /// Resolved runtime working directory (computed at load time, never serialized). + /// + /// Populated by `resolve_paths()`. Use `effective_working_dir()` to read — + /// it returns this value when set, falling back to `runtime_working_dir`. + #[serde(skip)] + pub resolved_working_dir: Option, +} + +/// Configuration for the ACP Server +/// +/// Defines settings for the ACP Server which accepts incoming ACP connections +/// from external agents. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AcpServerConfig { + /// Whether the ACP Server is enabled + pub enabled: bool, + + /// Port to listen on + /// - None: Integrated at /acp on main server (default) + /// - Some(port): Separate server on specified port + #[serde(default, skip_serializing_if = "Option::is_none")] + pub port: Option, + + /// Allowed origins for CORS (None = all origins allowed) + #[serde(skip_serializing_if = "Option::is_none")] + pub allowed_origins: Option>, + + /// Maximum concurrent connections (default: 100) + pub max_connections: usize, + + /// Default connector ID for routing incoming sessions (None = use Gateway) + #[serde(skip_serializing_if = "Option::is_none")] + pub default_connector_id: Option, +} + +/// Configuration for a background task (serialized to TOML) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskConfig { + /// Human-readable title for display in the UI + pub title: String, + /// Unique machine-friendly identifier + pub name: String, + /// The command to execute (e.g., "cargo", "npm") + pub command: String, + /// Arguments passed to the command + #[serde(default)] + pub args: Vec, + /// Working directory for the process (None = inherit runtime working dir) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub working_directory: Option, + /// Whether to start this task automatically when the runtime starts + #[serde(default)] + pub run_at_startup: bool, + /// Maximum number of output lines to keep in memory + #[serde(default = "default_task_buffer_size")] + pub buffer_size: usize, + /// Whether to persist output to disk + #[serde(default = "default_task_persist")] + pub persist_to_disk: bool, + /// Whether to rotate (rename) previous output files on restart + #[serde(default)] + pub rotate_previous: bool, + /// Additional environment variables for the process + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub env: Vec<(String, String)>, +} + +fn default_task_buffer_size() -> usize { + 10000 +} + +fn default_task_persist() -> bool { + true +} + +impl Default for AcpServerConfig { + fn default() -> Self { + Self { + enabled: false, + port: None, // Default to integrated mode + allowed_origins: None, + max_connections: 100, + default_connector_id: None, + } + } +} + +impl Default for CoreConfig { + fn default() -> Self { + Self { + port: None, + runtime_working_dir: PathBuf::from("."), + connectors: vec![], + archive_root: None, + archives: Vec::new(), + projects_root: None, + acp_server: None, + dismissed_zed_agents: vec![], + tasks: vec![], + accounts: HashMap::new(), + matrix: None, + streams: crate::sharing::StreamsConfig::default(), + config_source_path: None, + resolved_working_dir: None, + } + } +} + +impl CoreConfig { + /// Returns the effective working directory for runtime use. + /// + /// Prefers `resolved_working_dir` (populated by `resolve_paths()` at load time). + /// Falls back to the raw `runtime_working_dir` if resolution hasn't been called. + pub fn effective_working_dir(&self) -> &std::path::Path { + self.resolved_working_dir + .as_deref() + .unwrap_or(&self.runtime_working_dir) + } + + /// Resolve relative paths against the config file's parent directory. + /// + /// Populates `resolved_working_dir` based on `runtime_working_dir`: + /// - `"."` or equivalent → `noproject_home_dir()` (no explicit project) + /// - Relative path → joined with `config_dir_parent` + /// - Absolute path → used as-is + /// + /// Also resolves deprecated `archive_root` and `projects_root` in-place + /// (these are skip_serializing so mutation is safe). + pub fn resolve_paths(&mut self, config_dir_parent: Option<&std::path::Path>) { + let is_dot = self.runtime_working_dir == PathBuf::from("."); + + let resolved = if is_dot { + dirigent_config::DirigentPaths::resolve() + .map(|p| p.noproject_home_dir()) + .unwrap_or_else(|_| { + config_dir_parent + .map(|p| p.to_path_buf()) + .unwrap_or_else(|| std::env::current_dir().unwrap_or_default()) + }) + } else if self.runtime_working_dir.is_relative() { + if let Some(parent) = config_dir_parent { + let joined = parent.join(&self.runtime_working_dir); + debug!( + original = ?self.runtime_working_dir, + resolved = ?joined, + "Resolved relative runtime_working_dir against config directory" + ); + joined + } else { + std::env::current_dir() + .unwrap_or_default() + .join(&self.runtime_working_dir) + } + } else { + self.runtime_working_dir.clone() + }; + + // If resolved path equals the config dir, fall back to noproject_home + let final_resolved = if let Some(parent) = config_dir_parent { + if resolved == parent { + dirigent_config::DirigentPaths::resolve() + .map(|p| p.noproject_home_dir()) + .unwrap_or(resolved) + } else { + resolved + } + } else { + resolved + }; + + self.resolved_working_dir = Some(final_resolved); + + // Resolve deprecated fields in-place (skip_serializing, safe from corruption) + if let Some(parent) = config_dir_parent { + if let Some(ref archive) = self.archive_root { + if archive.is_relative() { + let resolved = parent.join(archive); + debug!(original = ?archive, resolved = ?resolved, "Resolved relative archive_root against config directory"); + self.archive_root = Some(resolved); + } + } + if let Some(ref projects) = self.projects_root { + if projects.is_relative() { + let resolved = parent.join(projects); + debug!(original = ?projects, resolved = ?resolved, "Resolved relative projects_root against config directory"); + self.projects_root = Some(resolved); + } + } + } + } + + /// Populate runtime fields on accounts after deserialization. + /// + /// Sets `config_name` from the map key and derives a stable `user_id` + /// (UUID v5 from the account name) for any account that doesn't have one. + /// Call this immediately after deserializing or loading a `CoreConfig`. + pub fn populate_account_names(&mut self) { + for (name, account) in &mut self.accounts { + account.config_name = name.clone(); + if account.user_id.is_none() { + account.user_id = Some(uuid::Uuid::new_v5( + &uuid::Uuid::NAMESPACE_URL, + format!("dirigent:account:{}", name).as_bytes(), + )); + } + } + } + + /// Find the first local account, if any. + pub fn default_local_account(&self) -> Option<&dirigent_auth::Account> { + self.accounts + .values() + .find(|a| a.kind == dirigent_auth::AccountKind::Local) + } + + /// Get the UserId for the default local account, or Uuid::nil() if none. + pub fn default_owner(&self) -> dirigent_auth::UserId { + self.default_local_account() + .and_then(|a| a.user_id) + .unwrap_or(uuid::Uuid::nil()) + } + + /// Load configuration from a file + /// + /// This function attempts to load configuration from various sources in priority order: + /// 1. The provided `path` parameter (if Some) + /// 2. Config directory (platform-native, or overridden via `DIRIGENT_CONFIG_DIR` env var): + /// `%APPDATA%\dirigent\` on Windows, `~/.config/dirigent/` on Linux/macOS + /// 3. Default configuration (if no file found) + /// + /// The file format (TOML or JSON) is automatically detected based on the file extension. + /// + /// # Arguments + /// + /// * `path` - Optional explicit path to configuration file + /// + /// # Returns + /// + /// A `Result` containing the loaded `CoreConfig` or a `CoreError` + /// + /// # Errors + /// + /// Returns `CoreError::InvalidConfig` if: + /// - The file exists but cannot be read + /// - The file exists but contains invalid TOML/JSON + /// - The file has an unsupported extension + /// + /// # Example + /// + /// ```no_run + /// use dirigent_core::CoreConfig; + /// use std::path::PathBuf; + /// + /// // Load from explicit path + /// let config = CoreConfig::load_config(Some(PathBuf::from("my-config.toml"))).unwrap(); + /// + /// // Load from config directory or defaults + /// let config = CoreConfig::load_config(None).unwrap(); + /// ``` + pub fn load_config(path: Option) -> Result { + // Determine which file to load + let config_path = if let Some(p) = path { + debug!(path = ?p, "Loading config from provided path"); + Some(p) + } else { + // Check config directory (platform-native, or DIRIGENT_CONFIG_DIR override) + if let Ok(paths) = dirigent_config::DirigentPaths::resolve() { + let user_toml = paths.config_dir().join("dirigent.toml"); + let user_json = paths.config_dir().join("dirigent.json"); + + if user_toml.exists() { + debug!(path = ?user_toml, "Found dirigent.toml in config directory"); + Some(user_toml) + } else if user_json.exists() { + debug!(path = ?user_json, "Found dirigent.json in config directory"); + Some(user_json) + } else { + debug!("No config file found in config directory, using defaults"); + None + } + } else { + debug!("Could not resolve config directory, using defaults"); + None + } + }; + + // If we have a path, try to load it + if let Some(path) = config_path { + // Check if file exists + if !path.exists() { + warn!(path = ?path, "Config file not found, using default configuration"); + info!("Using default configuration"); + return Ok(CoreConfig::default()); + } + + // Read file contents + let contents = std::fs::read_to_string(&path).map_err(|e| { + CoreError::Internal(format!( + "Failed to read config file {}: {}", + path.display(), + e + )) + })?; + + // Determine format by extension + let mut config = match path.extension().and_then(|s| s.to_str()) { + Some("toml") => { + debug!("Parsing config as TOML"); + toml::from_str::(&contents).map_err(|e| { + CoreError::Internal(format!( + "Failed to parse TOML config file {}: {}", + path.display(), + e + )) + })? + } + Some("json") => { + debug!("Parsing config as JSON"); + serde_json::from_str::(&contents).map_err(|e| { + CoreError::Internal(format!( + "Failed to parse JSON config file {}: {}", + path.display(), + e + )) + })? + } + _ => { + let msg = format!( + "Unsupported config file extension for {}, expected .toml or .json", + path.display() + ); + warn!("{}", msg); + return Err(CoreError::Internal(msg)); + } + }; + + info!(path = ?path, "Successfully loaded configuration"); + + if config.archive_root.is_some() { + warn!("archive_root in config file is deprecated. Set DIRIGENT_DATA_DIR instead. Archives default to /archives/"); + } + if config.projects_root.is_some() { + warn!("projects_root in config file is deprecated. Set DIRIGENT_DATA_DIR instead. Projects default to /projects/"); + } + + // Deduplicate connectors by ID (in case of corrupted config file) + let original_count = config.connectors.len(); + let mut seen_ids = std::collections::HashSet::new(); + config.connectors.retain(|c| { + if let Some(ref id) = c.id { + if seen_ids.contains(id) { + warn!(connector_id = %id, "Removing duplicate connector from config"); + false + } else { + seen_ids.insert(id.clone()); + true + } + } else { + // Keep connectors without IDs (shouldn't happen, but be safe) + warn!("Found connector without ID in config, keeping it"); + true + } + }); + + let deduped_count = config.connectors.len(); + if deduped_count < original_count { + info!( + original = original_count, + deduped = deduped_count, + removed = original_count - deduped_count, + "Deduplicated connectors in config" + ); + } + + // Track which file we loaded from so save_config writes back to it + config.config_source_path = Some(path.clone()); + + // Resolve relative paths against config file location. + // This populates resolved_working_dir (never mutates runtime_working_dir) + // and resolves deprecated archive_root/projects_root in-place. + config.resolve_paths(path.parent()); + + // Populate runtime fields on accounts (config_name, user_id). + config.populate_account_names(); + + Ok(config) + } else { + // No config file found, use default + info!("Using default configuration"); + let mut config = CoreConfig::default(); + config.resolve_paths(None); + // No accounts to populate in default config, but call for consistency. + config.populate_account_names(); + Ok(config) + } + } +} + +/// Lenient deserializer for the `owner` field. +/// +/// Accepts a valid UUID string, or silently returns `None` for non-UUID +/// values (e.g. legacy `"DIRIGENT-USER"` strings) and missing fields. +fn deserialize_owner_lenient<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + // Accept either null/missing or a string + let opt: Option = Option::deserialize(deserializer)?; + match opt { + None => Ok(None), + Some(s) => match uuid::Uuid::parse_str(&s) { + Ok(id) => Ok(Some(id)), + Err(_) => { + tracing::debug!( + value = %s, + "Ignoring non-UUID owner value in connector config, treating as None" + ); + Ok(None) + } + }, + } +} + +/// Configuration for creating a connector +/// +/// This struct contains all the information needed to create and configure +/// a connector instance. The `params` field is a JSON value that will be +/// deserialized into the appropriate connector-specific config type. +/// +/// ConnectorConfig can be serialized and persisted to disk, enabling +/// connectors to be automatically recreated on server restart. +/// +/// # Example +/// +/// ```no_run +/// use dirigent_core::{ConnectorConfig, ConnectorKind}; +/// use serde_json::json; +/// +/// let config = ConnectorConfig { +/// id: Some("my-connector".to_string()), +/// kind: ConnectorKind::OpenCode, +/// owner: Some(uuid::Uuid::now_v7()), +/// title: Some("My OpenCode".to_string()), +/// working_directory: None, +/// params: json!({ +/// "base_url": "http://localhost:12225", +/// "title": "My OpenCode", +/// "initial_session": null +/// }), +/// ..Default::default() +/// }; +/// ``` +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ConnectorConfig { + /// Optional connector ID (will be generated if not provided) + /// + /// When persisting connectors, an explicit ID ensures the same connector + /// can be referenced consistently across restarts. + pub id: Option, + + /// Type of connector to create + /// + /// Determines which connector implementation will be instantiated. + /// Each variant corresponds to a different agent system integration. + pub kind: ConnectorKind, + + /// Optional owner user ID (may be set by the runtime) + /// + /// Identifies which user owns this connector. Used for authorization + /// and filtering. If not provided during creation, the runtime will + /// set this based on the creating user. + /// + /// For backwards compatibility, non-UUID strings (e.g. `"DIRIGENT-USER"`) + /// are silently treated as `None` during deserialization. + #[serde(default, deserialize_with = "deserialize_owner_lenient")] + pub owner: Option, + + /// Optional title (may be provided in params instead) + /// + /// Human-readable title for the connector. If not provided here, + /// the connector-specific params should include a title field. + pub title: Option, + + /// Optional working directory for the connector + /// + /// If specified, this directory will be used as the working directory for + /// agent operations. If not specified, the working directory is resolved + /// using the following hierarchy: + /// 1. Explicit working_directory (if set) + /// 2. Global project_dir from CoreConfig + /// 3. Current working directory of the Dirigent process + /// + /// The resolved path is normalized and canonicalized where possible. + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, + + /// Connector-specific parameters as JSON + /// + /// These will be deserialized into the appropriate config type + /// based on the `kind` field. For example: + /// - OpenCode: deserialized to `OpenCodeConfig` + /// - Acp: deserialized to `AcpConfig` (future) + /// + /// Using serde_json::Value allows flexible persistence without + /// requiring the config module to know about all connector types. + pub params: serde_json::Value, + + /// T034: Optional custom icon path for this connector + /// + /// If specified, this path points to an image file that will be used + /// as the connector's icon in the UI instead of the default emoji. + /// The path can be absolute or relative to the project directory. + #[serde(skip_serializing_if = "Option::is_none")] + pub icon_path: Option, + + /// T035: Show connector type emoji as overlay on custom icon + /// + /// When true and a custom icon_path is provided, the connector type + /// emoji (e.g., for OpenCode, ACP) will appear as a small overlay + /// in the lower-right corner of the custom icon, similar to Windows + /// shortcut overlay icons. + #[serde(default)] + pub show_type_overlay: bool, + + /// T001: List of features this connector supports + /// + /// Orchestration-level field that declares which optional features + /// this connector implementation supports. Examples: + /// - "session_resume" - Can resume previous sessions + /// - "session_list" - Can list available sessions + /// - "message_history" - Can retrieve message history + /// - "cancellation" - Can cancel in-progress operations + /// - "model_selection" - Supports model selection + /// - "mode_selection" - Supports mode/persona selection + /// + /// This field should NOT appear in connector-specific params. + #[serde(default)] + pub supported_features: Vec, + + /// Tool configuration for this connector. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_configuration: Option, + + /// Plugin assignments for this connector. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub plugin_assignments: Vec, + + /// Whether this connector is used in newly created projects by default. + #[serde(default = "default_true")] + pub use_in_new_projects: bool, + + /// Source of this connector (e.g., "user", "zed"). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub source: Option, + + /// Zed agent name this connector was created from. + /// + /// When set, this connector's binary path will be automatically refreshed + /// at startup to track Zed agent upgrades. The value is the Zed agent name + /// (e.g. "claude-agent-acp", "codex", "gemini") used to look up the + /// current binary in Zed's `external_agents/` directory. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub zed_agent_name: Option, +} + +fn default_true() -> bool { + true +} + +impl Default for ConnectorConfig { + fn default() -> Self { + Self { + id: None, + kind: ConnectorKind::Mock, + owner: None, + title: None, + working_directory: None, + params: serde_json::json!({}), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + } + } +} + +/// Connector configuration templates +/// +/// Provides pre-defined connector configurations with sensible defaults. +/// Templates can be referenced by (ConnectorKind, label) and customized +/// via JSON patches when creating connectors. +/// +/// # Available Templates +/// +/// ## OpenCode Templates +/// - `opencode/default` - Standard OpenCode connector with localhost URL +/// +/// ## ACP Templates +/// - `acp/claude-default` - Claude API connector with default model (stub) +/// +/// # Usage +/// +/// Templates are accessed via the `apply_template` function, which merges +/// a template with a user-provided patch to create a final ConnectorConfig. +static TEMPLATES: Lazy> = + Lazy::new(|| { + let mut m = HashMap::new(); + + // OpenCode default template + m.insert( + (ConnectorKind::OpenCode, "default"), + serde_json::json!({ + "base_url": "http://localhost:12225", + "initial_session": null + }), + ); + + // ACP Claude stdio template + // Note: Claude Code has limited ACP support: + // - NOT session_resume: generates ephemeral ACP session IDs + // - NOT cancellation: session/cancel not implemented + m.insert( + (ConnectorKind::Acp, "claude-stdio"), + serde_json::json!({ + "transport": { + "type": "stdio", + "command": "claude", + "args": ["--acp"] + }, + "protocol_version": 1, + "cwd": ".", + "retry": { + "max_retries": 5, + "retry_delays_ms": [1000, 3000, 5000, 5000, 5000] + } + }), + ); + + // ACP Mocker stdio template (for testing) + // Note: Mocker supports all features for testing purposes + m.insert( + (ConnectorKind::Acp, "mocker-stdio"), + serde_json::json!({ + "transport": { + "type": "stdio", + "command": "dirigate", + "args": ["serve", "--fixtures", "basic.yaml", "--stdio"] + }, + "protocol_version": 1, + "cwd": ".", + "retry": { + "max_retries": 5, + "retry_delays_ms": [1000, 3000, 5000, 5000, 5000] + } + }), + ); + + // ACP HTTP template + // Note: HTTP agent capabilities vary - user should configure supported_features + m.insert( + (ConnectorKind::Acp, "http"), + serde_json::json!({ + "transport": { + "type": "http", + "base_url": "http://localhost:3000", + "timeout_ms": 30000 + }, + "protocol_version": 1, + "cwd": ".", + "retry": { + "max_retries": 5, + "retry_delays_ms": [1000, 3000, 5000, 5000, 5000] + } + }), + ); + + m + }); + +/// Apply a connector template with optional patches +/// +/// This function retrieves a template by kind and label, then merges it with +/// the provided patch object to create a final ConnectorConfig. The patch +/// values override template values, allowing customization of defaults. +/// +/// # Arguments +/// +/// * `kind` - The type of connector (OpenCode, Acp, etc.) +/// * `label` - The template identifier (e.g., "default", "claude-default") +/// * `patch` - JSON object with values to override in the template +/// +/// # Returns +/// +/// A `Result` containing the merged `ConnectorConfig` or a `CoreError` +/// +/// # Errors +/// +/// Returns `CoreError::NotFound` if no template exists for the given (kind, label) pair. +/// +/// # Example +/// +/// ```no_run +/// use dirigent_core::{apply_template, ConnectorKind}; +/// use serde_json::json; +/// +/// // Use template with default values +/// let config = apply_template( +/// ConnectorKind::OpenCode, +/// "default", +/// json!({}) +/// ).unwrap(); +/// +/// // Override the base_url and title +/// let config = apply_template( +/// ConnectorKind::OpenCode, +/// "default", +/// json!({ +/// "base_url": "http://localhost:8080", +/// "title": "My Custom OpenCode" +/// }) +/// ).unwrap(); +/// ``` +pub fn apply_template( + kind: ConnectorKind, + label: &str, + patch: serde_json::Value, +) -> Result { + // Lookup template + let template = TEMPLATES + .get(&(kind.clone(), label)) + .ok_or(CoreError::NotFound)?; + + // Merge template with patch (patch overrides template) + let merged = merge_json(template.clone(), patch); + + // Extract orchestration-level fields from merged JSON + let title = merged + .get("title") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let supported_features = merged + .get("supported_features") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect() + }) + .unwrap_or_default(); + + // Remove orchestration fields from params + let mut params = merged; + if let Some(obj) = params.as_object_mut() { + obj.remove("title"); + obj.remove("supported_features"); + obj.remove("icon_path"); + obj.remove("show_type_overlay"); + } + + // Build ConnectorConfig + Ok(ConnectorConfig { + id: None, // ID will be generated by runtime + kind, + owner: None, // Owner will be set by runtime + title, + working_directory: None, // Working directory will be resolved from global config + params, + icon_path: None, // Custom icons are not set via templates + show_type_overlay: false, // Default to no overlay + supported_features, + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }) +} + +/// Merge two JSON values with right-hand side taking precedence +/// +/// This is a simple shallow merge that combines two JSON objects. +/// If both values are objects, their keys are merged with `patch` +/// values overriding `base` values. Otherwise, `patch` replaces `base`. +/// +/// # Arguments +/// +/// * `base` - The base JSON value (typically from a template) +/// * `patch` - The patch JSON value (user overrides) +/// +/// # Returns +/// +/// The merged JSON value +fn merge_json(base: serde_json::Value, patch: serde_json::Value) -> serde_json::Value { + use serde_json::Value; + + match (base, patch) { + (Value::Object(mut base_map), Value::Object(patch_map)) => { + // Merge objects: patch values override base values + for (key, value) in patch_map { + base_map.insert(key, value); + } + Value::Object(base_map) + } + (_, patch) => { + // For non-objects, patch completely replaces base + patch + } + } +} + +/// Resolve the working directory for a connector +/// +/// This function implements the resolution hierarchy for determining the +/// working directory to use for a connector's operations: +/// +/// 1. Explicit `working_directory` in connector config (if set) +/// 2. `runtime_working_dir` from global config +/// +/// The resolved path is normalized and, if it exists, canonicalized. +/// If the directory does not exist, a warning is logged but the path +/// is still returned (allowing for creation later). +/// +/// # Arguments +/// +/// * `connector_config` - The connector configuration +/// * `global_config` - The global runtime configuration +/// +/// # Returns +/// +/// The resolved working directory path. This path is normalized and +/// canonicalized where possible. +/// +/// # Example +/// +/// ```no_run +/// use dirigent_core::{resolve_default_runtime_working_directory, ConnectorConfig, CoreConfig}; +/// +/// let connector_config = ConnectorConfig { +/// id: None, +/// kind: dirigent_core::ConnectorKind::OpenCode, +/// owner: None, +/// title: None, +/// working_directory: Some("/path/to/workdir".into()), +/// params: serde_json::json!({}), +/// }; +/// let global_config = CoreConfig::default(); +/// +/// let workdir = resolve_default_runtime_working_directory(&connector_config, &global_config); +/// println!("Working directory: {}", workdir.display()); +/// ``` +pub fn resolve_default_runtime_working_directory( + connector_config: &ConnectorConfig, + global_config: &CoreConfig, +) -> PathBuf { + // Step 1: Check if explicit working_directory is set in connector config + if let Some(ref wd) = connector_config.working_directory { + debug!( + working_directory = ?wd, + "Using explicit working_directory from connector config" + ); + + // Normalize and canonicalize if path exists + return normalize_path(wd); + } + + // Step 2: Use runtime_working_dir from global config + let project_dir = &global_config.runtime_working_dir; + debug!( + working_directory = ?project_dir, + "Using runtime_working_dir from global config" + ); + + normalize_path(project_dir) +} + +/// Normalize and canonicalize a path +/// +/// This helper function normalizes a path and attempts to canonicalize it +/// if the path exists. If canonicalization fails (e.g., directory doesn't +/// exist), the normalized path is returned with a warning. +/// +/// # Arguments +/// +/// * `path` - The path to normalize +/// +/// # Returns +/// +/// The normalized (and canonicalized if possible) path +fn normalize_path(path: &PathBuf) -> PathBuf { + // Try to canonicalize (resolves symlinks and relative paths) + match path.canonicalize() { + Ok(canonical) => { + debug!( + original = ?path, + canonical = ?canonical, + "Successfully canonicalized path" + ); + canonical + } + Err(e) => { + // Path doesn't exist or can't be canonicalized + warn!( + path = ?path, + error = %e, + "Failed to canonicalize path (directory may not exist), using normalized path" + ); + + // Return the normalized path as-is + path.clone() + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::ConnectorKind; + use serde_json::json; + + #[test] + fn test_core_config_default() { + let config = CoreConfig::default(); + assert_eq!(config.runtime_working_dir, PathBuf::from(".")); + assert_eq!(config.connectors.len(), 0); + } + + #[test] + fn test_core_config_custom() { + let config = CoreConfig { + runtime_working_dir: PathBuf::from("/my/project"), + connectors: vec![], + archive_root: None, + archives: Vec::new(), + projects_root: None, + acp_server: None, + dismissed_zed_agents: vec![], + tasks: vec![], + accounts: HashMap::new(), + matrix: None, + streams: crate::sharing::StreamsConfig::default(), + config_source_path: None, + resolved_working_dir: None, + }; + + assert_eq!(config.runtime_working_dir, PathBuf::from("/my/project")); + } + + #[test] + fn test_core_config_serialization() { + let config = CoreConfig { + runtime_working_dir: PathBuf::from("/my/project"), + connectors: vec![], + archive_root: None, + archives: Vec::new(), + projects_root: None, + acp_server: None, + dismissed_zed_agents: vec![], + tasks: vec![], + accounts: HashMap::new(), + matrix: None, + streams: crate::sharing::StreamsConfig::default(), + config_source_path: None, + resolved_working_dir: None, + }; + + let json = serde_json::to_string(&config).expect("Failed to serialize"); + let deserialized: CoreConfig = serde_json::from_str(&json).expect("Failed to deserialize"); + + assert_eq!(deserialized.runtime_working_dir, config.runtime_working_dir); + } + + #[test] + fn test_connector_config_basic() { + let config = ConnectorConfig { + id: Some("conn-1".to_string()), + kind: ConnectorKind::OpenCode, + owner: Some(uuid::Uuid::nil()), + title: Some("My Connector".to_string()), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": "My Connector", + }), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + assert_eq!(config.id, Some("conn-1".to_string())); + assert_eq!(config.kind, ConnectorKind::OpenCode); + assert_eq!(config.owner, Some(uuid::Uuid::nil())); + assert_eq!(config.title, Some("My Connector".to_string())); + } + + #[test] + fn test_connector_config_serialization() { + let config = ConnectorConfig { + id: Some("conn-1".to_string()), + kind: ConnectorKind::OpenCode, + owner: Some(uuid::Uuid::nil()), + title: Some("Test".to_string()), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + }), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let json = serde_json::to_string(&config).expect("Failed to serialize"); + let deserialized: ConnectorConfig = + serde_json::from_str(&json).expect("Failed to deserialize"); + + assert_eq!(deserialized.id, config.id); + assert_eq!(deserialized.kind, config.kind); + assert_eq!(deserialized.owner, config.owner); + assert_eq!(deserialized.title, config.title); + } + + #[test] + fn test_connector_config_minimal() { + // Test with minimal fields (id, owner, title all optional) + let config = ConnectorConfig { + id: None, + kind: ConnectorKind::Mock, + owner: None, + title: None, + working_directory: None, + params: json!({}), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + assert!(config.id.is_none()); + assert_eq!(config.kind, ConnectorKind::Mock); + assert!(config.owner.is_none()); + assert!(config.title.is_none()); + } + + #[test] + fn test_core_config_with_connectors() { + let connector1 = ConnectorConfig { + id: Some("conn-1".to_string()), + kind: ConnectorKind::OpenCode, + owner: Some(uuid::Uuid::nil()), + title: Some("Connector 1".to_string()), + working_directory: None, + params: json!({"base_url": "http://localhost:12225"}), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let connector2 = ConnectorConfig { + id: Some("conn-2".to_string()), + kind: ConnectorKind::Acp, + owner: Some(uuid::Uuid::nil()), + title: Some("Connector 2".to_string()), + working_directory: None, + params: json!({"endpoint": "http://localhost:8080"}), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let config = CoreConfig { + connectors: vec![connector1, connector2], + ..CoreConfig::default() + }; + + assert_eq!(config.connectors.len(), 2); + assert_eq!(config.connectors[0].id, Some("conn-1".to_string())); + assert_eq!(config.connectors[1].id, Some("conn-2".to_string())); + } + + #[test] + fn test_load_config_no_file_returns_default() { + // When no file exists and no env var is set, should return default + let config = CoreConfig::load_config(None).unwrap(); + // project_dir may be resolved against config directory if a config file exists + assert!(!config.runtime_working_dir.as_os_str().is_empty()); + } + + #[test] + fn test_apply_template_opencode_default() { + // Test applying OpenCode default template with no patch + let config = apply_template(ConnectorKind::OpenCode, "default", json!({})).unwrap(); + + assert_eq!(config.kind, ConnectorKind::OpenCode); + assert!(config.id.is_none()); + assert!(config.owner.is_none()); + assert_eq!(config.title, None); // Title removed from templates (T001) + assert_eq!(config.params["base_url"], "http://localhost:12225"); + assert!(config.params["initial_session"].is_null()); + } + + #[test] + fn test_apply_template_opencode_with_patch() { + // Test applying OpenCode template with custom values + let config = apply_template( + ConnectorKind::OpenCode, + "default", + json!({ + "base_url": "http://localhost:8080", + "title": "My Custom OpenCode", + "initial_session": "session-123" + }), + ) + .unwrap(); + + assert_eq!(config.kind, ConnectorKind::OpenCode); + assert_eq!(config.title, Some("My Custom OpenCode".to_string())); + assert_eq!(config.params["base_url"], "http://localhost:8080"); + assert_eq!(config.params["initial_session"], "session-123"); + } + + #[test] + fn test_apply_template_acp_claude_stdio() { + // Test applying ACP Claude stdio template + let config = apply_template(ConnectorKind::Acp, "claude-stdio", json!({})).unwrap(); + + assert_eq!(config.kind, ConnectorKind::Acp); + assert!(config.id.is_none()); + assert!(config.owner.is_none()); + assert_eq!(config.title, None); // Title removed from templates (T001) + assert_eq!(config.params["transport"]["type"], "stdio"); + assert_eq!(config.params["transport"]["command"], "claude"); + assert_eq!(config.params["protocol_version"], 1); + + // supported_features removed from templates (T001) + assert!(config.params.get("supported_features").is_none()); + // supported_features now in ConnectorConfig, not params + assert_eq!(config.supported_features, Vec::::new()); + } + + #[test] + fn test_apply_template_acp_mocker_stdio() { + // Test applying ACP Mocker stdio template + let config = apply_template(ConnectorKind::Acp, "mocker-stdio", json!({})).unwrap(); + + assert_eq!(config.kind, ConnectorKind::Acp); + assert_eq!(config.title, None); // Title removed from templates (T001) + assert_eq!(config.params["transport"]["type"], "stdio"); + assert_eq!(config.params["transport"]["command"], "dirigate"); + + // supported_features removed from templates (T001) + assert!(config.params.get("supported_features").is_none()); + // supported_features now in ConnectorConfig, not params + assert_eq!(config.supported_features, Vec::::new()); + } + + #[test] + fn test_apply_template_acp_http() { + // Test applying ACP HTTP template + let config = apply_template(ConnectorKind::Acp, "http", json!({})).unwrap(); + + assert_eq!(config.kind, ConnectorKind::Acp); + assert_eq!(config.title, None); // Title removed from templates (T001) + assert_eq!(config.params["transport"]["type"], "http"); + assert_eq!( + config.params["transport"]["base_url"], + "http://localhost:3000" + ); + assert_eq!(config.params["transport"]["timeout_ms"], 30000); + + // supported_features removed from templates (T001) + assert!(config.params.get("supported_features").is_none()); + // supported_features now in ConnectorConfig, not params + assert_eq!(config.supported_features, Vec::::new()); + } + + #[test] + fn test_apply_template_acp_with_patch() { + // Test applying ACP template with custom command + let config = apply_template( + ConnectorKind::Acp, + "claude-stdio", + json!({ + "title": "My Custom Claude", + "transport": { + "type": "stdio", + "command": "/usr/local/bin/claude", + "args": ["--acp", "--verbose"] + } + }), + ) + .unwrap(); + + assert_eq!(config.kind, ConnectorKind::Acp); + assert_eq!(config.title, Some("My Custom Claude".to_string())); + assert_eq!( + config.params["transport"]["command"], + "/usr/local/bin/claude" + ); + } + + #[test] + fn test_apply_template_not_found() { + // Test that unknown template returns NotFound error + let result = apply_template(ConnectorKind::OpenCode, "nonexistent", json!({})); + + assert_eq!(result, Err(CoreError::NotFound)); + } + + #[test] + fn test_apply_template_wrong_kind() { + // Test that wrong kind returns NotFound error + let result = apply_template(ConnectorKind::Mock, "default", json!({})); + + assert_eq!(result, Err(CoreError::NotFound)); + } + + #[test] + fn test_merge_json_objects() { + // Test merging two JSON objects + let base = json!({ + "a": 1, + "b": 2, + "c": 3 + }); + let patch = json!({ + "b": 20, + "d": 4 + }); + + let merged = merge_json(base, patch); + + assert_eq!(merged["a"], 1); + assert_eq!(merged["b"], 20); // Overridden + assert_eq!(merged["c"], 3); + assert_eq!(merged["d"], 4); // Added + } + + #[test] + fn test_merge_json_empty_patch() { + // Test that empty patch returns base unchanged + let base = json!({ + "a": 1, + "b": 2 + }); + let patch = json!({}); + + let merged = merge_json(base.clone(), patch); + + assert_eq!(merged, base); + } + + #[test] + fn test_merge_json_non_objects() { + // Test that non-object patch replaces base + let base = json!({ + "a": 1, + "b": 2 + }); + let patch = json!("replacement"); + + let merged = merge_json(base, patch.clone()); + + assert_eq!(merged, patch); + } + + #[test] + fn test_merge_json_nested_not_recursive() { + // Test that merge is shallow, not recursive + let base = json!({ + "nested": { + "a": 1, + "b": 2 + } + }); + let patch = json!({ + "nested": { + "c": 3 + } + }); + + let merged = merge_json(base, patch); + + // Should replace entire nested object, not merge it + assert!(merged["nested"]["c"] == 3); + assert!(merged["nested"].get("a").is_none()); + assert!(merged["nested"].get("b").is_none()); + } + + #[test] + fn test_template_params_serialization() { + // Test that template-generated configs can be serialized + let config = + apply_template(ConnectorKind::OpenCode, "default", json!({"title": "Test"})).unwrap(); + + let json = serde_json::to_string(&config).expect("Failed to serialize"); + let deserialized: ConnectorConfig = + serde_json::from_str(&json).expect("Failed to deserialize"); + + assert_eq!(deserialized.kind, config.kind); + assert_eq!(deserialized.title, config.title); + assert_eq!(deserialized.params, config.params); + } + + // INTEG-008: Test ACP config serialization round-trip + #[test] + fn test_acp_config_serialization_toml_stdio() { + // Test ACP stdio config serialization to/from TOML + let config = ConnectorConfig { + id: Some("acp-stdio-test".to_string()), + kind: ConnectorKind::Acp, + owner: Some(uuid::Uuid::nil()), + title: Some("Test ACP Stdio".to_string()), + working_directory: None, + params: json!({ + "transport": { + "type": "stdio", + "command": "dirigate", + "args": ["serve", "--stdio"] + }, + "title": "Test ACP Stdio", + "protocol_version": 1, + "cwd": ".", + "retry": { + "max_retries": 5, + "retry_delays_ms": [1000, 3000, 5000, 5000, 5000] + } + }), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + // Serialize to TOML + let toml_str = toml::to_string(&config).expect("Failed to serialize to TOML"); + + // Deserialize from TOML + let deserialized: ConnectorConfig = + toml::from_str(&toml_str).expect("Failed to deserialize from TOML"); + + // Verify round-trip + assert_eq!(deserialized.id, config.id); + assert_eq!(deserialized.kind, config.kind); + assert_eq!(deserialized.owner, config.owner); + assert_eq!(deserialized.title, config.title); + assert_eq!(deserialized.params["transport"]["command"], "dirigate"); + assert_eq!(deserialized.params["protocol_version"], 1); + } + + #[test] + fn test_acp_config_serialization_json_http() { + // Test ACP HTTP config serialization to/from JSON + let config = ConnectorConfig { + id: Some("acp-http-test".to_string()), + kind: ConnectorKind::Acp, + owner: Some(uuid::Uuid::nil()), + title: Some("Test ACP HTTP".to_string()), + working_directory: None, + params: json!({ + "transport": { + "type": "http", + "base_url": "http://localhost:3000", + "timeout_ms": 30000 + }, + "title": "Test ACP HTTP", + "protocol_version": 1, + "cwd": ".", + "retry": { + "max_retries": 5, + "retry_delays_ms": [1000, 3000, 5000, 5000, 5000] + } + }), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + // Serialize to JSON + let json_str = serde_json::to_string(&config).expect("Failed to serialize to JSON"); + + // Deserialize from JSON + let deserialized: ConnectorConfig = + serde_json::from_str(&json_str).expect("Failed to deserialize from JSON"); + + // Verify round-trip + assert_eq!(deserialized.id, config.id); + assert_eq!(deserialized.kind, config.kind); + assert_eq!(deserialized.owner, config.owner); + assert_eq!(deserialized.title, config.title); + assert_eq!( + deserialized.params["transport"]["base_url"], + "http://localhost:3000" + ); + assert_eq!(deserialized.params["transport"]["timeout_ms"], 30000); + } + + #[test] + fn test_core_config_with_acp_connectors_toml() { + // Test full CoreConfig with ACP connectors serialized to TOML + let acp_stdio_config = ConnectorConfig { + id: Some("acp-1".to_string()), + kind: ConnectorKind::Acp, + owner: Some(uuid::Uuid::nil()), + title: Some("ACP Stdio".to_string()), + working_directory: None, + params: json!({ + "transport": { + "type": "stdio", + "command": "claude", + "args": ["--acp"] + }, + "title": "ACP Stdio", + "protocol_version": 1, + "cwd": ".", + "retry": { + "max_retries": 5, + "retry_delays_ms": [1000, 3000, 5000, 5000, 5000] + } + }), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let acp_http_config = ConnectorConfig { + id: Some("acp-2".to_string()), + kind: ConnectorKind::Acp, + owner: Some(uuid::Uuid::nil()), + title: Some("ACP HTTP".to_string()), + working_directory: None, + params: json!({ + "transport": { + "type": "http", + "base_url": "http://localhost:3000", + "timeout_ms": 30000 + }, + "title": "ACP HTTP", + "protocol_version": 1, + "cwd": ".", + "retry": { + "max_retries": 5, + "retry_delays_ms": [1000, 3000, 5000, 5000, 5000] + } + }), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let config = CoreConfig { + connectors: vec![acp_stdio_config, acp_http_config], + ..CoreConfig::default() + }; + + // Serialize to TOML + let toml_str = toml::to_string_pretty(&config).expect("Failed to serialize to TOML"); + + // Deserialize from TOML + let deserialized: CoreConfig = + toml::from_str(&toml_str).expect("Failed to deserialize from TOML"); + + // Verify round-trip + assert_eq!(deserialized.connectors.len(), 2); + assert_eq!(deserialized.connectors[0].kind, ConnectorKind::Acp); + assert_eq!(deserialized.connectors[1].kind, ConnectorKind::Acp); + assert_eq!( + deserialized.connectors[0].params["transport"]["type"], + "stdio" + ); + assert_eq!( + deserialized.connectors[1].params["transport"]["type"], + "http" + ); + } + + #[test] + fn test_core_config_with_acp_connectors_json() { + // Test full CoreConfig with ACP connectors serialized to JSON + let acp_config = ConnectorConfig { + id: Some("acp-test".to_string()), + kind: ConnectorKind::Acp, + owner: Some(uuid::Uuid::nil()), + title: Some("Test ACP".to_string()), + working_directory: None, + params: json!({ + "transport": { + "type": "stdio", + "command": "test-command", + "args": [] + }, + "title": "Test ACP", + "protocol_version": 1, + "cwd": ".", + "retry": { + "max_retries": 5, + "retry_delays_ms": [1000, 3000, 5000, 5000, 5000] + } + }), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let config = CoreConfig { + connectors: vec![acp_config], + ..CoreConfig::default() + }; + + // Serialize to JSON + let json_str = serde_json::to_string_pretty(&config).expect("Failed to serialize to JSON"); + + // Deserialize from JSON + let deserialized: CoreConfig = + serde_json::from_str(&json_str).expect("Failed to deserialize from JSON"); + + // Verify round-trip + assert_eq!(deserialized.connectors.len(), 1); + assert_eq!(deserialized.connectors[0].kind, ConnectorKind::Acp); + assert_eq!( + deserialized.connectors[0].params["transport"]["command"], + "test-command" + ); + } + + /// Regression test: CoreConfig::default() must roundtrip through TOML. + /// This is the minimal config that gets saved on first launch. + #[test] + fn test_default_core_config_toml_roundtrip() { + let config = CoreConfig::default(); + let toml_str = toml::to_string_pretty(&config) + .expect("Default CoreConfig must serialize to TOML"); + let deserialized: CoreConfig = + toml::from_str(&toml_str).expect("Default CoreConfig must deserialize from TOML"); + assert_eq!(deserialized.runtime_working_dir, config.runtime_working_dir); + } + + /// Regression test: CoreConfig with tool_configuration must roundtrip through TOML. + /// ToolHandler unit variants (Agent, Deny, etc.) previously failed with + /// "unsupported unit type" when using adjacently-tagged serde representation. + #[test] + fn test_core_config_with_tool_configuration_toml() { + use crate::tools::{ToolConfiguration, ToolDirective, ToolHandler}; + + let mut tool_config = ToolConfiguration::new(); + tool_config.set(ToolDirective::passthrough("read_file")); + tool_config.set(ToolDirective::checked("shell_exec", ToolHandler::Deny)); + tool_config.set(ToolDirective::checked("editor_write", ToolHandler::Editor)); + tool_config.set(ToolDirective::checked( + "plugin_tool", + ToolHandler::Plugin { name: "my_plugin".to_string() }, + )); + + let connector = ConnectorConfig { + id: Some("test-acp".to_string()), + kind: ConnectorKind::Acp, + owner: None, + title: Some("Test ACP".to_string()), + working_directory: None, + params: json!({ + "transport": { "type": "stdio", "command": "echo", "args": [] }, + "title": "Test ACP" + }), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: Some(tool_config), + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let config = CoreConfig { + connectors: vec![connector], + ..CoreConfig::default() + }; + + let toml_str = toml::to_string_pretty(&config) + .expect("CoreConfig with tool_configuration must serialize to TOML"); + let deserialized: CoreConfig = + toml::from_str(&toml_str).expect("CoreConfig with tool_configuration must deserialize from TOML"); + assert_eq!(deserialized.connectors.len(), 1); + assert!(deserialized.connectors[0].tool_configuration.is_some()); + } + + /// Regression test: params containing JSON null values cause TOML serialization failure. + /// TOML has no null type — serde_json::Value::Null serializes as unit `()` + /// which toml crate rejects with "unsupported unit type". + /// The fix: config_manager::save_config strips null values before TOML serialization. + #[test] + fn test_core_config_with_null_params_toml_fails_without_sanitization() { + let connector = ConnectorConfig { + id: Some("test".to_string()), + kind: ConnectorKind::Acp, + owner: None, + title: Some("Test".to_string()), + working_directory: None, + params: json!({ + "transport": { "type": "stdio", "command": "echo", "args": [] }, + "initial_session": null, + "cwd": null + }), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let config = CoreConfig { + connectors: vec![connector], + ..CoreConfig::default() + }; + + // Raw TOML serialization fails with null values — this documents the known limitation + let err = toml::to_string_pretty(&config).unwrap_err(); + assert!(err.to_string().contains("unsupported"), "Expected 'unsupported unit type' error, got: {}", err); + } + + /// Regression test: after stripping null values, TOML serialization succeeds. + #[test] + fn test_core_config_with_null_params_toml_succeeds_after_sanitization() { + use crate::runtime::config_manager::strip_json_nulls; + + let connector = ConnectorConfig { + id: Some("test".to_string()), + kind: ConnectorKind::Acp, + owner: None, + title: Some("Test".to_string()), + working_directory: None, + params: json!({ + "transport": { "type": "stdio", "command": "echo", "args": [] }, + "initial_session": null, + "cwd": null, + "nested": { "a": 1, "b": null } + }), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let mut config = CoreConfig { + connectors: vec![connector], + ..CoreConfig::default() + }; + + // Strip nulls (same as config_manager::save_config does) + for c in &mut config.connectors { + strip_json_nulls(&mut c.params); + } + + let toml_str = toml::to_string_pretty(&config) + .expect("CoreConfig with sanitized params must serialize to TOML"); + let deserialized: CoreConfig = + toml::from_str(&toml_str).expect("Must deserialize from TOML"); + assert_eq!(deserialized.connectors.len(), 1); + // Verify nulls were stripped + assert!(deserialized.connectors[0].params.get("initial_session").is_none()); + assert!(deserialized.connectors[0].params.get("cwd").is_none()); + assert_eq!(deserialized.connectors[0].params["nested"]["a"], 1); + assert!(deserialized.connectors[0].params["nested"].get("b").is_none()); + } + + /// Regression test: CoreConfig with ACP server config must serialize to TOML. + /// archive_root and projects_root are deprecated — they are deserializable for + /// backward compatibility but must NOT appear in serialized output. + #[test] + fn test_core_config_with_acp_server_toml() { + let config = CoreConfig { + runtime_working_dir: PathBuf::from("."), + connectors: vec![], + archive_root: Some(PathBuf::from("/tmp/archive")), + archives: Vec::new(), + projects_root: Some(PathBuf::from("/tmp/projects")), + acp_server: Some(AcpServerConfig { + enabled: true, + port: Some(3001), + allowed_origins: None, + max_connections: 50, + default_connector_id: Some("default".to_string()), + }), + dismissed_zed_agents: vec!["old-agent".to_string()], + tasks: vec![], + accounts: HashMap::new(), + matrix: None, + streams: crate::sharing::StreamsConfig::default(), + config_source_path: None, + resolved_working_dir: None, + }; + + let toml_str = toml::to_string_pretty(&config) + .expect("CoreConfig with ACP server must serialize to TOML"); + + // Deprecated fields must NOT appear in serialized output + assert!( + !toml_str.contains("archive_root"), + "archive_root should be skipped during serialization (deprecated)" + ); + assert!( + !toml_str.contains("projects_root"), + "projects_root should be skipped during serialization (deprecated)" + ); + + let deserialized: CoreConfig = + toml::from_str(&toml_str).expect("CoreConfig with ACP server must deserialize from TOML"); + assert_eq!(deserialized.acp_server.as_ref().unwrap().port, Some(3001)); + // Deprecated fields are absent from serialized TOML, so they deserialize as None + assert_eq!(deserialized.archive_root, None); + assert_eq!(deserialized.projects_root, None); + assert_eq!(deserialized.dismissed_zed_agents, vec!["old-agent".to_string()]); + + // Verify backward compat: old TOML with archive_root still deserializes without error + let legacy_toml = r#" +runtime_working_dir = "." +connectors = [] +archive_root = "/tmp/archive" +projects_root = "/tmp/projects" +"#; + let legacy: CoreConfig = + toml::from_str(legacy_toml).expect("Legacy TOML with deprecated fields must deserialize"); + assert_eq!(legacy.archive_root, Some(PathBuf::from("/tmp/archive"))); + assert_eq!(legacy.projects_root, Some(PathBuf::from("/tmp/projects"))); + } + + /// Phase 4 Task 23: `[[streams]]` entries flatten onto CoreConfig.streams. + /// + /// The `streams` field uses `#[serde(flatten)]` and the inner + /// `StreamsConfig::entries` is renamed to `"streams"`, so `[[streams]]` + /// tables at the top level of dirigent.toml populate + /// `cfg.streams.entries`. + #[test] + fn test_core_config_parses_streams_block() { + let toml_str = r#" +runtime_working_dir = "." +connectors = [] + +[[streams]] +name = "lf" +type = "langfuse" + +[streams.scope] +kind = "archive_wide" +acknowledged = false + +[streams.params] +host = "http://example.com" +public_key = "pk" +secret_key = "sk" +"#; + let cfg: CoreConfig = + toml::from_str(toml_str).expect("CoreConfig with [[streams]] must parse"); + assert_eq!(cfg.streams.entries.len(), 1); + assert_eq!(cfg.streams.entries[0].name, "lf"); + assert_eq!(cfg.streams.entries[0].kind, "langfuse"); + assert!(cfg.streams.entries[0].enabled); + } + + /// Omitting `[[streams]]` entirely must still parse (empty vec default). + #[test] + fn test_core_config_no_streams_block_parses() { + let cfg: CoreConfig = toml::from_str( + r#" +runtime_working_dir = "." +connectors = [] +"#, + ) + .expect("CoreConfig without [[streams]] must parse"); + assert!(cfg.streams.entries.is_empty()); + } + + // Tests for working directory resolution + + #[test] + fn test_resolve_default_runtime_working_directory_explicit() { + // When working_directory is explicitly set, it should be used + let connector_config = ConnectorConfig { + id: None, + kind: ConnectorKind::Mock, + owner: None, + title: None, + working_directory: Some(PathBuf::from(".")), + params: json!({}), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let global_config = CoreConfig { + runtime_working_dir: PathBuf::from("/some/other/path"), + ..CoreConfig::default() + }; + + let resolved = resolve_default_runtime_working_directory(&connector_config, &global_config); + + // Should resolve to current directory (canonicalized) + assert!(resolved.exists()); + assert!(resolved.is_absolute()); + } + + #[test] + fn test_resolve_default_runtime_working_directory_from_project_dir() { + // When working_directory is not set, should use project_dir from global config + let connector_config = ConnectorConfig { + id: None, + kind: ConnectorKind::Mock, + owner: None, + title: None, + working_directory: None, + params: json!({}), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let global_config = CoreConfig::default(); + + let resolved = resolve_default_runtime_working_directory(&connector_config, &global_config); + + // Should resolve to project_dir (canonicalized) + assert!(resolved.exists()); + assert!(resolved.is_absolute()); + } + + #[test] + fn test_resolve_default_runtime_working_directory_nonexistent_path() { + // When path doesn't exist, should still return the path (with warning) + let connector_config = ConnectorConfig { + id: None, + kind: ConnectorKind::Mock, + owner: None, + title: None, + working_directory: Some(PathBuf::from("/this/path/does/not/exist")), + params: json!({}), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let global_config = CoreConfig::default(); + + let resolved = resolve_default_runtime_working_directory(&connector_config, &global_config); + + // Should return the path even if it doesn't exist + assert_eq!(resolved, PathBuf::from("/this/path/does/not/exist")); + } + + #[test] + fn test_connector_config_serialization_with_workdir() { + // Test that working_directory serializes/deserializes correctly + let config = ConnectorConfig { + id: Some("test-conn".to_string()), + kind: ConnectorKind::OpenCode, + owner: Some(uuid::Uuid::nil()), + title: Some("Test".to_string()), + working_directory: Some(PathBuf::from("/path/to/workdir")), + params: json!({"base_url": "http://localhost:12225"}), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: ConnectorConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.working_directory, config.working_directory); + } + + #[test] + fn test_connector_config_serialization_without_workdir() { + // Test that working_directory=None is skipped during serialization + let config = ConnectorConfig { + id: Some("test-conn".to_string()), + kind: ConnectorKind::OpenCode, + owner: Some(uuid::Uuid::nil()), + title: Some("Test".to_string()), + working_directory: None, + params: json!({"base_url": "http://localhost:12225"}), + icon_path: None, + show_type_overlay: false, + supported_features: vec![], + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + }; + + let json = serde_json::to_string(&config).unwrap(); + + // Should not contain "working_directory" field + assert!(!json.contains("working_directory")); + + let deserialized: ConnectorConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.working_directory, None); + } + + #[test] + fn test_load_config_no_env_returns_ok() { + let config = CoreConfig::load_config(None); + assert!(config.is_ok()); + } + + #[test] + fn test_effective_working_dir_returns_resolved_when_set() { + let mut config = CoreConfig::default(); + config.resolved_working_dir = Some(PathBuf::from("/resolved/path")); + assert_eq!(config.effective_working_dir(), std::path::Path::new("/resolved/path")); + } + + #[test] + fn test_effective_working_dir_falls_back_to_raw_when_unresolved() { + let config = CoreConfig::default(); + assert_eq!(config.effective_working_dir(), std::path::Path::new(".")); + } + + #[test] + fn test_resolve_paths_relative_joins_config_dir() { + let mut config = CoreConfig { + runtime_working_dir: PathBuf::from("my-project"), + ..CoreConfig::default() + }; + config.resolve_paths(Some(std::path::Path::new("/some/config/dir"))); + let resolved = config.resolved_working_dir.as_ref().expect("should be resolved"); + assert_eq!(resolved, &PathBuf::from("/some/config/dir/my-project")); + } + + #[test] + fn test_resolve_paths_absolute_unchanged() { + let mut config = CoreConfig { + runtime_working_dir: PathBuf::from("/absolute/project"), + ..CoreConfig::default() + }; + config.resolve_paths(Some(std::path::Path::new("/some/config/dir"))); + let resolved = config.resolved_working_dir.as_ref().expect("should be resolved"); + assert_eq!(resolved, &PathBuf::from("/absolute/project")); + } + + #[test] + fn test_resolve_paths_dot_becomes_noproject_home() { + let mut config = CoreConfig::default(); // runtime_working_dir = "." + config.resolve_paths(Some(std::path::Path::new("/some/config/dir"))); + let resolved = config.resolved_working_dir.as_ref().expect("should be resolved"); + // "." should resolve to noproject_home, not to the config dir + assert!( + resolved.to_string_lossy().contains("noproject_home"), + "Expected noproject_home, got: {:?}", + resolved + ); + } + + #[test] + fn test_resolve_paths_no_config_dir_dot_still_noproject() { + let mut config = CoreConfig::default(); // runtime_working_dir = "." + config.resolve_paths(None); + let resolved = config.resolved_working_dir.as_ref().expect("should be resolved"); + assert!( + resolved.to_string_lossy().contains("noproject_home"), + "Expected noproject_home, got: {:?}", + resolved + ); + } + + #[test] + fn test_load_config_does_not_mutate_runtime_working_dir() { + let dir = std::env::temp_dir().join("dirigent_test_no_mutate"); + let _ = std::fs::create_dir_all(&dir); + let config_path = dir.join("dirigent.toml"); + std::fs::write(&config_path, "runtime_working_dir = \".\"\nconnectors = []\n").unwrap(); + + let config = CoreConfig::load_config(Some(config_path.clone())).unwrap(); + + // The raw field must remain "." — never mutated + assert_eq!( + config.runtime_working_dir, + PathBuf::from("."), + "runtime_working_dir should stay as stored value, got: {:?}", + config.runtime_working_dir + ); + + // But the resolved field should be populated + assert!( + config.resolved_working_dir.is_some(), + "resolved_working_dir should be populated by load_config" + ); + + let _ = std::fs::remove_dir_all(&dir); + } + + #[test] + fn test_save_config_preserves_stored_value_not_resolved() { + // Simulate: load a config with ".", resolve it, serialize back + let mut config = CoreConfig { + runtime_working_dir: PathBuf::from("."), + ..CoreConfig::default() + }; + config.resolve_paths(Some(std::path::Path::new("/some/config/dir"))); + + // resolved_working_dir should be set (noproject_home or similar) + assert!(config.resolved_working_dir.is_some()); + + // Serialize to TOML + let toml_str = toml::to_string_pretty(&config) + .expect("CoreConfig must serialize to TOML"); + + // The TOML must contain the STORED value ".", not the resolved path + assert!( + toml_str.contains(r#"runtime_working_dir = ".""#), + "TOML should contain stored value '.', got:\n{}", + toml_str + ); + assert!( + !toml_str.contains("noproject_home"), + "TOML must NOT contain resolved noproject_home path:\n{}", + toml_str + ); + + // Round-trip: deserialize should give back "." + let reloaded: CoreConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(reloaded.runtime_working_dir, PathBuf::from(".")); + // resolved_working_dir is skip-serialized, so it's None after deserialize + assert!(reloaded.resolved_working_dir.is_none()); + } +} diff --git a/crates/dirigent_core/src/connectors/acceptor/acp_acceptor.rs b/crates/dirigent_core/src/connectors/acceptor/acp_acceptor.rs new file mode 100644 index 0000000..1a98e78 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acceptor/acp_acceptor.rs @@ -0,0 +1,561 @@ +//! ACP Acceptor implementation +//! +//! This module provides the `AcpAcceptor` struct which serves as the entry point +//! for incoming ACP connections. It tracks pending sessions from the ACP Server +//! and routes them to appropriate connectors. +//! +//! # Architecture +//! +//! The AcpAcceptor is a lightweight coordinator that: +//! - Maintains a list of pending incoming sessions +//! - Routes sessions to target connectors +//! - Optionally auto-routes new sessions to a default connector +//! +//! It does NOT: +//! - Process messages directly (delegates to connectors) +//! - Maintain direct connection to ACP clients (ACP Server does this) +//! - Store session history (connectors and archivist handle this) +//! +//! # Usage +//! +//! ```no_run +//! use dirigent_core::connectors::acceptor::AcpAcceptor; +//! +//! // Create acceptor with auto-routing disabled +//! let acceptor = AcpAcceptor::new("acp-acceptor-1".to_string(), "ACP Incoming".to_string()); +//! +//! // Or with auto-routing to a default connector +//! let acceptor = AcpAcceptor::builder("acp-acceptor-1".to_string(), "ACP Incoming".to_string()) +//! .with_default_connector("my-connector-id".to_string()) +//! .with_auto_routing(true) +//! .build(); +//! ``` + +use async_trait::async_trait; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{debug, info, warn}; + +use super::{Acceptor, AcceptorError, AcceptorId, IncomingSession, SessionRouting}; + +/// Configuration for AcpAcceptor +#[derive(Clone, Debug)] +pub struct AcpAcceptorConfig { + /// Unique identifier for this acceptor + pub id: AcceptorId, + + /// Human-readable title + pub title: String, + + /// Default connector ID for auto-routing + pub default_connector_id: Option, + + /// Whether to automatically route new sessions + pub auto_routing: bool, +} + +impl AcpAcceptorConfig { + /// Create a new config with defaults + pub fn new(id: AcceptorId, title: String) -> Self { + Self { + id, + title, + default_connector_id: None, + auto_routing: false, + } + } +} + +/// Builder for AcpAcceptor +pub struct AcpAcceptorBuilder { + config: AcpAcceptorConfig, +} + +impl AcpAcceptorBuilder { + /// Create a new builder + pub fn new(id: AcceptorId, title: String) -> Self { + Self { + config: AcpAcceptorConfig::new(id, title), + } + } + + /// Set the default connector for auto-routing + pub fn with_default_connector(mut self, connector_id: String) -> Self { + self.config.default_connector_id = Some(connector_id); + self + } + + /// Enable or disable auto-routing + pub fn with_auto_routing(mut self, enabled: bool) -> Self { + self.config.auto_routing = enabled; + self + } + + /// Build the AcpAcceptor + pub fn build(self) -> AcpAcceptor { + AcpAcceptor::from_config(self.config) + } +} + +/// ACP Acceptor for handling incoming ACP connections +/// +/// This struct manages incoming sessions from the ACP Server and routes them +/// to appropriate connectors for processing. +pub struct AcpAcceptor { + /// Unique identifier for this acceptor + id: AcceptorId, + + /// Human-readable title + title: String, + + /// Pending sessions awaiting routing + /// + /// Key: session_id, Value: IncomingSession + pending_sessions: Arc>>, + + /// Session routings (for sessions that have been routed) + /// + /// Key: session_id, Value: SessionRouting + session_routings: Arc>>, + + /// Default connector ID for auto-routing and accept operations + default_connector_id: Arc>>, + + /// Whether auto-routing is enabled + auto_routing: Arc>, +} + +impl AcpAcceptor { + /// Create a new AcpAcceptor with basic settings + pub fn new(id: AcceptorId, title: String) -> Self { + Self { + id, + title, + pending_sessions: Arc::new(RwLock::new(HashMap::new())), + session_routings: Arc::new(RwLock::new(HashMap::new())), + default_connector_id: Arc::new(RwLock::new(None)), + auto_routing: Arc::new(RwLock::new(false)), + } + } + + /// Create a builder for more configuration options + pub fn builder(id: AcceptorId, title: String) -> AcpAcceptorBuilder { + AcpAcceptorBuilder::new(id, title) + } + + /// Create from config + fn from_config(config: AcpAcceptorConfig) -> Self { + Self { + id: config.id, + title: config.title, + pending_sessions: Arc::new(RwLock::new(HashMap::new())), + session_routings: Arc::new(RwLock::new(HashMap::new())), + default_connector_id: Arc::new(RwLock::new(config.default_connector_id)), + auto_routing: Arc::new(RwLock::new(config.auto_routing)), + } + } + + /// Add an incoming session to the pending list + /// + /// This is called by the ACP Server when a new session is created. + /// If auto-routing is enabled and a default connector is configured, + /// the session will be automatically routed. + /// + /// # Returns + /// + /// - `Ok(Some(SessionRouting))` if auto-routed + /// - `Ok(None)` if added to pending list + pub async fn add_incoming_session( + &self, + session: IncomingSession, + ) -> Result, AcceptorError> { + let session_id = session.session_id.clone(); + + // Check if auto-routing should happen + let auto_routing = *self.auto_routing.read().await; + let default_connector = self.default_connector_id.read().await.clone(); + + if auto_routing { + if let Some(connector_id) = default_connector { + info!( + session_id = %session_id, + connector_id = %connector_id, + "Auto-routing incoming session" + ); + + // Create routing directly (skip pending) + let routing = SessionRouting::new(session_id.clone(), connector_id); + let mut routings = self.session_routings.write().await; + routings.insert(session_id, routing.clone()); + + return Ok(Some(routing)); + } else { + warn!( + session_id = %session_id, + "Auto-routing enabled but no default connector configured" + ); + } + } + + // Add to pending sessions + info!( + session_id = %session_id, + client_id = %session.client_id, + "Added incoming session to pending list" + ); + + let mut pending = self.pending_sessions.write().await; + pending.insert(session_id, session); + + Ok(None) + } + + /// Remove a session from tracking entirely + /// + /// This is called when a session is closed or the client disconnects. + pub async fn remove_session(&self, session_id: &str) { + let mut pending = self.pending_sessions.write().await; + pending.remove(session_id); + + let mut routings = self.session_routings.write().await; + routings.remove(session_id); + + debug!(session_id = %session_id, "Removed session from acceptor tracking"); + } + + /// Set the default connector ID + pub async fn set_default_connector(&self, connector_id: Option) { + let mut default = self.default_connector_id.write().await; + *default = connector_id; + } + + /// Enable or disable auto-routing + pub async fn set_auto_routing(&self, enabled: bool) { + let mut auto = self.auto_routing.write().await; + *auto = enabled; + } + + /// Get the count of pending sessions + pub async fn pending_count(&self) -> usize { + self.pending_sessions.read().await.len() + } + + /// Get the count of routed sessions + pub async fn routed_count(&self) -> usize { + self.session_routings.read().await.len() + } +} + +#[async_trait] +impl Acceptor for AcpAcceptor { + fn id(&self) -> &AcceptorId { + &self.id + } + + fn title(&self) -> &str { + &self.title + } + + async fn incoming_sessions(&self) -> Vec { + let pending = self.pending_sessions.read().await; + let mut sessions: Vec = pending.values().cloned().collect(); + + // Sort by created_at (newest first) + sessions.sort_by(|a, b| b.created_at.cmp(&a.created_at)); + + sessions + } + + async fn accept_session(&self, session_id: &str) -> Result { + // Get the default connector + let default_connector = self.default_connector_id.read().await.clone(); + + match default_connector { + Some(connector_id) => self.route_to_connector(session_id, &connector_id).await, + None => Err(AcceptorError::Internal( + "No default connector configured for accept operation".to_string(), + )), + } + } + + async fn reject_session(&self, session_id: &str) -> Result<(), AcceptorError> { + let mut pending = self.pending_sessions.write().await; + + if pending.remove(session_id).is_some() { + info!(session_id = %session_id, "Rejected incoming session"); + Ok(()) + } else { + // Check if it was already routed + let routings = self.session_routings.read().await; + if routings.contains_key(session_id) { + Err(AcceptorError::AlreadyRouted(session_id.to_string())) + } else { + Err(AcceptorError::SessionNotFound(session_id.to_string())) + } + } + } + + async fn route_to_connector( + &self, + session_id: &str, + connector_id: &str, + ) -> Result { + // Remove from pending + let mut pending = self.pending_sessions.write().await; + + if pending.remove(session_id).is_none() { + // Check if already routed + let routings = self.session_routings.read().await; + if routings.contains_key(session_id) { + return Err(AcceptorError::AlreadyRouted(session_id.to_string())); + } + return Err(AcceptorError::SessionNotFound(session_id.to_string())); + } + + // Create and store routing + let routing = SessionRouting::new(session_id.to_string(), connector_id.to_string()); + + let mut routings = self.session_routings.write().await; + routings.insert(session_id.to_string(), routing.clone()); + + info!( + session_id = %session_id, + connector_id = %connector_id, + "Routed session to connector" + ); + + Ok(routing) + } + + async fn get_session_routing(&self, session_id: &str) -> Option { + let routings = self.session_routings.read().await; + routings.get(session_id).cloned() + } + + fn auto_routing_enabled(&self) -> bool { + // Use try_read for non-blocking access in sync context + self.auto_routing + .try_read() + .map(|guard| *guard) + .unwrap_or(false) + } + + fn default_connector_id(&self) -> Option<&str> { + // This is tricky because we can't return a reference to data behind RwLock + // In practice, callers should use the async methods or get_session_routing + // For now, return None - the async version should be preferred + None + } +} + +/// Extension methods for getting routing info asynchronously +impl AcpAcceptor { + /// Get the default connector ID asynchronously + pub async fn default_connector_id_async(&self) -> Option { + self.default_connector_id.read().await.clone() + } + + /// Check if auto-routing is enabled asynchronously + pub async fn auto_routing_enabled_async(&self) -> bool { + *self.auto_routing.read().await + } + + /// Transfer a session from one connector to another + /// + /// This updates the routing to point to a new connector. + /// The session must already be routed. + /// + /// # Arguments + /// + /// * `session_id` - The session to transfer + /// * `to_connector_id` - The new target connector + /// + /// # Returns + /// + /// The new routing information + pub async fn transfer_session( + &self, + session_id: &str, + to_connector_id: &str, + ) -> Result { + let mut routings = self.session_routings.write().await; + + // Verify session exists and is routed + if !routings.contains_key(session_id) { + // Check if it's pending + let pending = self.pending_sessions.read().await; + if pending.contains_key(session_id) { + return Err(AcceptorError::Internal( + "Session is pending, not routed. Use route_to_connector instead.".to_string(), + )); + } + return Err(AcceptorError::SessionNotFound(session_id.to_string())); + } + + // Update the routing + let routing = SessionRouting::new(session_id.to_string(), to_connector_id.to_string()); + routings.insert(session_id.to_string(), routing.clone()); + + info!( + session_id = %session_id, + to_connector_id = %to_connector_id, + "Transferred session to new connector" + ); + + Ok(routing) + } + + /// Get all session routings + pub async fn all_routings(&self) -> HashMap { + self.session_routings.read().await.clone() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_acp_acceptor_creation() { + let acceptor = AcpAcceptor::new("acceptor-1".to_string(), "Test Acceptor".to_string()); + + assert_eq!(acceptor.id(), "acceptor-1"); + assert_eq!(acceptor.title(), "Test Acceptor"); + assert_eq!(acceptor.pending_count().await, 0); + assert_eq!(acceptor.routed_count().await, 0); + } + + #[tokio::test] + async fn test_add_incoming_session() { + let acceptor = AcpAcceptor::new("acceptor-1".to_string(), "Test Acceptor".to_string()); + + let session = + IncomingSession::new("session-1".to_string(), "client-1".to_string()); + + let result = acceptor.add_incoming_session(session).await; + assert!(result.is_ok()); + assert!(result.unwrap().is_none()); // No auto-routing + + assert_eq!(acceptor.pending_count().await, 1); + + let pending = acceptor.incoming_sessions().await; + assert_eq!(pending.len(), 1); + assert_eq!(pending[0].session_id, "session-1"); + } + + #[tokio::test] + async fn test_auto_routing() { + let acceptor = AcpAcceptor::builder("acceptor-1".to_string(), "Test Acceptor".to_string()) + .with_default_connector("default-connector".to_string()) + .with_auto_routing(true) + .build(); + + let session = + IncomingSession::new("session-1".to_string(), "client-1".to_string()); + + let result = acceptor.add_incoming_session(session).await; + assert!(result.is_ok()); + + let routing = result.unwrap(); + assert!(routing.is_some()); + let routing = routing.unwrap(); + assert_eq!(routing.session_id, "session-1"); + assert_eq!(routing.connector_id, "default-connector"); + + // Session should not be in pending + assert_eq!(acceptor.pending_count().await, 0); + // Session should be in routings + assert_eq!(acceptor.routed_count().await, 1); + } + + #[tokio::test] + async fn test_route_to_connector() { + let acceptor = AcpAcceptor::new("acceptor-1".to_string(), "Test Acceptor".to_string()); + + let session = + IncomingSession::new("session-1".to_string(), "client-1".to_string()); + acceptor.add_incoming_session(session).await.unwrap(); + + let routing = acceptor + .route_to_connector("session-1", "my-connector") + .await + .unwrap(); + + assert_eq!(routing.session_id, "session-1"); + assert_eq!(routing.connector_id, "my-connector"); + + // Session should be removed from pending + assert_eq!(acceptor.pending_count().await, 0); + // Session should be in routings + assert_eq!(acceptor.routed_count().await, 1); + } + + #[tokio::test] + async fn test_reject_session() { + let acceptor = AcpAcceptor::new("acceptor-1".to_string(), "Test Acceptor".to_string()); + + let session = + IncomingSession::new("session-1".to_string(), "client-1".to_string()); + acceptor.add_incoming_session(session).await.unwrap(); + + assert_eq!(acceptor.pending_count().await, 1); + + acceptor.reject_session("session-1").await.unwrap(); + + assert_eq!(acceptor.pending_count().await, 0); + } + + #[tokio::test] + async fn test_transfer_session() { + let acceptor = AcpAcceptor::new("acceptor-1".to_string(), "Test Acceptor".to_string()); + + // Add and route a session + let session = + IncomingSession::new("session-1".to_string(), "client-1".to_string()); + acceptor.add_incoming_session(session).await.unwrap(); + acceptor.route_to_connector("session-1", "connector-1") + .await + .unwrap(); + + // Transfer to new connector + let new_routing = acceptor + .transfer_session("session-1", "connector-2") + .await + .unwrap(); + + assert_eq!(new_routing.connector_id, "connector-2"); + + // Verify the routing was updated + let routing = acceptor.get_session_routing("session-1").await.unwrap(); + assert_eq!(routing.connector_id, "connector-2"); + } + + #[tokio::test] + async fn test_error_cases() { + let acceptor = AcpAcceptor::new("acceptor-1".to_string(), "Test Acceptor".to_string()); + + // Route non-existent session + let result = acceptor.route_to_connector("non-existent", "connector").await; + assert!(matches!(result, Err(AcceptorError::SessionNotFound(_)))); + + // Reject non-existent session + let result = acceptor.reject_session("non-existent").await; + assert!(matches!(result, Err(AcceptorError::SessionNotFound(_)))); + + // Add session, route it, then try to route again + let session = + IncomingSession::new("session-1".to_string(), "client-1".to_string()); + acceptor.add_incoming_session(session).await.unwrap(); + acceptor.route_to_connector("session-1", "connector-1") + .await + .unwrap(); + + let result = acceptor.route_to_connector("session-1", "connector-2").await; + assert!(matches!(result, Err(AcceptorError::AlreadyRouted(_)))); + + // Reject already routed session + let result = acceptor.reject_session("session-1").await; + assert!(matches!(result, Err(AcceptorError::AlreadyRouted(_)))); + } +} diff --git a/crates/dirigent_core/src/connectors/acceptor/mod.rs b/crates/dirigent_core/src/connectors/acceptor/mod.rs new file mode 100644 index 0000000..c6c8c52 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acceptor/mod.rs @@ -0,0 +1,402 @@ +//! Acceptor abstraction layer for incoming connections +//! +//! This module provides the `Acceptor` trait that represents an entry point for incoming +//! ACP connections. Unlike Connectors which initiate outbound connections to agents, +//! Acceptors accept inbound connections from external clients. +//! +//! # Architecture +//! +//! An Acceptor: +//! - Accepts incoming connections from ACP clients +//! - Tracks pending sessions awaiting routing decisions +//! - Routes sessions to target Connectors for processing +//! - Does NOT process messages itself (delegates to Connectors) +//! +//! # Session Lifecycle +//! +//! 1. Client connects to ACP Server and creates a session +//! 2. Session appears as "pending" in the Acceptor +//! 3. User (or auto-routing) assigns session to a Connector +//! 4. Once routed, the Connector handles all message processing +//! +//! # Usage +//! +//! ```no_run +//! use dirigent_core::connectors::acceptor::{Acceptor, IncomingSession}; +//! +//! async fn example(acceptor: impl Acceptor) { +//! // Get pending sessions +//! let pending = acceptor.incoming_sessions().await; +//! +//! for session in pending { +//! // Route to a connector +//! acceptor.route_to_connector(&session.session_id, "my-connector-id").await?; +//! } +//! } +//! ``` + +//! Note: This module requires the "server" feature flag. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use thiserror::Error; + +pub mod acp_acceptor; + +// Re-export AcpAcceptor for convenience +pub use acp_acceptor::AcpAcceptor; + +/// Unique identifier for an acceptor instance +pub type AcceptorId = String; + +/// Information about an incoming session awaiting routing +/// +/// When an ACP client connects and creates a session through the ACP Server, +/// that session is initially "pending" - it exists but hasn't been assigned +/// to a specific Connector for processing. +/// +/// This struct contains the metadata needed to display pending sessions +/// in the UI and make routing decisions. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct IncomingSession { + /// Unique session identifier (from the ACP Server) + pub session_id: String, + + /// Client identifier (who created this session) + pub client_id: String, + + /// Optional client metadata (capabilities, name, version, etc.) + /// + /// This is typically the client info provided during the ACP `initialize` + /// handshake. It can include: + /// - client_name: Name of the client application + /// - client_version: Version string + /// - capabilities: Supported features + #[serde(skip_serializing_if = "Option::is_none")] + pub client_info: Option, + + /// When this session was created + pub created_at: DateTime, + + /// Optional title for display + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, +} + +impl IncomingSession { + /// Create a new IncomingSession + pub fn new(session_id: String, client_id: String) -> Self { + Self { + session_id, + client_id, + client_info: None, + created_at: Utc::now(), + title: None, + } + } + + /// Create with client info + pub fn with_client_info(mut self, info: serde_json::Value) -> Self { + self.client_info = Some(info); + self + } + + /// Create with title + pub fn with_title(mut self, title: String) -> Self { + self.title = Some(title); + self + } +} + +/// Result of successfully routing a session +/// +/// When a session is routed to a connector, this struct provides +/// information about the mapping that was created. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SessionRouting { + /// The incoming session ID (from client) + pub session_id: String, + + /// The connector this session was routed to + pub connector_id: String, + + /// When the routing was established + pub routed_at: DateTime, +} + +impl SessionRouting { + /// Create a new SessionRouting + pub fn new(session_id: String, connector_id: String) -> Self { + Self { + session_id, + connector_id, + routed_at: Utc::now(), + } + } +} + +/// Error type for acceptor operations +#[derive(Clone, Debug, Error)] +pub enum AcceptorError { + /// Session not found in pending list + #[error("Session not found: {0}")] + SessionNotFound(String), + + /// Connector not found or not available + #[error("Connector not available: {0}")] + ConnectorNotAvailable(String), + + /// Session already routed to a connector + #[error("Session already routed: {0}")] + AlreadyRouted(String), + + /// Internal error + #[error("Internal error: {0}")] + Internal(String), +} + +/// Trait for incoming connection handlers (Acceptors) +/// +/// An Acceptor represents an entry point for incoming ACP connections. +/// It tracks pending sessions and routes them to appropriate Connectors. +/// +/// # Object Safety +/// +/// This trait is designed to be object-safe, allowing for `dyn Acceptor` usage. +/// All async methods use `async_trait` for compatibility. +#[async_trait] +pub trait Acceptor: Send + Sync { + /// Get the unique identifier for this acceptor + fn id(&self) -> &AcceptorId; + + /// Get the human-readable title for this acceptor + fn title(&self) -> &str; + + /// Get all pending incoming sessions + /// + /// Returns sessions that have been created but not yet routed to a connector. + /// Once a session is routed, it no longer appears in this list. + async fn incoming_sessions(&self) -> Vec; + + /// Accept a session and route it to a connector + /// + /// This is a convenience method that routes the session to the default + /// connector (if one is configured) or marks it as accepted for manual routing. + /// + /// # Arguments + /// + /// * `session_id` - The ID of the pending session to accept + /// + /// # Returns + /// + /// The routing information if successful, or an error if the session + /// was not found or routing failed. + async fn accept_session(&self, session_id: &str) -> Result; + + /// Reject a session and remove it from the pending list + /// + /// This notifies the client that their session was rejected and + /// removes it from the pending sessions list. + /// + /// # Arguments + /// + /// * `session_id` - The ID of the pending session to reject + async fn reject_session(&self, session_id: &str) -> Result<(), AcceptorError>; + + /// Route a session to a specific connector + /// + /// Once routed, the connector will handle all message processing for this + /// session. The session is removed from the pending list. + /// + /// # Arguments + /// + /// * `session_id` - The ID of the pending session + /// * `connector_id` - The ID of the connector to route to + /// + /// # Returns + /// + /// The routing information if successful, or an error if: + /// - The session was not found + /// - The connector was not found or not available + /// - The session was already routed + async fn route_to_connector( + &self, + session_id: &str, + connector_id: &str, + ) -> Result; + + /// Get the current routing for a session (if routed) + /// + /// Returns None if the session is still pending, or the routing + /// information if it has been assigned to a connector. + async fn get_session_routing(&self, session_id: &str) -> Option; + + /// Check if auto-routing is enabled + /// + /// When auto-routing is enabled, new incoming sessions are automatically + /// routed to the default connector without manual intervention. + fn auto_routing_enabled(&self) -> bool; + + /// Get the default connector ID for auto-routing + /// + /// Returns the connector ID that new sessions should be automatically + /// routed to, if auto-routing is enabled. + fn default_connector_id(&self) -> Option<&str>; +} + +/// Summary information about an acceptor for UI display +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct AcceptorSummary { + /// Unique acceptor identifier + pub id: AcceptorId, + + /// Human-readable title for display + pub title: String, + + /// Number of pending incoming sessions + pub pending_sessions_count: usize, + + /// Whether auto-routing is enabled + pub auto_routing_enabled: bool, + + /// Default connector ID (if configured) + #[serde(skip_serializing_if = "Option::is_none")] + pub default_connector_id: Option, +} + +/// Handle to an acceptor instance +/// +/// This provides a cloneable reference to an acceptor that can be shared +/// across async tasks. +#[derive(Clone)] +pub struct AcceptorHandle { + /// The underlying acceptor implementation + acceptor: Arc, +} + +impl AcceptorHandle { + /// Create a new AcceptorHandle wrapping an acceptor implementation + pub fn new(acceptor: impl Acceptor + 'static) -> Self { + Self { + acceptor: Arc::new(acceptor), + } + } + + /// Create from an Arc-wrapped acceptor + pub fn from_arc(acceptor: Arc) -> Self { + Self { acceptor } + } + + /// Get a reference to the underlying acceptor + pub fn as_ref(&self) -> &dyn Acceptor { + self.acceptor.as_ref() + } + + /// Get the acceptor ID + pub fn id(&self) -> &AcceptorId { + self.acceptor.id() + } + + /// Get the acceptor title + pub fn title(&self) -> &str { + self.acceptor.title() + } + + /// Get pending incoming sessions + pub async fn incoming_sessions(&self) -> Vec { + self.acceptor.incoming_sessions().await + } + + /// Accept a session + pub async fn accept_session(&self, session_id: &str) -> Result { + self.acceptor.accept_session(session_id).await + } + + /// Reject a session + pub async fn reject_session(&self, session_id: &str) -> Result<(), AcceptorError> { + self.acceptor.reject_session(session_id).await + } + + /// Route a session to a connector + pub async fn route_to_connector( + &self, + session_id: &str, + connector_id: &str, + ) -> Result { + self.acceptor.route_to_connector(session_id, connector_id).await + } + + /// Get session routing + pub async fn get_session_routing(&self, session_id: &str) -> Option { + self.acceptor.get_session_routing(session_id).await + } + + /// Check if auto-routing is enabled + pub fn auto_routing_enabled(&self) -> bool { + self.acceptor.auto_routing_enabled() + } + + /// Get default connector ID + pub fn default_connector_id(&self) -> Option<&str> { + self.acceptor.default_connector_id() + } + + /// Get a summary of this acceptor for UI display + pub async fn summary(&self) -> AcceptorSummary { + let pending = self.incoming_sessions().await; + AcceptorSummary { + id: self.id().clone(), + title: self.title().to_string(), + pending_sessions_count: pending.len(), + auto_routing_enabled: self.auto_routing_enabled(), + default_connector_id: self.default_connector_id().map(String::from), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_incoming_session_creation() { + let session = IncomingSession::new("session-1".to_string(), "client-1".to_string()); + + assert_eq!(session.session_id, "session-1"); + assert_eq!(session.client_id, "client-1"); + assert!(session.client_info.is_none()); + assert!(session.title.is_none()); + } + + #[test] + fn test_incoming_session_with_info() { + let session = IncomingSession::new("session-1".to_string(), "client-1".to_string()) + .with_client_info(serde_json::json!({ + "client_name": "test-client", + "version": "1.0.0" + })) + .with_title("Test Session".to_string()); + + assert!(session.client_info.is_some()); + assert_eq!(session.title, Some("Test Session".to_string())); + } + + #[test] + fn test_session_routing() { + let routing = SessionRouting::new("session-1".to_string(), "connector-1".to_string()); + + assert_eq!(routing.session_id, "session-1"); + assert_eq!(routing.connector_id, "connector-1"); + } + + #[test] + fn test_acceptor_error_display() { + let error = AcceptorError::SessionNotFound("session-1".to_string()); + assert!(error.to_string().contains("session-1")); + + let error = AcceptorError::ConnectorNotAvailable("connector-1".to_string()); + assert!(error.to_string().contains("connector-1")); + } +} diff --git a/crates/dirigent_core/src/connectors/acp/config.rs b/crates/dirigent_core/src/connectors/acp/config.rs new file mode 100644 index 0000000..30d5014 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/config.rs @@ -0,0 +1,509 @@ +//! ACP Connector configuration +//! +//! This module provides configuration types for ACP connectors, including +//! transport selection (Stdio vs HTTP), feature support, and validation logic. + +use dirigent_protocol::SessionOwnership; +use dirigent_tools::EmbeddingConfig; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +/// Well-known agent types for ACP connectors. +/// +/// This identifies the specific agent implementation behind an ACP connector, +/// enabling automatic mode/model mapping during session transfers. +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq, Copy)] +#[serde(rename_all = "lowercase")] +pub enum ConnectorAgentType { + /// Custom or unknown agent (no automatic mode/model mapping) + #[default] + Custom, + /// Anthropic Claude Code / Claude API + Claude, + /// OpenAI Codex / ChatGPT Code Interpreter + Codex, + /// Google Gemini Code Assist + Gemini, +} + +impl ConnectorAgentType { + /// Parse from string magic word (case-insensitive). + /// + /// Supports aliases like "openai" for Codex and "google" for Gemini. + pub fn from_magic_word(word: &str) -> Option { + match word.to_lowercase().as_str() { + "claude" => Some(Self::Claude), + "codex" | "openai" => Some(Self::Codex), + "gemini" | "google" => Some(Self::Gemini), + _ => None, + } + } +} + +/// Well-known ACP feature identifiers +/// +/// These constants define standard features that ACP agents may support. +/// Use these when configuring `supported_features` in `AcpConfig`. +pub mod features { + /// Agent supports `session/load` to resume existing sessions + /// + /// Without this feature, archived sessions cannot be loaded back into the agent. + /// The UI should disable "load from archive" functionality. + pub const SESSION_RESUME: &str = "session_resume"; + + /// Agent supports `session/list` to enumerate available sessions + /// + /// Without this feature, the connector cannot list sessions from the agent. + pub const SESSION_LIST: &str = "session_list"; + + /// Agent provides message history when loading sessions + /// + /// Without this feature, loaded sessions start empty (no replay). + pub const MESSAGE_HISTORY: &str = "message_history"; + + /// Agent supports cancellation via `session/cancel` + pub const CANCELLATION: &str = "cancellation"; + + /// Agent supports model selection + pub const MODEL_SELECTION: &str = "model_selection"; + + /// Agent supports mode selection (e.g., plan mode, bypass permissions) + pub const MODE_SELECTION: &str = "mode_selection"; +} + +/// Transport mechanism for ACP communication +/// +/// ACP supports multiple transport mechanisms for different deployment scenarios: +/// - Stdio: For locally spawned agent processes +/// - Http: For remote agents accessible via HTTP/SSE +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum TransportKind { + /// Stdio transport (spawn process, communicate via stdin/stdout) + /// + /// # Fields + /// + /// * `command` - Executable path or name (resolved via PATH) + /// * `args` - Command-line arguments to pass to the process + /// * `cwd` - Optional working directory for the process + /// * `env` - Optional environment variables + Stdio { + /// Command to execute (e.g., "dirigate", "/path/to/agent") + command: String, + /// Command-line arguments + #[serde(default)] + args: Vec, + /// Working directory (defaults to current directory) + #[serde(skip_serializing_if = "Option::is_none")] + cwd: Option, + /// Environment variables + #[serde(default, skip_serializing_if = "Vec::is_empty")] + env: Vec<(String, String)>, + }, + + /// HTTP transport (connect to remote agent via HTTP/SSE) + /// + /// # Fields + /// + /// * `base_url` - Base URL of the ACP agent HTTP endpoint + /// * `timeout_ms` - Optional request timeout in milliseconds + Http { + /// Base URL (e.g., "http://localhost:3000", "https://agent.example.com") + base_url: String, + /// Request timeout in milliseconds (default: 30000) + #[serde(skip_serializing_if = "Option::is_none")] + timeout_ms: Option, + }, +} + +/// Configuration for ACP connector +/// +/// Contains all the information needed to create and connect to an ACP agent, +/// including transport configuration and operational parameters. +/// +/// Note: Title and supported_features are now orchestration-level fields +/// stored in ConnectorConfig, not here. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AcpConfig { + /// Transport configuration + /// + /// Determines how to connect to the ACP agent (stdio process or HTTP endpoint). + pub transport: TransportKind, + + /// Protocol version to use (default: 1) + /// + /// Specifies which version of the ACP protocol to negotiate during initialization. + #[serde(default = "default_protocol_version")] + pub protocol_version: u32, + + /// Current working directory for sessions (default: ".") + /// + /// Passed to the agent when creating new sessions. Determines the file system + /// context for tools and operations. + #[serde(default = "default_cwd")] + pub cwd: String, + + /// Connection retry configuration + /// + /// Controls how the connector behaves when connection attempts fail. + #[serde(default)] + pub retry: RetryConfig, + + /// File embedding configuration + /// + /// Controls how files are embedded in prompts (size limits, redaction, etc.) + #[serde(default)] + pub embedding: EmbeddingConfig, + + /// Default ownership for capability negotiation (used during initialize) + /// + /// This ownership model determines what capabilities are advertised to the agent + /// during the initialize handshake. For UI-created connectors, this defaults to + /// internal ownership (empty capabilities). For connectors created by ACP Acceptor + /// (incoming external clients), this should be set to external_forwarded with the + /// client's capabilities. + #[serde(default)] + pub default_ownership: SessionOwnership, + + /// Override directory for ACP protocol logging (optional) + /// + /// All JSON-RPC messages (stdin/stdout) are logged to JSONL files. + /// Defaults to `data_dir/logs/acp/` when not set. Set this to override + /// the default location. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub acp_log_dir: Option, + + /// Agent type for automatic mode/model mapping. + /// + /// When set to a specific agent type (Claude, Codex, Gemini), the system + /// will automatically translate Gateway mode/model identifiers to + /// agent-specific identifiers during session transfers. + #[serde(default)] + pub agent_type: ConnectorAgentType, +} + +/// Retry configuration for connection failures +/// +/// Controls reconnection behavior when the transport fails or disconnects. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RetryConfig { + /// Maximum number of retry attempts (default: 5) + /// + /// After this many failed attempts, the connector enters Error state + /// and waits for manual Reconnect command. + #[serde(default = "default_max_retries")] + pub max_retries: usize, + + /// Retry delays in milliseconds (default: [60000, 60000, 60000, 60000, 60000]) + /// + /// Delays between retry attempts. Defaults to 60 seconds to avoid reconnect + /// spam when SSE connections drop. If retries exceed the length of this + /// array, the last delay is reused. + #[serde(default = "default_retry_delays")] + pub retry_delays_ms: Vec, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: default_max_retries(), + retry_delays_ms: default_retry_delays(), + } + } +} + +// Default value functions for serde + +fn default_protocol_version() -> u32 { + 1 +} + +fn default_cwd() -> String { + ".".to_string() +} + +fn default_max_retries() -> usize { + 5 +} + +fn default_retry_delays() -> Vec { + // First retry after 60 seconds to avoid spam when SSE disconnects + // Subsequent retries use same 60-second delay + vec![60000, 60000, 60000, 60000, 60000] +} + +impl AcpConfig { + /// Set the maximum number of retry attempts + /// + /// # Arguments + /// + /// * `max` - Maximum retry attempts before entering Error state + /// + /// # Returns + /// + /// Self for method chaining + pub fn with_retry_max_attempts(mut self, max: usize) -> Self { + self.retry.max_retries = max; + self + } + + /// Set the initial retry delay + /// + /// # Arguments + /// + /// * `delay` - Initial delay before first retry + /// + /// # Returns + /// + /// Self for method chaining + pub fn with_retry_initial_delay(mut self, delay: std::time::Duration) -> Self { + let delay_ms = delay.as_millis() as u64; + if !self.retry.retry_delays_ms.is_empty() { + self.retry.retry_delays_ms[0] = delay_ms; + } + self + } + + /// Set the request timeout for HTTP transport + /// + /// # Arguments + /// + /// * `timeout` - Request timeout duration + /// + /// # Returns + /// + /// Self for method chaining + /// + /// # Note + /// + /// This only applies to HTTP transport. For stdio transport, this is a no-op. + pub fn with_request_timeout(mut self, timeout: std::time::Duration) -> Self { + if let TransportKind::Http { timeout_ms, .. } = &mut self.transport { + *timeout_ms = Some(timeout.as_millis() as u64); + } + self + } + + /// Validate configuration and return errors if invalid + /// + /// Checks for common configuration mistakes like: + /// - Empty command for stdio transport + /// - Invalid URL for HTTP transport + /// - Negative or zero timeouts + /// + /// # Returns + /// + /// `Ok(())` if configuration is valid, `Err(String)` with error message otherwise. + pub fn validate(&self) -> Result<(), String> { + // Validate transport + match &self.transport { + TransportKind::Stdio { command, .. } => { + if command.trim().is_empty() { + return Err("Stdio transport requires non-empty command".to_string()); + } + } + TransportKind::Http { base_url, timeout_ms } => { + if base_url.trim().is_empty() { + return Err("HTTP transport requires non-empty base_url".to_string()); + } + // Basic URL validation + if !base_url.starts_with("http://") && !base_url.starts_with("https://") { + return Err("HTTP base_url must start with http:// or https://".to_string()); + } + if let Some(timeout) = timeout_ms { + if *timeout == 0 { + return Err("HTTP timeout must be greater than 0".to_string()); + } + } + } + } + + // Validate protocol version + if self.protocol_version == 0 { + return Err("Protocol version must be >= 1".to_string()); + } + + // Validate retry config + if self.retry.max_retries == 0 { + return Err("max_retries must be >= 1".to_string()); + } + if self.retry.retry_delays_ms.is_empty() { + return Err("retry_delays_ms cannot be empty".to_string()); + } + + Ok(()) + } + + // REMOVED: supports_feature() and with_features() methods + // supported_features is now in ConnectorConfig, not AcpConfig + + /// Create a default stdio configuration for testing + /// + /// # Arguments + /// + /// * `command` - Command to execute + /// * `args` - Command-line arguments + /// + /// # Returns + /// + /// AcpConfig with stdio transport and default values + pub fn stdio(command: impl Into, args: Vec) -> Self { + Self { + transport: TransportKind::Stdio { + command: command.into(), + args, + cwd: None, + env: vec![], + }, + protocol_version: 1, + cwd: ".".to_string(), + retry: RetryConfig::default(), + embedding: EmbeddingConfig::default(), + default_ownership: SessionOwnership::default(), + acp_log_dir: None, + agent_type: ConnectorAgentType::default(), + } + } + + /// Create a default HTTP configuration + /// + /// # Arguments + /// + /// * `base_url` - Base URL of the ACP agent + /// + /// # Returns + /// + /// AcpConfig with HTTP transport and default values + pub fn http(base_url: impl Into) -> Self { + Self { + transport: TransportKind::Http { + base_url: base_url.into(), + timeout_ms: Some(30_000), + }, + protocol_version: 1, + cwd: ".".to_string(), + retry: RetryConfig::default(), + embedding: EmbeddingConfig::default(), + default_ownership: SessionOwnership::default(), + acp_log_dir: None, + agent_type: ConnectorAgentType::default(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transport_kind_stdio_serialization() { + let transport = TransportKind::Stdio { + command: "test-agent".to_string(), + args: vec!["--stdio".to_string()], + cwd: None, + env: vec![], + }; + + let json = serde_json::to_string(&transport).unwrap(); + let deserialized: TransportKind = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized, transport); + } + + #[test] + fn test_transport_kind_http_serialization() { + let transport = TransportKind::Http { + base_url: "http://localhost:3000".to_string(), + timeout_ms: Some(60_000), + }; + + let json = serde_json::to_string(&transport).unwrap(); + let deserialized: TransportKind = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized, transport); + } + + #[test] + fn test_acp_config_defaults() { + let config = AcpConfig::stdio("agent", vec![]); + + assert_eq!(config.protocol_version, 1); + assert_eq!(config.cwd, "."); + assert_eq!(config.retry.max_retries, 5); + assert_eq!(config.retry.retry_delays_ms.len(), 5); + } + + #[test] + fn test_acp_config_validation_stdio_empty_command() { + let config = AcpConfig::stdio("", vec![]); + + let result = config.validate(); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("non-empty command")); + } + + #[test] + fn test_acp_config_validation_http_empty_url() { + let config = AcpConfig::http(""); + + let result = config.validate(); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("non-empty base_url")); + } + + #[test] + fn test_acp_config_validation_http_invalid_url() { + let config = AcpConfig::http("not-a-url"); + + let result = config.validate(); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("http://")); + } + + // REMOVED: test_acp_config_validation_empty_title + // Title is now in ConnectorConfig, not AcpConfig + + #[test] + fn test_acp_config_validation_zero_protocol_version() { + let mut config = AcpConfig::stdio("agent", vec![]); + config.protocol_version = 0; + + let result = config.validate(); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Protocol version")); + } + + #[test] + fn test_acp_config_validation_valid() { + let config = AcpConfig::stdio("agent", vec!["--stdio".to_string()]); + + assert!(config.validate().is_ok()); + } + + #[test] + fn test_acp_config_http_valid() { + let config = AcpConfig::http("https://localhost:3000"); + + assert!(config.validate().is_ok()); + } + + #[test] + fn test_retry_config_default() { + let retry = RetryConfig::default(); + + assert_eq!(retry.max_retries, 5); + // Default delays: 60 seconds each to avoid spam when SSE disconnects + assert_eq!(retry.retry_delays_ms, vec![60000, 60000, 60000, 60000, 60000]); + } + + #[test] + fn test_acp_config_serialization() { + let config = AcpConfig::stdio("test-agent", vec!["--flag".to_string()]); + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: AcpConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.protocol_version, config.protocol_version); + assert_eq!(deserialized.cwd, config.cwd); + } +} diff --git a/crates/dirigent_core/src/connectors/acp/connector.rs b/crates/dirigent_core/src/connectors/acp/connector.rs new file mode 100644 index 0000000..e2dedd6 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/connector.rs @@ -0,0 +1,3900 @@ +//! ACP Connector implementation +//! +//! This module provides the main connector implementation for the Agent-Client Protocol. +//! It integrates the transport and protocol layers with the Connector trait to provide +//! a complete, production-ready ACP client. +//! +//! # Architecture +//! +//! The AcpConnector follows the same pattern as OpenCodeConnector: +//! - Background task for event processing and reconnection +//! - Command channel for control operations +//! - Event broadcast for publishing events to subscribers +//! - State machine for lifecycle management +//! - Automatic reconnection with exponential backoff +//! +//! # Lifecycle +//! +//! 1. Create connector with `AcpConnector::new(id, owner, config)` +//! 2. Start the connector task with `start_task()` which returns a JoinHandle +//! 3. Send commands via the command channel +//! 4. Receive events via the event broadcast channel +//! 5. Stop with `Connector::stop()` which sends Shutdown command + +use crate::connectors::acp::{ + config::{AcpConfig, TransportKind}, + error::{AcpError, AcpResult}, + idle_detector::{check_idle_sessions, PendingCompletion, SessionState}, + protocol::{ + self, build_initialize_request, build_session_cancel_request, build_session_new_request, + build_session_prompt_request, ProtocolHandler, + }, + state::{InternalState, SessionInfo, SessionStatus}, + title_utils::derive_title_from_text, + transport::{AcpTransport, HttpTransport, StdioTransport}, +}; +use crate::connectors::{Connector, ConnectorCommand}; +use crate::sharing::bus::SharingBus; +use crate::types::{ConnectorId, ConnectorKind, ConnectorState, UserId}; +use chrono::Utc; +use dirigent_protocol::types::meta::Meta; +use dirigent_protocol::types::{ContentBlock, ToolOrigin}; +use dirigent_protocol::{ + Event, Message, MessagePart, MessageRole, MessageStatus, Session, SessionMetadata, + SessionOwnership, SessionUpdate, TurnCompleteTrigger, +}; +use serde_json::{json, Value}; +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{broadcast, mpsc, Mutex, RwLock}; +use tokio::task::JoinHandle; +use tracing::{debug, error, info, trace, warn}; +use uuid::Uuid; + +/// Debug log to file macro +/// Writes to acp_connector_debug.log in the logs directory (no console output). +/// For deep protocol debugging. Use tracing for normal logging. +macro_rules! debug_log { + ($($arg:tt)*) => {{ + let msg = format!("[{}] {}\n", chrono::Utc::now().format("%H:%M:%S%.3f"), format!($($arg)*)); + let log_path = dirigent_config::DirigentPaths::resolve() + .map(|p| { + let dir = p.logs_dir(); + let _ = std::fs::create_dir_all(&dir); + dir.join("acp_connector_debug.log") + }) + .unwrap_or_else(|_| std::path::PathBuf::from("acp_connector_debug.log")); + if let Ok(mut file) = std::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&log_path) + { + let _ = std::io::Write::write_all(&mut file, msg.as_bytes()); + } + // No console output - file-only logging for deep protocol debugging + }}; +} + +/// Register or update a session node in the inspector registry. +/// +/// Creates the node if it doesn't exist, or updates its properties and state. +/// Called after `upsert_session` and on session metadata changes. +#[cfg(feature = "server")] +async fn inspector_upsert_session( + inspector: &Arc, + connector_id: &str, + session: &SessionInfo, +) { + let sess_node_id = dirigent_inspector::NodeId::new(format!( + "dirigent/connectors/{}/sessions/{}", + connector_id, session.id + )); + let parent_id = + dirigent_inspector::NodeId::new(format!("dirigent/connectors/{}", connector_id)); + + let node_state = match session.status { + SessionStatus::Active => dirigent_inspector::NodeState::Running, + SessionStatus::Processing => dirigent_inspector::NodeState::Busy("Generating".to_string()), + SessionStatus::Idle => dirigent_inspector::NodeState::Idle, + SessionStatus::Ended => dirigent_inspector::NodeState::Stopped, + }; + + let label = session.title.as_deref().unwrap_or(&session.id); + + // Try to register; if already exists, update instead + let meta = + dirigent_inspector::NodeMetadata::new(dirigent_inspector::NodeKind::AsyncTask, label) + .with_state(node_state.clone()) + .with_property("session_id", serde_json::json!(&session.id)) + .with_property("status", serde_json::json!(format!("{:?}", session.status))); + + let meta = if let Some(ref model) = session.model { + meta.with_property("model", serde_json::json!(model)) + } else { + meta + }; + + let meta = if let Some(ref title) = session.title { + meta.with_property("title", serde_json::json!(title)) + } else { + meta + }; + + let meta = if session.message_count > 0 { + meta.with_property("message_count", serde_json::json!(session.message_count)) + } else { + meta + }; + + match inspector + .register(sess_node_id.clone(), &parent_id, meta, None) + .await + { + Ok(mut handle) => { + handle.detach(); + trace!(connector_id = %connector_id, session_id = %session.id, "Registered session with inspector"); + } + Err(_) => { + // Already registered — update state and properties + let _ = inspector.update_state(&sess_node_id, node_state).await; + let mut props = HashMap::new(); + props.insert( + "status".to_string(), + serde_json::json!(format!("{:?}", session.status)), + ); + if let Some(ref title) = session.title { + props.insert("title".to_string(), serde_json::json!(title)); + } + if let Some(ref model) = session.model { + props.insert("model".to_string(), serde_json::json!(model)); + } + if session.message_count > 0 { + props.insert( + "message_count".to_string(), + serde_json::json!(session.message_count), + ); + } + let _ = inspector.update_properties(&sess_node_id, props).await; + trace!(connector_id = %connector_id, session_id = %session.id, "Updated session in inspector"); + } + } +} + +/// Update only the inspector state for a session (lightweight, no property changes). +#[cfg(feature = "server")] +async fn inspector_update_session_state( + inspector: &Arc, + connector_id: &str, + session_id: &str, + state: dirigent_inspector::NodeState, +) { + let sess_node_id = dirigent_inspector::NodeId::new(format!( + "dirigent/connectors/{}/sessions/{}", + connector_id, session_id + )); + let _ = inspector.update_state(&sess_node_id, state).await; +} + +/// Deregister all session nodes for a connector from the inspector. +#[cfg(feature = "server")] +async fn inspector_deregister_all_sessions( + inspector: &Arc, + connector_id: &str, + internal_state: &InternalState, +) { + let sessions = internal_state.list_sessions().await; + for session in sessions { + let sess_node_id = dirigent_inspector::NodeId::new(format!( + "dirigent/connectors/{}/sessions/{}", + connector_id, session.id + )); + let _ = inspector.deregister_subtree(&sess_node_id).await; + } +} + +/// Helper that publishes an event to both the per-connector broadcast +/// channel and the global `SharingBus`. The bus publish goes first so any +/// subscriber reading from both sees the bus event no later than the +/// broadcast event. +async fn emit_event( + sharing_bus: &Arc, + events_tx: &broadcast::Sender, + connector_id: &ConnectorId, + connector_uid: Option, + event: Event, +) { + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + event.clone(), + connector_uid, + connector_id.clone(), + ); + sharing_bus.publish(bus_event).await; + let _ = events_tx.send(event); +} + +/// ACP connector implementation +/// +/// Provides integration with ACP-compliant agents via stdio or HTTP transport. +/// Implements the `Connector` trait for use with CoreRuntime. +pub struct AcpConnector { + /// Unique identifier for this connector instance + id: ConnectorId, + + /// Optional connector UID used for BusEvent routing. + /// + /// Populated by the runtime when the connector is registered; legacy + /// constructor paths default this to `None` and later wiring can fill it in. + pub connector_uid: Option, + + /// User who owns this connector + owner: UserId, + + /// Human-readable title for this connector + /// + /// Used in UI display and logging. This is stored separately from + /// AcpConfig as it's an orchestration-level field. + title: String, + + /// Connector configuration + config: AcpConfig, + + /// Tool configuration inherited from the connector config. + /// + /// When present, tool calls are checked against these directives + /// before being emitted as events. Denied tools are marked as errors; + /// hidden tools are silently dropped. + tool_configuration: Option, + + /// Shared connector state (Initializing, Connecting, Ready, Error, Stopped) + state: Arc>, + + /// Sender for commands to the connector + cmd_tx: mpsc::Sender, + + /// Receiver for commands (internal to the connector task) + cmd_rx: Arc>>>, + + /// Broadcast sender for events + events_tx: broadcast::Sender, + + /// Shared event bus for direct-to-bus publishes. + /// + /// Every event emitted by this connector is published here in addition + /// to `events_tx`, eliminating the forwarder task tier in the runtime. + sharing_bus: Arc, + + /// Internal state (protocol version, capabilities, sessions) + internal_state: Arc, + + /// Pending agent requests awaiting external responses + /// + /// Tracks request_ids that are waiting for user response (e.g., permission prompts). + /// No timeout - the system waits indefinitely for the user to respond. + /// Used to validate that incoming AgentResponse commands match a real pending request. + pending_agent_requests: Arc>>, + + /// Optional inspector registry for PID tracking of stdio processes + #[cfg(feature = "server")] + inspector: Option>, + + /// Optional process group manager for lifecycle management of stdio processes. + /// + /// When set, each StdioTransport created by this connector receives a + /// per-process `ProcessLifecycle` handle so that spawned agents are + /// tracked in the platform job object / process group, and are shut down + /// gracefully on close. + #[cfg(feature = "server")] + process_manager: Option>, +} + +impl AcpConnector { + /// Create a new ACP connector + /// + /// Initializes the connector with the given configuration but does not + /// start it. Call `start_task()` to begin the connector's background task. + /// + /// # Arguments + /// + /// * `id` - Unique identifier for this connector + /// * `owner` - User ID of the connector owner + /// * `title` - Human-readable title for this connector + /// * `config` - ACP configuration (transport, protocol version, etc.) + /// + /// # Returns + /// + /// A new AcpConnector in Initializing state, or an error if configuration is invalid. + pub fn new( + id: ConnectorId, + owner: UserId, + title: String, + config: AcpConfig, + sharing_bus: Arc, + ) -> AcpResult { + // Validate configuration + config.validate().map_err(AcpError::config)?; + + // Create command channel (capacity 100 for buffering commands) + let (cmd_tx, cmd_rx) = mpsc::channel(100); + + // Create event broadcast channel (capacity 1000 for event buffering) + let (events_tx, _) = broadcast::channel(1000); + + // Create internal state + let internal_state = Arc::new(InternalState::new()); + + info!( + connector_id = %id, + owner = %owner, + title = %title, + transport = ?config.transport, + "Creating ACP connector" + ); + + Ok(Self { + id, + connector_uid: None, + owner, + title, + config, + tool_configuration: None, + state: Arc::new(RwLock::new(ConnectorState::Initializing)), + cmd_tx, + cmd_rx: Arc::new(RwLock::new(Some(cmd_rx))), + events_tx, + sharing_bus, + internal_state, + pending_agent_requests: Arc::new(Mutex::new(HashSet::new())), + #[cfg(feature = "server")] + inspector: None, + #[cfg(feature = "server")] + process_manager: None, + }) + } + + /// Wrap a raw connector `Event` in a `BusEvent` populated with this + /// connector's routing identity (`connector_uid` + `connector_id`). + /// + /// Call sites that currently broadcast a raw `Event` via `events_tx` can + /// use this helper when migrating to the `BusEvent` pipeline. This is an + /// additive helper — existing `events_tx.send(event)` emissions remain + /// unchanged. + pub fn to_bus_event( + &self, + event: dirigent_protocol::Event, + ) -> dirigent_protocol::streaming::BusEvent { + dirigent_protocol::streaming::BusEvent::from_connector_event( + event, + self.connector_uid, + self.id.clone(), + ) + } + + /// Set the tool configuration for this connector. + /// + /// When set, tool calls are checked against these directives before + /// being emitted as events. This configuration is also cloned into + /// each session's metadata at creation time. + pub fn with_tool_configuration( + mut self, + tool_config: Option, + ) -> Self { + self.tool_configuration = tool_config; + self + } + + /// Set the inspector registry for PID tracking of stdio processes. + #[cfg(feature = "server")] + pub fn with_inspector( + mut self, + inspector: Option>, + ) -> Self { + self.inspector = inspector; + self + } + + /// Set the process group manager for lifecycle management of stdio processes. + /// + /// When set, each `StdioTransport` spawned by this connector will receive a + /// per-process `ProcessLifecycle` handle so agents are tracked in the + /// platform job object / process group and are shut down gracefully. + #[cfg(feature = "server")] + pub fn with_process_manager( + mut self, + process_manager: Option>, + ) -> Self { + self.process_manager = process_manager; + self + } + + /// Get the events broadcast sender + pub fn events_sender(&self) -> broadcast::Sender { + self.events_tx.clone() + } + + /// Get the state Arc + pub fn state_arc(&self) -> Arc> { + Arc::clone(&self.state) + } + + /// Start the connector's background task + /// + /// This spawns an async task that: + /// - Establishes transport connection + /// - Performs ACP initialization handshake + /// - Processes incoming notifications + /// - Handles commands from the command channel + /// - Manages automatic reconnection on failures + /// + /// # Returns + /// + /// A JoinHandle for the background task + /// + /// # Panics + /// + /// Panics if called more than once (the command receiver can only be taken once). + pub async fn start_task(&self) -> JoinHandle<()> { + let id = self.id.clone(); + let connector_uid = self.connector_uid; + let config = self.config.clone(); + let tool_configuration = self.tool_configuration.clone(); + let state = Arc::clone(&self.state); + let events_tx = self.events_tx.clone(); + let sharing_bus = Arc::clone(&self.sharing_bus); + let internal_state = Arc::clone(&self.internal_state); + let pending_agent_requests = Arc::clone(&self.pending_agent_requests); + #[cfg(feature = "server")] + let inspector = self.inspector.clone(); + #[cfg(feature = "server")] + let process_manager = self.process_manager.clone(); + + // Create session states map for idle detection + let session_states = Arc::new(Mutex::new(HashMap::::new())); + + // Take the command receiver (this can only be done once) + let cmd_rx = self + .cmd_rx + .write() + .await + .take() + .expect("start_task() called more than once - command receiver already taken"); + + info!(connector_id = %id, "Starting ACP connector task"); + + tokio::spawn(async move { + Self::run_task( + id, + connector_uid, + config, + tool_configuration, + state, + events_tx, + sharing_bus, + internal_state, + pending_agent_requests, + session_states, + cmd_rx, + #[cfg(feature = "server")] + inspector, + #[cfg(feature = "server")] + process_manager, + ) + .await; + }) + } + + /// Main connector task loop + /// + /// This is the core of the connector's async behavior. It manages: + /// - Transport connection and reconnection + /// - ACP initialization handshake + /// - Notification processing + /// - Command handling + /// - Error handling and recovery + async fn run_task( + id: ConnectorId, + connector_uid: Option, + config: AcpConfig, + tool_configuration: Option, + state: Arc>, + events_tx: broadcast::Sender, + sharing_bus: Arc, + internal_state: Arc, + pending_agent_requests: Arc>>, + session_states: Arc>>, + mut cmd_rx: mpsc::Receiver, + #[cfg(feature = "server")] inspector: Option>, + #[cfg(feature = "server")] process_manager: Option>, + ) { + debug_log!("🚀 ACP connector {} task started", id); + info!(connector_id = %id, "ACP connector task started"); + + // Reconnection state + let mut retry_count = 0; + let mut last_init_success: Option = None; + let max_retries = config.retry.max_retries; + let retry_delays: Vec = config + .retry + .retry_delays_ms + .iter() + .map(|ms| Duration::from_millis(*ms)) + .collect(); + + // Main loop: connect, process events, handle reconnects + 'reconnect_loop: loop { + // Update state to Connecting + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Connecting; + } + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: "Connecting".to_string(), + error_kind: None, + }, + ) + .await; + debug!(connector_id = %id, "Attempting to connect to ACP agent"); + + // Create transport based on config + let transport_result = Self::create_transport( + &config, + #[cfg(feature = "server")] + process_manager.as_ref(), + ) + .await; + + let mut transport = match transport_result { + Ok(t) => { + info!(connector_id = %id, "Successfully connected to ACP agent"); + if let Some(init_time) = last_init_success { + if init_time.elapsed() > Duration::from_secs(30) { + debug!(connector_id = %id, "Previous connection was stable, resetting retry count"); + retry_count = 0; + } + } + t + } + Err(e) => { + error!( + connector_id = %id, + error = %e, + "Failed to connect to ACP agent" + ); + + // Update state to Error + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Error(format!("Connection failed: {}", e)); + } + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: format!("Error(Connection failed: {})", e), + error_kind: Some("connection_failed".to_string()), + }, + ) + .await; + // Emit error event + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("Failed to connect: {}", e), + }, + ) + .await; + // Handle retry logic + if retry_count < max_retries { + let delay = retry_delays[retry_count.min(retry_delays.len() - 1)]; + warn!( + connector_id = %id, + retry_count = retry_count + 1, + max_retries = max_retries, + delay_secs = delay.as_secs(), + "Retrying connection" + ); + + retry_count += 1; + tokio::time::sleep(delay).await; + continue 'reconnect_loop; + } else { + error!( + connector_id = %id, + "Max retries exceeded, staying in Error state until manual Reconnect" + ); + + // Stay in Error state and wait for commands + loop { + match cmd_rx.recv().await { + Some(ConnectorCommand::Reconnect) => { + info!(connector_id = %id, "Received Reconnect command, resetting retry count"); + retry_count = 0; + internal_state.clear().await; + pending_agent_requests.lock().await.clear(); + continue 'reconnect_loop; + } + Some(ConnectorCommand::Shutdown) => { + info!(connector_id = %id, "Received Shutdown command"); + break 'reconnect_loop; + } + Some(cmd) => { + warn!( + connector_id = %id, + command = ?cmd, + "Received command while in Error state, ignoring" + ); + } + None => { + error!(connector_id = %id, "Command channel closed"); + break 'reconnect_loop; + } + } + } + } + } + }; + + // Perform ACP initialization + let protocol_handler = ProtocolHandler::new(); + + // Create ACP adapter for notification translation + let adapter = dirigent_protocol::adapters::AcpAdapter::new(); + + match Self::initialize_acp( + &id, + &mut transport, + &protocol_handler, + &config, + &internal_state, + ) + .await + { + Ok(()) => { + info!(connector_id = %id, "ACP initialization successful"); + retry_count = 0; + last_init_success = Some(tokio::time::Instant::now()); + + // Update state to Ready + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Ready; + } + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: "Ready".to_string(), + error_kind: None, + }, + ) + .await; + // Emit Connected event + emit_event(&sharing_bus, &events_tx, &id, connector_uid, Event::Connected).await; + // Register stdio process PID with inspector (if available) + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + if let Some(pid) = transport.pid().await { + let process_node_id = dirigent_inspector::NodeId::new(format!( + "dirigent/connectors/{}/process", + id + )); + let parent_node_id = dirigent_inspector::NodeId::new(format!( + "dirigent/connectors/{}", + id + )); + let meta = dirigent_inspector::NodeMetadata::new( + dirigent_inspector::NodeKind::Process, + "stdio-process", + ) + .with_state(dirigent_inspector::NodeState::Running) + .with_property("pid", serde_json::json!(pid)) + .with_property("transport", serde_json::json!("stdio")); + if let Ok(mut handle) = inspector + .register(process_node_id, &parent_node_id, meta, None) + .await + { + handle.detach(); + info!(connector_id = %id, pid = pid, "Registered stdio process with inspector"); + } + } + } + + // Auto-discover sessions if agent supports session/list + if internal_state.agent_supports_list_sessions().await { + info!(connector_id = %id, "Agent supports session/list — auto-discovering sessions"); + Self::execute_list_sessions( + &mut transport, + &protocol_handler, + &events_tx, &sharing_bus, connector_uid, + &id, + &config, + &internal_state, + &pending_agent_requests, + &mut cmd_rx, + ).await; + } else { + debug!(connector_id = %id, "Agent does not support session/list — skipping auto-discovery"); + } + } + Err(e) => { + error!( + connector_id = %id, + error = %e, + "ACP initialization failed" + ); + + // Close transport + let _ = transport.close().await; + + // Update state to Error + { + let mut state_guard = state.write().await; + *state_guard = + ConnectorState::Error(format!("Initialization failed: {}", e)); + } + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: format!("Error(Initialization failed: {})", e), + error_kind: Some("initialization_failed".to_string()), + }, + ) + .await; + // Emit error event + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("Initialization failed: {}", e), + }, + ) + .await; + // Retry if retriable + if e.is_retriable() && retry_count < max_retries { + let delay = retry_delays[retry_count.min(retry_delays.len() - 1)]; + warn!( + connector_id = %id, + retry_count = retry_count + 1, + "Retrying after initialization failure" + ); + + retry_count += 1; + tokio::time::sleep(delay).await; + continue 'reconnect_loop; + } else { + // Enter Error state + loop { + match cmd_rx.recv().await { + Some(ConnectorCommand::Reconnect) => { + info!(connector_id = %id, "Received Reconnect command"); + retry_count = 0; + internal_state.clear().await; + pending_agent_requests.lock().await.clear(); + continue 'reconnect_loop; + } + Some(ConnectorCommand::Shutdown) => { + info!(connector_id = %id, "Received Shutdown command"); + break 'reconnect_loop; + } + Some(cmd) => { + warn!(connector_id = %id, command = ?cmd, "Ignoring command in Error state"); + } + None => { + error!(connector_id = %id, "Command channel closed"); + break 'reconnect_loop; + } + } + } + } + } + } + + // Event processing loop + // Get notification receiver from protocol handler + let mut notifications = protocol_handler + .take_notification_receiver() + .await + .expect("Notification receiver should be available"); + + // Idle check interval for housekeeping (50ms interval) + let mut idle_check_interval = tokio::time::interval(Duration::from_millis(50)); + // Don't let missed ticks accumulate + idle_check_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + 'event_loop: loop { + tokio::select! { + biased; // Enforce priority: transport > notifications > commands > housekeeping + + // PRIORITY 1: Handle incoming notifications from transport + message = transport.recv() => { + match message { + Ok(Some(msg)) => { + let masked_msg = dirigent_protocol::log_utils::format_for_log(&msg); + info!(connector_id = %id, message = %masked_msg, "📬 Received message from transport in event loop"); + + // Route message through protocol handler + use crate::connectors::acp::protocol::MessageHandlerResult; + match protocol_handler.handle_message(msg).await { + MessageHandlerResult::None => { + // No response needed (notification processed) + } + MessageHandlerResult::Response(response) => { + // Send immediate response back to agent + info!(connector_id = %id, "📤 Sending response to agent request"); + if let Err(e) = transport.send(response).await { + error!(connector_id = %id, error = %e, "Failed to send response to agent"); + } + } + MessageHandlerResult::AgentRequest { request_id, method, params } => { + // Agent is requesting something from the client (e.g., permission) + info!( + connector_id = %id, + method = %method, + request_id = %request_id, + "🔔 Agent request received - emitting Event::AgentRequest" + ); + + // Extract session_id from params (standard ACP location) + let session_id = params + .get("sessionId") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + + // Determine if this is a forwarded (external) session + // Check session ownership to decide routing + let is_forwarded = if let Some(session_info) = internal_state.get_session(&session_id).await { + // Session found - check if it's external (owned by ACP client) + session_info.ownership.is_external() + } else { + // Session not found - default to false (internal) + // This shouldn't happen in normal operation, but we default to showing UI modal + warn!( + connector_id = %id, + session_id = %session_id, + "Session not found when processing AgentRequest - defaulting to internal (UI modal)" + ); + false + }; + + // Store the pending request (no timeout - wait indefinitely for user) + let request_id_str = request_id.to_string(); + { + let mut pending = pending_agent_requests.lock().await; + pending.insert(request_id_str.clone()); + } + info!( + connector_id = %id, + request_id = %request_id_str, + is_forwarded = is_forwarded, + "Stored pending agent request (waiting for user response)" + ); + + // Emit Event::AgentRequest for routing + // - If is_forwarded=true: EventBridge forwards to external client + // - If is_forwarded=false: Web UI shows permission modal + let event = Event::AgentRequest { + connector_id: id.clone(), + session_id, + request_id: request_id.clone(), + method, + params, + is_forwarded, + }; + + let bus_event_for_agent_request = dirigent_protocol::streaming::BusEvent::from_connector_event( + event.clone(), + connector_uid, + id.clone(), + ); + sharing_bus.publish(bus_event_for_agent_request).await; + if let Err(e) = events_tx.send(event) { + error!( + connector_id = %id, + error = %e, + "Failed to emit AgentRequest event" + ); + } + + // Response will be handled when AgentResponse command is received + // No timeout - the system waits indefinitely for user input + } + } + } + Ok(None) => { + warn!(connector_id = %id, "Transport stream closed"); + + // Update state to Connecting (will reconnect) + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Connecting; + } + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: "Connecting".to_string(), + error_kind: None, + }, + ) + .await; + // Emit disconnected event + emit_event(&sharing_bus, &events_tx, &id, connector_uid, Event::Disconnected).await; + // Deregister session nodes (will be re-created on reconnect) + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + inspector_deregister_all_sessions(inspector, &id, &internal_state).await; + } + + // Cancel all pending requests with crash context before closing + let crash_reason = Self::build_crash_reason(&transport, "Connection lost: transport closed").await; + protocol_handler.cancel_all_pending(&crash_reason).await; + + // Close transport and clear stale requests before reconnect + let _ = transport.close().await; + pending_agent_requests.lock().await.clear(); + break 'event_loop; + } + Err(e) => { + error!(connector_id = %id, error = %e, "Transport error"); + + // Update state to Error + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Error(format!("Transport error: {}", e)); + } + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: format!("Error(Transport error: {})", e), + error_kind: Some("transport_error".to_string()), + }, + ) + .await; + // Emit error event + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("Transport error: {}", e), + }, + ) + .await; + // Deregister session nodes (will be re-created on reconnect) + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + inspector_deregister_all_sessions(inspector, &id, &internal_state).await; + } + + // Cancel all pending requests with crash context before closing + let crash_reason = Self::build_crash_reason(&transport, &format!("Connection lost: {}", e)).await; + protocol_handler.cancel_all_pending(&crash_reason).await; + + // Close transport and clear stale requests before reconnect + let _ = transport.close().await; + pending_agent_requests.lock().await.clear(); + break 'event_loop; + } + } + } + + // Handle notifications from protocol handler + notification = notifications.recv() => { + match notification { + Some(notif) => { + let masked_notif = dirigent_protocol::log_utils::format_for_log(¬if); + debug_log!("🔔 Notification received: {}", masked_notif); + info!(connector_id = %id, notification = %masked_notif, "🔔 Received ACP notification from protocol handler"); + + // Translate notification to Dirigent event using adapter + if let Err(e) = Self::handle_notification(&id, notif, &events_tx, &sharing_bus, connector_uid, &adapter, &internal_state, &session_states, &tool_configuration).await { + debug_log!("⚠️ Failed to handle notification: {}", e); + warn!(connector_id = %id, error = %e, "Failed to handle notification"); + } + } + None => { + warn!(connector_id = %id, "Notification channel closed"); + break 'event_loop; + } + } + } + + // Handle commands + cmd = cmd_rx.recv() => { + match cmd { + Some(ConnectorCommand::ListSessions) => { + debug!(connector_id = %id, "Processing ListSessions command"); + Self::execute_list_sessions( + &mut transport, + &protocol_handler, + &events_tx, &sharing_bus, connector_uid, + &id, + &config, + &internal_state, + &pending_agent_requests, + &mut cmd_rx, + ).await; + } + + Some(ConnectorCommand::CreateSession { cwd, project_id, ownership }) => { + debug!(connector_id = %id, "Processing CreateSession command"); + debug_log!("🆕 CreateSession command received, cwd={:?}, ownership={:?}", cwd, ownership); + + // Build session/new request (use provided cwd or default from config) + let session_cwd = cwd.as_ref().unwrap_or(&config.cwd); + let request = build_session_new_request(session_cwd, None); + debug_log!("📤 session/new request: {}", serde_json::to_string_pretty(&request).unwrap_or_default()); + + // Send request and wait for response + match Self::send_request(&mut transport, &protocol_handler, request, Some(&events_tx), Some(&sharing_bus), connector_uid, &id, Some(&pending_agent_requests), Some(&mut cmd_rx), Some(&internal_state)).await { + Ok(response) => { + debug_log!("📥 session/new response: {}", serde_json::to_string_pretty(&response).unwrap_or_default()); + + // Extract session ID and metadata from result + if let Some(result) = response.get("result") { + if let Some(session_id) = result.get("sessionId").and_then(|s| s.as_str()) { + debug_log!("✅ session/new returned sessionId: {}", session_id); + info!(connector_id = %id, session_id = session_id, "Session created successfully"); + + // Parse optional models field (UNSTABLE in ACP spec but Claude uses it) + let models: Option = result + .get("models") + .and_then(|m| { + serde_json::from_value(m.clone()) + .map_err(|e| { + warn!(connector_id = %id, "Failed to parse models from session/new: {}", e); + e + }) + .ok() + }); + + // Parse optional modes field (Stable in ACP spec) + let modes: Option = result + .get("modes") + .and_then(|m| { + serde_json::from_value(m.clone()) + .map_err(|e| { + warn!(connector_id = %id, "Failed to parse modes from session/new: {}", e); + e + }) + .ok() + }); + + // Parse optional configOptions field (ACP spec replacement for modes/models) + let config_options: Option> = result + .get("configOptions") + .and_then(|co| { + serde_json::from_value(co.clone()) + .map_err(|e| { + warn!(connector_id = %id, "Failed to parse configOptions from session/new: {}", e); + e + }) + .ok() + }); + + if models.is_some() || modes.is_some() || config_options.is_some() { + debug_log!("📊 session/new metadata: models={}, modes={}, configOptions={}", + models.as_ref().map(|m| format!("{} models, current={}", m.available_models.len(), m.current_model_id)).unwrap_or_else(|| "none".to_string()), + modes.as_ref().map(|m| format!("{} modes, current={}", m.available_modes.len(), m.current_mode_id)).unwrap_or_else(|| "none".to_string()), + config_options.as_ref().map(|co| format!("{} options", co.len())).unwrap_or_else(|| "none".to_string()) + ); + } + + // Extract current_mode_id for SessionMetadata + let current_mode_id = modes.as_ref().map(|m| m.current_mode_id.clone()); + + // Extract current model name for legacy model field + let current_model_name = models.as_ref().and_then(|m| { + m.available_models.iter() + .find(|model| model.model_id == m.current_model_id) + .map(|model| model.name.clone()) + }); + + // Update internal state with metadata + let now = Utc::now(); + let session_info = SessionInfo { + id: session_id.to_string(), + title: None, + cwd: session_cwd.clone(), + message_count: 0, + model: current_model_name.clone(), + created_at: now, + last_activity: now, + status: SessionStatus::Active, + models: models.clone(), + modes: modes.clone(), + config_options: config_options.clone(), + ownership: ownership.clone(), + }; + internal_state.upsert_session(session_info.clone()).await; + + // Register session with inspector + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + inspector_upsert_session(inspector, &id, &session_info).await; + } + + // Build _meta with tool configuration if present + let session_meta = Self::build_session_meta(&tool_configuration); + + // Emit SessionCreated event + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionCreated { + connector_id: id.clone(), + session: dirigent_protocol::Session { + id: session_id.to_string(), + title: String::new(), + created_at: Utc::now(), + updated_at: Utc::now(), + metadata: dirigent_protocol::SessionMetadata { + project_path: session_cwd.clone(), + model: current_model_name, + total_messages: 0, + system_message: None, + current_mode_id: current_mode_id.clone(), + _meta: session_meta, + project_id: project_id.as_ref().and_then(|p| uuid::Uuid::parse_str(p).ok()), + }, + cwd: Some(session_cwd.clone()), + // Include ACP models/modes state in Session + models: models.clone(), + modes: modes.clone(), + config_options: config_options.clone(), + // ACP connector sessions don't have a client ID (that's for incoming connections) + acp_client_id: None, + }, + }, + ) + .await; + + // Emit SessionMetadataReceived event if we have any metadata + if models.is_some() || modes.is_some() || config_options.is_some() { + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionMetadataReceived { + connector_id: id.clone(), + session_id: session_id.to_string(), + models, + modes, + config_options, + }, + ) + .await; + } + + // Auto-trigger ListSessions after session creation. + // This populates the session list for any deferred + // obligations (e.g., a client called session/list on + // a gateway, got transferred here, and is waiting for + // the real session list from the upstream agent). + Self::execute_list_sessions( + &mut transport, + &protocol_handler, + &events_tx, &sharing_bus, connector_uid, + &id, + &config, + &internal_state, + &pending_agent_requests, + &mut cmd_rx, + ).await; + } else { + error!(connector_id = %id, "Session creation response missing sessionId"); + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: "Session creation failed: invalid response".to_string(), + }, + ) + .await; + } + } + } + Err(e) => { + error!(connector_id = %id, error = %e, "Failed to create session"); + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("Failed to create session: {}", e), + }, + ) + .await; + } + } + } + + Some(ConnectorCommand::ListMessages { session_id }) => { + debug!(connector_id = %id, session_id = %session_id, "Processing ListMessages command"); + + // For ACP, message listing is not part of the core protocol + // This would need to be handled via custom agent implementation + warn!(connector_id = %id, "ListMessages not yet implemented for ACP"); + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessagesListed { messages: vec![] }, + ) + .await; + } + + Some(ConnectorCommand::LoadSession { session_id, cwd, mcp_servers }) => { + debug!(connector_id = %id, session_id = %session_id, "Processing LoadSession command"); + + // Check agent capabilities and session state + let supports_load = internal_state.agent_supports_load_session().await; + let supports_resume = internal_state.agent_supports_session_resume().await; + let already_loaded = internal_state.is_session_loaded(&session_id).await; + + let request = if already_loaded && supports_resume { + // Session already loaded in current process — use resume (fast re-attach, no history replay) + debug!(connector_id = %id, session_id = %session_id, "Using session/resume (already loaded in process)"); + Some(protocol::build_session_resume_request(&session_id, Some(&cwd), mcp_servers.clone())) + } else if supports_load { + // Full restore from disk with history replay + debug!(connector_id = %id, session_id = %session_id, "Using session/load (full restore)"); + Some(protocol::build_session_load_request(&session_id, &cwd, mcp_servers)) + } else if supports_resume { + // Fallback: agent doesn't support load but supports resume + debug!(connector_id = %id, session_id = %session_id, "Using session/resume (load not supported, fallback)"); + Some(protocol::build_session_resume_request(&session_id, Some(&cwd), mcp_servers)) + } else { + warn!(connector_id = %id, "Agent supports neither session/load nor session/resume"); + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: "Session loading is not supported by this agent".to_string(), + }, + ) + .await; + None + }; + + if let Some(request) = request { + + // Send request and wait for response + match Self::send_request(&mut transport, &protocol_handler, request, Some(&events_tx), Some(&sharing_bus), connector_uid, &id, Some(&pending_agent_requests), Some(&mut cmd_rx), Some(&internal_state)).await { + Ok(response) => { + info!(connector_id = %id, session_id = %session_id, "Session loaded successfully"); + + // Extract result for metadata parsing + let result = response.get("result"); + + // Parse optional models field (UNSTABLE in ACP spec but Claude uses it) + let models: Option = result + .and_then(|r| r.get("models")) + .and_then(|m| { + serde_json::from_value(m.clone()) + .map_err(|e| { + warn!(connector_id = %id, "Failed to parse models from session/load: {}", e); + e + }) + .ok() + }); + + // Parse optional modes field (Stable in ACP spec) + let modes: Option = result + .and_then(|r| r.get("modes")) + .and_then(|m| { + serde_json::from_value(m.clone()) + .map_err(|e| { + warn!(connector_id = %id, "Failed to parse modes from session/load: {}", e); + e + }) + .ok() + }); + + // Parse optional configOptions field (ACP spec replacement for modes/models) + let config_options: Option> = result + .and_then(|r| r.get("configOptions")) + .and_then(|co| { + serde_json::from_value(co.clone()) + .map_err(|e| { + warn!(connector_id = %id, "Failed to parse configOptions from session/load: {}", e); + e + }) + .ok() + }); + + if models.is_some() || modes.is_some() || config_options.is_some() { + debug_log!("📊 session/load metadata: models={}, modes={}, configOptions={}", + models.as_ref().map(|m| format!("{} models, current={}", m.available_models.len(), m.current_model_id)).unwrap_or_else(|| "none".to_string()), + modes.as_ref().map(|m| format!("{} modes, current={}", m.available_modes.len(), m.current_mode_id)).unwrap_or_else(|| "none".to_string()), + config_options.as_ref().map(|co| format!("{} options", co.len())).unwrap_or_else(|| "none".to_string()) + ); + } + + // Extract current_mode_id for SessionMetadata + let current_mode_id = modes.as_ref().map(|m| m.current_mode_id.clone()); + + // Extract current model name for legacy model field + let current_model_name = models.as_ref().and_then(|m| { + m.available_models.iter() + .find(|model| model.model_id == m.current_model_id) + .map(|model| model.name.clone()) + }); + + // Update internal state with metadata + let now = Utc::now(); + let session_info = SessionInfo { + id: session_id.clone(), + title: None, + cwd: config.cwd.clone(), + message_count: 0, + model: current_model_name.clone(), + created_at: now, + last_activity: now, + status: SessionStatus::Active, + models: models.clone(), + modes: modes.clone(), + config_options: config_options.clone(), + ownership: SessionOwnership::default(), + }; + internal_state.upsert_session(session_info.clone()).await; + + // Register session with inspector + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + inspector_upsert_session(inspector, &id, &session_info).await; + } + + // Build _meta with tool configuration if present + let session_meta = Self::build_session_meta(&tool_configuration); + + // Emit SessionCreated event so archivist can register the session + // This prevents resolve_session() failures when messages arrive + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionCreated { + connector_id: id.clone(), + session: dirigent_protocol::Session { + id: session_id.clone(), + title: String::new(), + created_at: now, + updated_at: now, + metadata: dirigent_protocol::SessionMetadata { + project_path: config.cwd.clone(), + model: current_model_name, + total_messages: 0, + system_message: None, + current_mode_id: current_mode_id.clone(), + _meta: session_meta, + project_id: None, + }, + cwd: Some(config.cwd.clone()), + // Include ACP models/modes state in Session + models: models.clone(), + modes: modes.clone(), + config_options: config_options.clone(), + // ACP connector sessions don't have a client ID (that's for incoming connections) + acp_client_id: None, + }, + }, + ) + .await; + + // Emit SessionMetadataReceived event if we have any metadata + if models.is_some() || modes.is_some() || config_options.is_some() { + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionMetadataReceived { + connector_id: id.clone(), + session_id: session_id.clone(), + models, + modes, + config_options, + }, + ) + .await; + } + + // Track that this session is now loaded in this process + // so subsequent loads use session/resume instead of session/load + internal_state.mark_session_loaded(&session_id).await; + + // Register session with idle detector so that SessionIdle + // fires after history replay notifications stop arriving. + // This triggers archivist finalization (MessageAccumulator.finalize()) + // which writes the replayed messages to disk. + { + let mut states = session_states.lock().await; + if let Some(state) = states.get_mut(&session_id) { + state.mark_awaiting_idle(); + } else { + let mut new_state = SessionState::new(session_id.clone()); + new_state.mark_awaiting_idle(); + states.insert(session_id.clone(), new_state); + } + } + + // History replay happens via session/update notifications + // which are handled by the notification handler + } + Err(e) => { + error!(connector_id = %id, session_id = %session_id, error = %e, "Failed to load session"); + + // Check if it's a "session not found" error (not recoverable) + // Other errors (transport, timeout) may be recoverable + let (error_message, is_recoverable) = if matches!(e, AcpError::AgentError { code, .. } if code == -32602 || code == -32001) { + (format!("Session not found: {}", session_id), false) + } else { + (format!("Failed to load session: {}", e), e.is_retriable()) + }; + + let error_code = if matches!(e, AcpError::AgentError { code, .. } if code == -32602 || code == -32001) { + "SESSION_NOT_FOUND" + } else { + "SESSION_LOAD_FAILED" + }; + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionError { + connector_id: id.clone(), + session_id: session_id.clone(), + error_message, + is_recoverable, + error_code: Some(error_code.to_string()), + technical_details: Some(format!("{:?}", e)), + context: Some(serde_json::json!({ + "operation": "session/load", + "connector_id": id.clone(), + })), + }, + ) + .await; + debug_log!("📛 Emitted SessionError for LoadSession failure: session={}, error={}", session_id, e); + } + } + } + } + + Some(ConnectorCommand::CancelGeneration { session_id }) => { + debug!(connector_id = %id, session_id = %session_id, "Processing CancelGeneration command"); + + // Build session/cancel request + let request = build_session_cancel_request(&session_id); + + // Send request (fire and forget, don't wait for response) + match Self::send_request(&mut transport, &protocol_handler, request, Some(&events_tx), Some(&sharing_bus), connector_uid, &id, Some(&pending_agent_requests), Some(&mut cmd_rx), Some(&internal_state)).await { + Ok(_) => { + info!(connector_id = %id, session_id = %session_id, "Generation cancelled"); + } + Err(e) => { + warn!(connector_id = %id, session_id = %session_id, error = %e, "Failed to cancel generation"); + } + } + } + + Some(ConnectorCommand::SendMessage { session_id, text }) => { + debug_log!("📨 SendMessage command: session={}, text_len={}", session_id, text.len()); + info!(connector_id = %id, session_id = %session_id, text_len = text.len(), "Processing SendMessage command"); + + // Generate UUIDv7 for the user message (for temporal ordering) + let user_message_id = uuid::Uuid::now_v7().to_string(); + debug_log!("📝 Generated user message ID: {}", user_message_id); + info!(connector_id = %id, user_message_id = %user_message_id, "Generated user message ID"); + + // Build prompt request + let prompt = json!([{ + "type": "text", + "text": text + }]); + + let request = build_session_prompt_request(&session_id, prompt); + let masked_request = dirigent_protocol::log_utils::format_for_log(&request); + debug!(connector_id = %id, request = %masked_request, "Built prompt request"); + + // NON-BLOCKING: Prepare request and send it, then spawn task to handle response + // Prepare request (adds ID and creates response channel) + let (message_with_id, response_rx) = protocol_handler.prepare_request(request).await; + + let request_id = message_with_id.get("id").cloned().unwrap_or(json!(null)); + debug!(connector_id = %id, request_id = %request_id, "Prepared non-blocking prompt request"); + + // Send via transport (non-blocking - just writes to transport) + let send_result = transport.send(message_with_id).await; + debug!(connector_id = %id, request_id = %request_id, "Non-blocking send completed"); + + // Mark session as Processing in inspector + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + inspector_update_session_state( + inspector, + &id, + &session_id, + dirigent_inspector::NodeState::Busy("Generating".to_string()), + ).await; + } + + // Emit user message IMMEDIATELY via SessionUpdate for instant UI feedback + // This ensures the user message appears BEFORE the agent starts responding + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionUpdate { + connector_id: id.clone(), + session_id: session_id.clone(), + update: SessionUpdate::UserMessageChunk { + message_id: user_message_id.clone(), + content: ContentBlock::Text { + text: text.clone(), + }, + _meta: None, + }, + }, + ) + .await; + debug!(connector_id = %id, session_id = %session_id, user_message_id = %user_message_id, "Emitted UserMessageChunk immediately after send"); + + // Clone values for spawned task + let session_id_clone = session_id.clone(); + let text_clone = text.clone(); + let user_message_id_clone = user_message_id.clone(); + let id_clone = id.clone(); + let events_tx_clone = events_tx.clone(); + let sharing_bus_clone = Arc::clone(&sharing_bus); + let internal_state_clone = Arc::clone(&internal_state); + let session_states_clone = Arc::clone(&session_states); + #[cfg(feature = "server")] + let inspector_clone = inspector.clone(); + + // Spawn task to wait for response and perform post-processing + // This allows the main event loop to continue polling notifications + tokio::spawn(async move { + match send_result { + Err(e) => { + error!(connector_id = %id_clone, session_id = %session_id_clone, error = %e, "Failed to send prompt (transport error)"); + + let is_recoverable = true; // Transport errors are generally recoverable + let error_string = format!("{}", e); + let debug_string = format!("{:?}", e); + + // Extract received content preview if present (for JSON parse errors) + let received_preview = if error_string.contains("Line:") { + error_string + .split("Line:") + .nth(1) + .map(|s| { + let trimmed = s.trim(); + if trimmed.len() > 500 { + format!("{}... (truncated, {} bytes total)", &trimmed[..500], trimmed.len()) + } else { + trimmed.to_string() + } + }) + } else { + None + }; + + emit_event( + &sharing_bus_clone, + &events_tx_clone, + &id_clone, + connector_uid, + Event::SessionError { + connector_id: id_clone.clone(), + session_id: session_id_clone.clone(), + error_message: "Failed to send message: Transport error".to_string(), + is_recoverable, + error_code: Some("TRANSPORT_ERROR".to_string()), + technical_details: Some(debug_string), + context: Some(serde_json::json!({ + "operation": "session/prompt", + "connector_id": id_clone.clone(), + "error_display": error_string, + "received_preview": received_preview, + "user_message_preview": if text_clone.len() > 100 { + format!("{}...", &text_clone[..100]) + } else { + text_clone.clone() + }, + })), + }, + ) + .await; + + // Mark session as awaiting idle even on error + // The idle check will emit SessionIdle after the threshold + { + let mut states = session_states_clone.lock().await; + if let Some(state) = states.get_mut(&session_id_clone) { + state.mark_awaiting_idle(); + } else { + // Create new state if not exists + states.insert(session_id_clone.clone(), SessionState::new(session_id_clone.clone())); + states.get_mut(&session_id_clone).unwrap().mark_awaiting_idle(); + } + } + return; + } + Ok(_) => { + debug!(connector_id = %id_clone, "Waiting for prompt response..."); + } + } + + // Wait for response via protocol handler correlation + match response_rx.await { + Ok(response) => { + let masked_response = dirigent_protocol::log_utils::format_for_log(&response); + debug_log!("✅ Prompt response received (non-blocking): {}", masked_response); + info!(connector_id = %id_clone, response = %masked_response, "Prompt sent successfully, got response"); + + // Check for JSON-RPC error in response + if let Some(error) = response.get("error") { + let message = error + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error") + .to_string(); + + let error_code_num = error.get("code").and_then(|c| c.as_i64()); + let error_data = error.get("data").cloned(); + + // Detect cancel_all_pending responses (code -32000 = connection lost) + let (error_code_str, error_msg) = if error_code_num == Some(-32000) { + error!(connector_id = %id_clone, session_id = %session_id_clone, error = %message, "Request cancelled due to connection loss"); + ("CONNECTION_LOST".to_string(), message.clone()) + } else { + error!(connector_id = %id_clone, session_id = %session_id_clone, error = %message, "Prompt request failed with JSON-RPC error"); + ("JSONRPC_ERROR".to_string(), format!("Agent error: {}", message)) + }; + + emit_event( + &sharing_bus_clone, + &events_tx_clone, + &id_clone, + connector_uid, + Event::SessionError { + connector_id: id_clone.clone(), + session_id: session_id_clone.clone(), + error_message: error_msg, + is_recoverable: true, + error_code: Some(error_code_str), + technical_details: Some(format!("JSON-RPC error code: {:?}, message: {}", error_code_num, message)), + context: Some(serde_json::json!({ + "operation": "session/prompt", + "jsonrpc_error_code": error_code_num, + "jsonrpc_error_message": message, + "jsonrpc_error_data": error_data, + })), + }, + ) + .await; + + // Mark session as awaiting idle after error + { + let mut states = session_states_clone.lock().await; + if let Some(state) = states.get_mut(&session_id_clone) { + state.mark_awaiting_idle(); + } else { + states.insert(session_id_clone.clone(), SessionState::new(session_id_clone.clone())); + states.get_mut(&session_id_clone).unwrap().mark_awaiting_idle(); + } + } + return; + } + + // Emit user message event AFTER successful prompt send + // User messages are atomic (not streamed), so emit MessageCompleted directly + let user_message = Message { + id: user_message_id_clone.clone(), + session_id: session_id_clone.clone(), + role: MessageRole::User, + created_at: Utc::now(), + content: vec![ + MessagePart::Text { + text: text_clone.clone(), + } + ], + status: MessageStatus::Completed, + metadata: None, + }; + + emit_event( + &sharing_bus_clone, + &events_tx_clone, + &id_clone, + connector_uid, + Event::MessageCompleted { + connector_id: id_clone.clone(), + message: user_message, + }, + ) + .await; + debug_log!("✅ Emitted user MessageCompleted for {}", user_message_id_clone); + info!(connector_id = %id_clone, session_id = %session_id_clone, user_message_id = %user_message_id_clone, "✅ Emitted user MessageCompleted"); + + // Emit TurnComplete for user message so archivist can finalize and write it + // User messages are atomic (not streamed), so turn is complete immediately + emit_event( + &sharing_bus_clone, + &events_tx_clone, + &id_clone, + connector_uid, + Event::TurnComplete { + connector_id: id_clone.clone(), + session_id: session_id_clone.clone(), + message_id: user_message_id_clone.clone(), + trigger: TurnCompleteTrigger::ResponseReceived, + }, + ) + .await; + debug_log!("✅ Emitted user TurnComplete for {}", user_message_id_clone); + + // Derive title from first user message (Task 1.2: Title Propagation) + // Check if this is the first user message for this session + let should_set_title = { + let session_info = internal_state_clone.get_session(&session_id_clone).await; + match session_info { + Some(info) => info.title.is_none() && info.message_count == 0, + None => false, // Session not tracked, skip title derivation + } + }; + + if should_set_title { + // Derive title from text content (first 50 chars, trimmed, word boundary) + let title = derive_title_from_text(&text_clone); + + debug_log!("📝 Derived title for session {}: {}", session_id_clone, title); + info!(connector_id = %id_clone, session_id = %session_id_clone, title = %title, "📝 Derived session title from first user message"); + + // Update internal state with title + if let Some(mut info) = internal_state_clone.get_session(&session_id_clone).await { + info.title = Some(title.clone()); + internal_state_clone.upsert_session(info.clone()).await; + + // Update inspector with new title + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector_clone { + inspector_upsert_session(inspector, &id_clone, &info).await; + } + } + + // Emit SessionMetadataUpdated event + emit_event( + &sharing_bus_clone, + &events_tx_clone, + &id_clone, + connector_uid, + Event::SessionMetadataUpdated { + connector_id: id_clone.clone(), + session_id: session_id_clone.clone(), + title: Some(title.clone()), + total_messages: None, + model: None, + }, + ) + .await; + debug_log!("✅ Emitted SessionMetadataUpdated with title: {}", title); + info!(connector_id = %id_clone, session_id = %session_id_clone, title = %title, "✅ Emitted SessionMetadataUpdated"); + } + + // Response streaming is handled via notifications + // Note: Notifications may arrive AFTER this response, so there might not be + // an active message yet. We still need to emit SessionIdle to clear is_generating. + + // Always try to emit MessageCompleted, even if no chunks arrived yet. + // This handles the race where the prompt response arrives before the first chunk. + let message_id = internal_state_clone.clear_active_message(&session_id_clone).await; + + // Build pending completion data if we have an active message + // MessageCompleted will be emitted when SessionIdle fires (after 300ms of inactivity) + // This ensures all notifications are processed before finalization + let pending = match message_id { + Some(message_id) => { + debug_log!("🎯 Preparing deferred MessageCompleted: session={}, message_id={}", session_id_clone, message_id); + info!(connector_id = %id_clone, session_id = %session_id_clone, message_id = %message_id, "🎯 Storing pending MessageCompleted (will emit on SessionIdle)"); + + // Get message start time (when first chunk arrived) + let created_at = internal_state_clone + .get_message_start_time(&message_id) + .await + .unwrap_or_else(Utc::now); // Fallback to now if not tracked + + // Create a minimal Message struct to signal completion + // The cache already has all the chunks accumulated + let message = Message { + id: message_id.clone(), + session_id: session_id_clone.clone(), + role: MessageRole::Assistant, + created_at, // Use message start time, not finalization time + content: vec![], // Chunks already accumulated in cache + status: MessageStatus::Completed, + metadata: None, + }; + + // Clean up message start time to prevent memory leak + internal_state_clone.clear_message_start_time(&message_id).await; + + Some(PendingCompletion { + connector_id: id_clone.clone(), + message, + }) + } + None => { + debug_log!("⚠️ No active message to finalize (normal for sessions with no streaming)"); + None + } + }; + + // Mark session as awaiting idle and store pending completion + // MessageCompleted + SessionIdle will be emitted by check_idle_sessions() + // after IDLE_THRESHOLD (300ms) of inactivity + { + let mut states = session_states_clone.lock().await; + if let Some(state) = states.get_mut(&session_id_clone) { + state.mark_awaiting_idle(); + state.pending_completion = pending; + } else { + // Create new state if not exists + let mut new_state = SessionState::new(session_id_clone.clone()); + new_state.mark_awaiting_idle(); + new_state.pending_completion = pending; + states.insert(session_id_clone.clone(), new_state); + } + } + debug_log!("✅ Marked session {} as awaiting idle with pending completion", session_id_clone); + info!(connector_id = %id_clone, session_id = %session_id_clone, "🎯 Marked session as awaiting idle after prompt completion (MessageCompleted deferred)"); + } + Err(_) => { + error!(connector_id = %id_clone, session_id = %session_id_clone, "Response channel dropped (connection lost)"); + + emit_event( + &sharing_bus_clone, + &events_tx_clone, + &id_clone, + connector_uid, + Event::SessionError { + connector_id: id_clone.clone(), + session_id: session_id_clone.clone(), + error_message: "Connection to agent was lost".to_string(), + is_recoverable: true, + error_code: Some("CONNECTION_LOST".to_string()), + technical_details: Some("Response channel was dropped — agent process likely exited".to_string()), + context: Some(serde_json::json!({ + "operation": "session/prompt", + "reason": "channel_dropped", + })), + }, + ) + .await; + + // Mark session as awaiting idle after timeout/cancellation + { + let mut states = session_states_clone.lock().await; + if let Some(state) = states.get_mut(&session_id_clone) { + state.mark_awaiting_idle(); + } else { + states.insert(session_id_clone.clone(), SessionState::new(session_id_clone.clone())); + states.get_mut(&session_id_clone).unwrap().mark_awaiting_idle(); + } + } + } + } + }); + + // CRITICAL: Return immediately to event loop! + // The spawned task will handle all response processing + // This allows notifications to be processed in real-time + } + + Some(ConnectorCommand::Reconnect) => { + info!(connector_id = %id, "Received Reconnect command, restarting connection"); + + // Close transport + let _ = transport.close().await; + + // Reset retry count and clear stale state + retry_count = 0; + internal_state.clear().await; + pending_agent_requests.lock().await.clear(); + break 'event_loop; + } + + Some(ConnectorCommand::AgentResponse { request_id, response }) => { + info!( + connector_id = %id, + request_id = %request_id, + "Received AgentResponse command" + ); + + // Convert request_id to string for HashSet lookup + let request_id_str = request_id.to_string(); + + // Check if this request is still pending and remove it + let was_pending = { + let mut pending = pending_agent_requests.lock().await; + pending.remove(&request_id_str) + }; + + if was_pending { + // Send response directly to agent via transport + info!( + connector_id = %id, + request_id = %request_id_str, + "Sending agent response to transport" + ); + + if let Err(e) = transport.send(response).await { + error!( + connector_id = %id, + request_id = %request_id_str, + error = %e, + "Failed to send agent response to transport" + ); + } else { + info!( + connector_id = %id, + request_id = %request_id_str, + "Successfully sent agent response to transport" + ); + } + } else { + warn!( + connector_id = %id, + request_id = %request_id_str, + "Received AgentResponse for unknown request_id (not found in pending requests)" + ); + } + } + + Some(ConnectorCommand::SetSessionMode { session_id, mode_id }) => { + info!( + connector_id = %id, + session_id = %session_id, + mode_id = %mode_id, + "Processing SetSessionMode command" + ); + + // Check if this session has config_options — if so, use the new API + let has_config_options = internal_state.get_session(&session_id).await + .and_then(|s| s.config_options.as_ref().map(|opts| !opts.is_empty())) + .unwrap_or(false); + + let request = if has_config_options { + debug!(connector_id = %id, "Using session/set_config_option for mode change"); + protocol::build_set_config_option_request(&session_id, "mode", &mode_id) + } else { + // Legacy path for agents that don't support configOptions + json!({ + "jsonrpc": "2.0", + "method": "session/set_mode", + "params": { + "sessionId": session_id, + "modeId": mode_id + } + }) + }; + + let using_config_option = has_config_options; + + // Send request and wait for response + match Self::send_request( + &mut transport, + &protocol_handler, + request, + Some(&events_tx), Some(&sharing_bus), connector_uid, + &id, + Some(&pending_agent_requests), + Some(&mut cmd_rx), + Some(&internal_state), + ).await { + Ok(response) => { + info!( + connector_id = %id, + session_id = %session_id, + mode_id = %mode_id, + "Mode changed successfully" + ); + + // When using set_config_option, parse configOptions response + if using_config_option { + if let Some(result) = response.get("result") { + let config_options: Option> = result + .get("configOptions") + .and_then(|co| serde_json::from_value(co.clone()).ok()); + + if let Some(ref opts) = config_options { + if let Some(mut session_info) = internal_state.get_session(&session_id).await { + session_info.config_options = Some(opts.clone()); + internal_state.upsert_session(session_info).await; + } + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionMetadataReceived { + connector_id: id.clone(), + session_id: session_id.clone(), + models: None, + modes: None, + config_options, + }, + ) + .await; + } + } + } + } + Err(e) => { + error!( + connector_id = %id, + session_id = %session_id, + mode_id = %mode_id, + error = %e, + "Failed to change mode" + ); + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionError { + connector_id: id.clone(), + session_id: session_id.clone(), + error_message: format!("Failed to set mode: {}", e), + is_recoverable: true, + error_code: Some("SET_MODE_FAILED".to_string()), + technical_details: Some(format!("{:?}", e)), + context: Some(serde_json::json!({ + "operation": "session/set_mode", + "mode_id": mode_id, + })), + }, + ) + .await; + } + } + } + + Some(ConnectorCommand::SetSessionModel { session_id, model_id }) => { + info!( + connector_id = %id, + session_id = %session_id, + model_id = %model_id, + "Processing SetSessionModel command" + ); + + // Check if this session has config_options — if so, use the new API + let has_config_options = internal_state.get_session(&session_id).await + .and_then(|s| s.config_options.as_ref().map(|opts| !opts.is_empty())) + .unwrap_or(false); + + let request = if has_config_options { + debug!(connector_id = %id, "Using session/set_config_option for model change"); + protocol::build_set_config_option_request(&session_id, "model", &model_id) + } else { + // Legacy path for agents that don't support configOptions + json!({ + "jsonrpc": "2.0", + "method": "session/set_model", + "params": { + "sessionId": session_id, + "modelId": model_id + } + }) + }; + + let using_config_option = has_config_options; + + // Send request and wait for response + match Self::send_request( + &mut transport, + &protocol_handler, + request, + Some(&events_tx), Some(&sharing_bus), connector_uid, + &id, + Some(&pending_agent_requests), + Some(&mut cmd_rx), + Some(&internal_state), + ).await { + Ok(response) => { + info!( + connector_id = %id, + session_id = %session_id, + model_id = %model_id, + "Model changed successfully" + ); + + // When using set_config_option, parse configOptions response + if using_config_option { + if let Some(result) = response.get("result") { + let config_options: Option> = result + .get("configOptions") + .and_then(|co| serde_json::from_value(co.clone()).ok()); + + if let Some(ref opts) = config_options { + if let Some(mut session_info) = internal_state.get_session(&session_id).await { + session_info.config_options = Some(opts.clone()); + internal_state.upsert_session(session_info).await; + } + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionMetadataReceived { + connector_id: id.clone(), + session_id: session_id.clone(), + models: None, + modes: None, + config_options, + }, + ) + .await; + } + } + } + } + Err(e) => { + error!( + connector_id = %id, + session_id = %session_id, + model_id = %model_id, + error = %e, + "Failed to change model" + ); + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionError { + connector_id: id.clone(), + session_id: session_id.clone(), + error_message: format!("Failed to set model: {}", e), + is_recoverable: true, + error_code: Some("SET_MODEL_FAILED".to_string()), + technical_details: Some(format!("{:?}", e)), + context: Some(serde_json::json!({ + "operation": "session/set_model", + "model_id": model_id, + })), + }, + ) + .await; + } + } + } + + Some(ConnectorCommand::SetConfigOption { session_id, config_id, value }) => { + debug!(connector_id = %id, session_id = %session_id, config_id = %config_id, value = %value, "Processing SetConfigOption command"); + + let request = protocol::build_set_config_option_request(&session_id, &config_id, &value); + match Self::send_request( + &mut transport, + &protocol_handler, + request, + Some(&events_tx), Some(&sharing_bus), connector_uid, + &id, + Some(&pending_agent_requests), + Some(&mut cmd_rx), + Some(&internal_state), + ).await { + Ok(response) => { + info!(connector_id = %id, session_id = %session_id, config_id = %config_id, "Config option set successfully"); + + // The response contains the complete updated configOptions array + if let Some(result) = response.get("result") { + let config_options: Option> = result + .get("configOptions") + .and_then(|co| serde_json::from_value(co.clone()).ok()); + + if let Some(ref opts) = config_options { + // Update internal state + if let Some(mut session_info) = internal_state.get_session(&session_id).await { + session_info.config_options = Some(opts.clone()); + internal_state.upsert_session(session_info).await; + } + + // Emit SessionMetadataReceived with updated config options + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionMetadataReceived { + connector_id: id.clone(), + session_id: session_id.clone(), + models: None, + modes: None, + config_options, + }, + ) + .await; + } + } + } + Err(e) => { + warn!(connector_id = %id, session_id = %session_id, error = %e, "Failed to set config option"); + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("Failed to set config option: {}", e), + }, + ) + .await; + } + } + } + + Some(ConnectorCommand::CloseSession { session_id }) => { + debug!(connector_id = %id, session_id = %session_id, "Processing CloseSession command"); + + let supports_close = internal_state.agent_supports_session_close().await; + if !supports_close { + warn!(connector_id = %id, "Agent does not support session/close"); + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: "Agent does not support session close".to_string(), + }, + ) + .await; + } else { + let request = protocol::build_session_close_request(&session_id); + match Self::send_request( + &mut transport, &protocol_handler, request, + Some(&events_tx), Some(&sharing_bus), connector_uid, &id, Some(&pending_agent_requests), + Some(&mut cmd_rx), Some(&internal_state), + ).await { + Ok(_response) => { + info!(connector_id = %id, session_id = %session_id, "Session closed successfully"); + + // Remove from loaded sessions tracking + internal_state.unmark_session_loaded(&session_id).await; + + // Update internal state — mark session as idle + if let Some(mut session_info) = internal_state.get_session(&session_id).await { + session_info.status = SessionStatus::Idle; + internal_state.upsert_session(session_info).await; + } + + // Emit SessionClosed event + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionClosed { + connector_id: id.clone(), + session_id: session_id.clone(), + }, + ) + .await; + } + Err(e) => { + warn!(connector_id = %id, session_id = %session_id, error = %e, "Failed to close session"); + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("Failed to close session: {}", e), + }, + ) + .await; + } + } + } + } + + Some(ConnectorCommand::Shutdown) => { + info!(connector_id = %id, "Received Shutdown command, stopping connector"); + + // Close transport + let _ = transport.close().await; + + // Deregister session nodes from inspector + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + inspector_deregister_all_sessions(inspector, &id, &internal_state).await; + } + + // Update state to Stopped + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Stopped; + } + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: "Offline".to_string(), + error_kind: None, + }, + ) + .await; + // Exit both loops + break 'reconnect_loop; + } + + None => { + error!(connector_id = %id, "Command channel closed, stopping connector"); + + // Close transport + let _ = transport.close().await; + + // Deregister session nodes from inspector + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + inspector_deregister_all_sessions(inspector, &id, &internal_state).await; + } + + // Update state to Stopped + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Stopped; + } + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: "Offline".to_string(), + error_kind: None, + }, + ) + .await; + // Exit both loops + break 'reconnect_loop; + } + } + } + + // PRIORITY 4: Housekeeping (idle detection and timeout cleanup) + _ = idle_check_interval.tick() => { + // Check for sessions that should emit SessionIdle + let idled = check_idle_sessions(&session_states, &events_tx, &sharing_bus, connector_uid, &id).await; + + // Update inspector state for sessions that just went idle + #[cfg(feature = "server")] + if !idled.is_empty() { + if let Some(ref inspector) = inspector { + for session_id in &idled { + inspector_update_session_state( + inspector, + &id, + session_id, + dirigent_inspector::NodeState::Idle, + ).await; + } + } + } + // TODO: Task 02 cleanup_timed_out_requests() can be added here if needed + } + } + } + + // If we exited the event loop (not due to Shutdown), attempt reconnect + info!(connector_id = %id, "Event loop exited, attempting reconnect"); + } + + info!(connector_id = %id, "ACP connector task stopped"); + } + + /// Create transport based on configuration + async fn create_transport( + config: &AcpConfig, + #[cfg(feature = "server")] process_manager: Option<&Arc>, + ) -> AcpResult> { + match &config.transport { + TransportKind::Stdio { + command, + args, + cwd, + env, + } => { + let mut transport = StdioTransport::new(command, args); + + // Set working directory if specified + if let Some(cwd_path) = cwd { + transport.set_cwd(cwd_path.clone()); + } + + // Set environment variables if specified + for (key, value) in env { + transport.set_env(key.clone(), value.clone()); + } + + // Set protocol logging directory: explicit config > data_dir/logs/acp/ + let log_dir = config.acp_log_dir.clone().or_else(|| { + #[cfg(feature = "server")] + { + Some(dirigent_config::DirigentPaths::resolve() + .map(|p| p.logs_dir().join("acp")) + .unwrap_or_else(|_| PathBuf::from("logs/acp"))) + } + #[cfg(not(feature = "server"))] + { None } + }); + if let Some(log_dir) = log_dir { + transport.set_log_dir(log_dir); + } + + // Wire process lifecycle for graceful shutdown management + #[cfg(feature = "server")] + if let Some(mgr) = process_manager { + transport.set_process_lifecycle(mgr.create_lifecycle()); + } + + transport.connect().await?; + Ok(Box::new(transport)) + } + TransportKind::Http { + base_url, + timeout_ms, + } => { + // Generate client_id for SSE subscription + // This ensures SSE connection can be established before RPC calls + let client_id = uuid::Uuid::new_v4().to_string(); + + info!( + base_url = %base_url, + client_id = %client_id, + "Generated client_id for HTTP transport" + ); + + let mut transport = HttpTransport::new(base_url, client_id); + + // Set timeout if specified + if let Some(timeout) = timeout_ms { + transport.set_timeout(Duration::from_millis(*timeout)); + } + + transport.connect().await?; + Ok(Box::new(transport)) + } + } + } + + /// Perform ACP initialization handshake + async fn initialize_acp( + id: &str, + transport: &mut Box, + protocol_handler: &ProtocolHandler, + config: &AcpConfig, + internal_state: &Arc, + ) -> AcpResult<()> { + // Build initialize request with capabilities from default ownership + let capabilities = config.default_ownership.capabilities_for_agent(); + let request = build_initialize_request(config.protocol_version, capabilities); + + // Send request and wait for response + // Note: During initialization, we don't have access to event channels, pending requests, cmd_rx, or internal_state, + // so we pass None for those parameters. Agent requests during init are unexpected. + let response = Self::send_request( + transport, + protocol_handler, + request, + None, + None, + None, + "init", + None, + None, + None, + ) + .await?; + + // Extract result + let result = response + .get("result") + .ok_or_else(|| AcpError::protocol("Initialize response missing result field"))?; + + // Extract protocol version + let protocol_version = result + .get("protocolVersion") + .and_then(|v| v.as_u64()) + .ok_or_else(|| AcpError::protocol("Initialize result missing protocolVersion"))? + as u32; + + // Extract agent capabilities + let capabilities = result + .get("agentCapabilities") + .cloned() + .unwrap_or(json!({})); + + info!( + connector_id = %id, + protocol_version = protocol_version, + "ACP initialization complete" + ); + + // Update internal state + internal_state.set_protocol_version(protocol_version).await; + internal_state.set_agent_capabilities(capabilities).await; + + Ok(()) + } + + /// Send a request via transport with protocol handler correlation + /// + /// This method properly integrates the protocol handler for request/response correlation. + /// It prepares the request with a unique ID, sends it via transport, and waits for the + /// correlated response with timeout. + /// + /// **IMPORTANT**: This method concurrently receives messages from the transport while waiting + /// for the response, avoiding deadlock when the transport requires active receiving. + async fn send_request( + transport: &mut Box, + protocol_handler: &ProtocolHandler, + request: Value, + events_sender: Option<&broadcast::Sender>, + sharing_bus: Option<&Arc>, + connector_uid: Option, + connector_id: &str, + pending_agent_requests: Option<&Arc>>>, + mut cmd_rx: Option<&mut mpsc::Receiver>, + internal_state: Option<&Arc>, + ) -> AcpResult { + // Prepare request (adds ID and creates response channel) + let (message_with_id, mut response_rx) = protocol_handler.prepare_request(request).await; + + let request_id = message_with_id.get("id").cloned().unwrap_or(json!(null)); + debug!(request_id = %request_id, "Prepared request with ID"); + + // Send via transport + let masked_message = dirigent_protocol::log_utils::format_for_log(&message_with_id); + debug!(request_id = %request_id, message = %masked_message, "Calling transport.send()"); + transport.send(message_with_id).await?; + debug!(request_id = %request_id, "transport.send() completed successfully"); + + // Wait for correlated response while also receiving messages from transport + // This prevents deadlock when the agent sends the response immediately + // Timeout set to 5 minutes to allow for tool use, thinking, and complex operations + let timeout_duration = Duration::from_millis(300_000); + let deadline = tokio::time::Instant::now() + timeout_duration; + + loop { + tokio::select! { + // Check if we got the response via protocol handler + response = &mut response_rx => { + match response { + Ok(response) => { + // Check for JSON-RPC error in response + if let Some(error) = response.get("error") { + let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1); + let message = error + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error") + .to_string(); + let data = error.get("data").cloned(); + + return Err(AcpError::AgentError { + code, + message, + data, + }); + } + + return Ok(response); + } + Err(_) => { + return Err(AcpError::internal("Response channel dropped")); + } + } + } + + // Continue receiving messages from transport and route them + message = transport.recv() => { + match message { + Ok(Some(msg)) => { + // Route through protocol handler (may complete response_rx) + use crate::connectors::acp::protocol::MessageHandlerResult; + match protocol_handler.handle_message(msg).await { + MessageHandlerResult::None => { + // Notification processed, continue + } + MessageHandlerResult::Response(response) => { + debug!("📤 Sending response to agent request during send_request"); + if let Err(e) = transport.send(response).await { + error!(error = %e, "Failed to send response to agent during send_request"); + } + } + MessageHandlerResult::AgentRequest { request_id, method, params } => { + // Agent request during send_request - handle it if we have the channels + if let (Some(events_tx), Some(pending_reqs)) = (events_sender, pending_agent_requests) { + info!( + connector_id = connector_id, + request_id = %request_id, + method = %method, + "Received agent request during send_request, emitting event" + ); + + // Extract session_id from params (standard ACP location) + let session_id = params + .get("sessionId") + .and_then(|v| v.as_str()) + .unwrap_or("unknown") + .to_string(); + + // Determine if this is a forwarded (external) session + let is_forwarded = if let Some(state) = internal_state { + if let Some(session_info) = state.get_session(&session_id).await { + session_info.ownership.is_external() + } else { + warn!( + connector_id = connector_id, + session_id = %session_id, + "Session not found when processing AgentRequest in send_request - defaulting to internal" + ); + false + } + } else { + // No internal_state provided (e.g., during initialization) - default to internal + false + }; + + // Store the pending request (no timeout - wait indefinitely for user) + let request_id_str = request_id.to_string(); + { + let mut pending = pending_reqs.lock().await; + pending.insert(request_id_str.clone()); + } + info!( + connector_id = connector_id, + request_id = %request_id_str, + is_forwarded = is_forwarded, + "Stored pending agent request (waiting for user response)" + ); + + // Emit Event::AgentRequest for routing + // - If is_forwarded=true: EventBridge forwards to external client + // - If is_forwarded=false: Web UI shows permission modal + let event = Event::AgentRequest { + connector_id: connector_id.to_string(), + session_id, + request_id: request_id.clone(), + method, + params, + is_forwarded, + }; + + if let Some(bus) = sharing_bus { + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + event.clone(), + connector_uid, + connector_id.to_string(), + ); + bus.publish(bus_event).await; + } + if let Err(e) = events_tx.send(event) { + error!( + connector_id = connector_id, + error = %e, + "Failed to emit AgentRequest event" + ); + } + + // Response will be handled when AgentResponse command is received + // No timeout - the system waits indefinitely for user input + } else { + warn!( + connector_id = connector_id, + "Received agent request during send_request, but no event channel available (ignoring)" + ); + } + } + } + // Continue loop to check response_rx again + } + Ok(None) => { + return Err(AcpError::transport("Transport closed while waiting for response")); + } + Err(e) => { + return Err(AcpError::transport(format!("Transport error: {}", e))); + } + } + } + + // Handle incoming commands (specifically AgentResponse to avoid deadlock) + cmd = async { + if let Some(ref mut rx) = cmd_rx { + rx.recv().await + } else { + std::future::pending().await // Never resolves if no cmd_rx + } + } => { + match cmd { + Some(ConnectorCommand::AgentResponse { request_id, response }) => { + info!( + connector_id = connector_id, + request_id = %request_id, + "Received AgentResponse command during send_request" + ); + + // Convert request_id to string for HashSet lookup + let request_id_str = request_id.to_string(); + + // Check if this request is still pending and remove it + if let Some(pending_reqs) = pending_agent_requests { + let was_pending = { + let mut pending = pending_reqs.lock().await; + pending.remove(&request_id_str) + }; + + if was_pending { + // Send response directly to agent via transport + info!( + connector_id = connector_id, + request_id = %request_id_str, + "Sending agent response to transport during send_request" + ); + + if let Err(e) = transport.send(response).await { + error!( + connector_id = connector_id, + request_id = %request_id_str, + error = %e, + "Failed to send agent response to transport during send_request" + ); + } else { + info!( + connector_id = connector_id, + request_id = %request_id_str, + "Successfully sent agent response to transport during send_request" + ); + } + } else { + warn!( + connector_id = connector_id, + request_id = %request_id_str, + "Received AgentResponse for unknown request_id during send_request" + ); + } + } + // Continue loop to check response_rx again + } + Some(other_cmd) => { + warn!( + connector_id = connector_id, + "Received non-AgentResponse command during send_request (ignoring): {:?}", + std::mem::discriminant(&other_cmd) + ); + // Put the command back or drop it - for now we'll drop it + // In the future, we might want to buffer these for later processing + } + None => { + warn!( + connector_id = connector_id, + "Command channel closed during send_request" + ); + // Continue waiting for response + } + } + } + + // Check for timeout + _ = tokio::time::sleep_until(deadline) => { + return Err(AcpError::timeout(timeout_duration.as_millis() as u64)); + } + } + } + } + + /// Execute session listing: fetch from upstream agent if supported, fall back to internal state. + /// + /// Emits `SessionsListed` (and optionally `SessionMetadataReceived` per session). + /// Called by both the `ListSessions` command handler and automatically after + /// `CreateSession` to populate the session list for deferred obligations. + async fn execute_list_sessions( + transport: &mut Box, + protocol_handler: &ProtocolHandler, + events_tx: &broadcast::Sender, + sharing_bus: &Arc, + connector_uid: Option, + connector_id: &str, + config: &AcpConfig, + internal_state: &Arc, + pending_agent_requests: &Arc>>, + cmd_rx: &mut mpsc::Receiver, + ) { + debug!(connector_id = %connector_id, "Executing ListSessions"); + + // Check if the upstream agent supports session/list + let supports_list = internal_state.agent_supports_list_sessions().await; + + // Try to fetch from upstream if supported + let upstream_sessions: Option> = if supports_list { + let mut all_upstream: Vec = Vec::new(); + let mut cursor: Option = None; + let mut fetch_ok = true; + + loop { + let request = protocol::build_session_list_request( + Some(&config.cwd), + cursor.as_deref(), + ); + + match Self::send_request( + transport, + protocol_handler, + request, + Some(events_tx), + Some(sharing_bus), + connector_uid, + connector_id, + Some(pending_agent_requests), + Some(cmd_rx), + Some(internal_state), + ) + .await + { + Ok(response) => { + if let Some(result) = response.get("result") { + if let Some(sessions) = + result.get("sessions").and_then(|s| s.as_array()) + { + all_upstream.extend(sessions.iter().cloned()); + } + cursor = result + .get("nextCursor") + .and_then(|c| c.as_str()) + .map(|s| s.to_string()); + if cursor.is_none() { + break; + } + } else { + warn!(connector_id = %connector_id, "session/list response missing result"); + fetch_ok = false; + break; + } + } + Err(e) => { + warn!(connector_id = %connector_id, error = %e, "session/list failed, falling back to internal state"); + fetch_ok = false; + break; + } + } + } + + if fetch_ok { + // Parse upstream sessions into SessionInfo + let mut session_map: HashMap = HashMap::new(); + + for session_json in &all_upstream { + let session_id = session_json + .get("sessionId") + .and_then(|v| v.as_str()) + .unwrap_or_default() + .to_string(); + + if session_id.is_empty() { + continue; + } + + let title = session_json + .get("title") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + + let cwd = session_json + .get("cwd") + .and_then(|v| v.as_str()) + .unwrap_or(&config.cwd) + .to_string(); + + let updated_at = session_json + .get("updatedAt") + .and_then(|v| v.as_str()) + .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok()) + .map(|dt| dt.with_timezone(&Utc)) + .unwrap_or_else(Utc::now); + + let meta = session_json.get("_meta"); + let message_count = meta + .and_then(|m| m.get("messageCount")) + .and_then(|v| v.as_u64()) + .unwrap_or(0) as u32; + + let info = SessionInfo { + id: session_id.clone(), + title, + cwd, + message_count, + model: None, + created_at: updated_at, + last_activity: updated_at, + status: SessionStatus::Idle, + models: None, + modes: None, + config_options: None, + ownership: SessionOwnership::default(), + }; + + session_map.insert(session_id, info); + } + + // Union merge: add internal state sessions that aren't already upstream + let internal_sessions = internal_state.list_sessions().await; + for internal in internal_sessions { + session_map.entry(internal.id.clone()).or_insert(internal); + } + + // Upsert newly discovered sessions into internal state + for info in session_map.values() { + internal_state.upsert_session(info.clone()).await; + } + + info!(connector_id = %connector_id, count = session_map.len(), "Listed sessions from upstream agent (merged with internal state)"); + Some(session_map.into_values().collect()) + } else { + None + } + } else { + None + }; + + // Use upstream sessions or fall back to internal state + let sessions = match upstream_sessions { + Some(s) => s, + None => internal_state.list_sessions().await, + }; + + // Convert SessionInfo to protocol::Session and emit metadata events + let mut protocol_sessions: Vec = Vec::with_capacity(sessions.len()); + + for info in sessions { + let current_mode_id = info.modes.as_ref().map(|m| m.current_mode_id.clone()); + + protocol_sessions.push(Session { + id: info.id.clone(), + title: info + .title + .clone() + .unwrap_or_else(|| "Untitled Session".to_string()), + created_at: info.created_at, + updated_at: info.last_activity, + cwd: Some(info.cwd.clone()), + metadata: SessionMetadata { + project_path: info.cwd, + model: info.model, + total_messages: info.message_count, + system_message: None, + current_mode_id, + _meta: None, + project_id: None, + }, + models: info.models.clone(), + modes: info.modes.clone(), + config_options: None, + acp_client_id: None, + }); + + if info.models.is_some() || info.modes.is_some() || info.config_options.is_some() { + emit_event( + sharing_bus, + events_tx, + &connector_id.to_string(), + connector_uid, + Event::SessionMetadataReceived { + connector_id: connector_id.to_string(), + session_id: info.id, + models: info.models, + modes: info.modes, + config_options: info.config_options, + }, + ) + .await; + } + } + + emit_event( + sharing_bus, + events_tx, + &connector_id.to_string(), + connector_uid, + Event::SessionsListed { + connector_id: connector_id.to_string(), + sessions: protocol_sessions, + }, + ) + .await; + } + + /// Prepare request without sending (for use in protocol handler) + // TODO: Request preparation - utility for building ACP requests with validation and defaults + #[allow(dead_code)] + async fn prepare_request( + protocol_handler: &ProtocolHandler, + request: Value, + ) -> (Value, tokio::sync::oneshot::Receiver) { + protocol_handler.prepare_request(request).await + } + + /// Handle a notification from the agent using ACP adapter + /// + /// Translates ACP JSON-RPC notifications to Dirigent protocol events. + async fn handle_notification( + id: &str, + notification: Value, + events_tx: &broadcast::Sender, + sharing_bus: &Arc, + connector_uid: Option, + adapter: &dirigent_protocol::adapters::AcpAdapter, + internal_state: &Arc, + session_states: &Arc>>, + tool_configuration: &Option, + ) -> AcpResult<()> { + // Extract session ID and update activity timestamp + if let Some(params) = notification.get("params") { + if let Some(notif_session_id) = params.get("sessionId").and_then(|s| s.as_str()) { + debug_log!( + "🔔 Notification sessionId: {} (method: {})", + notif_session_id, + notification + .get("method") + .and_then(|m| m.as_str()) + .unwrap_or("unknown") + ); + + // Update session activity timestamp + { + let mut states = session_states.lock().await; + if let Some(state) = states.get_mut(notif_session_id) { + state.touch(); + trace!(session_id = %notif_session_id, "Updated session activity timestamp"); + } + } + } + } + let masked_notification = dirigent_protocol::log_utils::format_for_log(¬ification); + info!(connector_id = %id, notification = %masked_notification, "🔄 Processing ACP notification in handle_notification"); + + // Save notification data for fallback before it's moved + let notification_clone = notification.clone(); + + // Translate notification to Dirigent event using adapter + match adapter.translate_notification(notification).await { + Ok(event) => { + info!(connector_id = %id, "✅ Successfully translated notification to event"); + + // Replace ACP message_id with Dirigent message_id for streaming chunks + // and inject connector_id + let event = match event { + Event::SessionUpdate { + connector_id: _, + session_id, + update, + } => { + match update { + SessionUpdate::AgentMessageChunk { + message_id: acp_message_id, + content, + _meta, + } => { + // Get or create the Dirigent message_id for this session + let dirigent_message_id = internal_state + .get_or_create_active_message(&session_id) + .await; + + debug_log!( + "🔄 Message ID translation: ACP={} → Dirigent={}", + acp_message_id, + dirigent_message_id + ); + info!( + connector_id = %id, + session_id = %session_id, + acp_message_id = %acp_message_id, + dirigent_message_id = %dirigent_message_id, + "🔄 Translating ACP message_id to Dirigent message_id" + ); + + // Store the original ACP message_id in metadata + let mut extra = HashMap::new(); + extra.insert("acp_message_id".to_string(), json!(acp_message_id)); + let meta = Meta { + provider: None, + extra, + }; + + Event::SessionUpdate { + connector_id: id.to_string(), + session_id, + update: SessionUpdate::AgentMessageChunk { + message_id: dirigent_message_id, + content, + _meta: Some(meta), + }, + } + } + SessionUpdate::ToolCall { + message_id: acp_message_id, + mut tool_call, + _meta, + } => { + // Use same Dirigent message_id as surrounding chunks + let dirigent_message_id = internal_state + .get_or_create_active_message(&session_id) + .await; + + // Mark as external - Claude Code executes these tools + tool_call.origin = Some(ToolOrigin::External); + + debug_log!( + "🔄 ToolCall message ID translation: ACP={} → Dirigent={}", + acp_message_id, + dirigent_message_id + ); + info!( + connector_id = %id, + session_id = %session_id, + acp_message_id = %acp_message_id, + dirigent_message_id = %dirigent_message_id, + tool_name = %tool_call.tool_name, + "🔄 Translating ToolCall message_id to Dirigent message_id" + ); + + Event::SessionUpdate { + connector_id: id.to_string(), + session_id, + update: SessionUpdate::ToolCall { + message_id: dirigent_message_id, + tool_call, + _meta, + }, + } + } + SessionUpdate::ToolCallUpdate { + message_id: acp_message_id, + tool_call_id, + mut tool_call, + _meta, + } => { + // Use same Dirigent message_id as surrounding chunks + let dirigent_message_id = internal_state + .get_or_create_active_message(&session_id) + .await; + + // Mark as external - Claude Code executes these tools + tool_call.origin = Some(ToolOrigin::External); + + debug_log!("🔄 ToolCallUpdate message ID translation: ACP={} → Dirigent={}", acp_message_id, dirigent_message_id); + info!( + connector_id = %id, + session_id = %session_id, + acp_message_id = %acp_message_id, + dirigent_message_id = %dirigent_message_id, + tool_call_id = %tool_call_id, + "🔄 Translating ToolCallUpdate message_id to Dirigent message_id" + ); + + Event::SessionUpdate { + connector_id: id.to_string(), + session_id, + update: SessionUpdate::ToolCallUpdate { + message_id: dirigent_message_id, + tool_call_id, + tool_call, + _meta, + }, + } + } + // Check for config_option_update, available_commands_update, etc. + SessionUpdate::Unknown { ref data } => { + // Check the sessionUpdate type for special handling + let merged_event = if let Some(session_update_type) = + data.get("sessionUpdate").and_then(|s| s.as_str()) + { + if session_update_type == "config_option_update" { + // Parse configOptions from the notification + if let Some(config_options_json) = data.get("configOptions") { + if let Ok(config_options) = serde_json::from_value::< + Vec, + >(config_options_json.clone()) { + info!( + connector_id = %id, + session_id = %session_id, + option_count = config_options.len(), + "📊 Received config_option_update notification" + ); + + // Update internal state with new config options + if let Some(mut session_info) = internal_state.get_session(&session_id).await { + session_info.config_options = Some(config_options.clone()); + session_info.last_activity = Utc::now(); + internal_state.upsert_session(session_info).await; + } + + // Emit SessionMetadataReceived event with config options + // This allows downstream consumers (event bridge, archivist, UI) to react + emit_event( + sharing_bus, + events_tx, + &id.to_string(), + connector_uid, + Event::SessionMetadataReceived { + connector_id: id.to_string(), + session_id: session_id.clone(), + models: None, + modes: None, + config_options: Some(config_options), + }, + ) + .await; + + // Also pass through as Unknown for the event bridge to forward + Some(Event::SessionUpdate { + connector_id: id.to_string(), + session_id: session_id.clone(), + update: SessionUpdate::Unknown { data: data.clone() }, + }) + } else { + warn!(connector_id = %id, "Failed to parse configOptions from config_option_update notification"); + None + } + } else { + warn!(connector_id = %id, "config_option_update notification missing configOptions field"); + None + } + } else if session_update_type == "available_commands_update" { + // Extract commands from the data + if let Some(commands_json) = data.get("availableCommands") { + // Parse commands and merge with transient commands + if let Ok(mut commands) = serde_json::from_value::< + Vec, + >( + commands_json.clone() + ) { + info!( + connector_id = %id, + session_id = %session_id, + original_count = commands.len(), + "🔄 Intercepted available_commands_update, merging transient commands" + ); + + // Merge transient commands + commands = crate::connectors::gateway::merge_with_transient_commands(commands); + + info!( + connector_id = %id, + session_id = %session_id, + merged_count = commands.len(), + "✅ Merged transient commands into available_commands_update" + ); + + // Reconstruct the update with merged commands + let mut merged_data = data.clone(); + if let Some(obj) = merged_data.as_object_mut() { + obj.insert( + "availableCommands".to_string(), + serde_json::to_value(&commands) + .unwrap_or_default(), + ); + } + + Some(Event::SessionUpdate { + connector_id: id.to_string(), + session_id: session_id.clone(), + update: SessionUpdate::Unknown { + data: merged_data, + }, + }) + } else { + None + } + } else { + None + } + } else { + None + } + } else { + None + }; + + // Return merged event if we processed it, otherwise pass through unchanged + merged_event.unwrap_or_else(|| Event::SessionUpdate { + connector_id: id.to_string(), + session_id, + update: SessionUpdate::Unknown { data: data.clone() }, + }) + } + // Pass through all other update types unchanged, but add connector_id + other_update => Event::SessionUpdate { + connector_id: id.to_string(), + session_id, + update: other_update, + }, + } + } + // Pass through all other event types unchanged + other_event => other_event, + }; + + // --- Tool directive intercept --- + // Check ToolCall events against the connector's tool configuration. + // Denied tools are marked as errors; hidden tools are silently dropped. + let event = if let Some(ref tool_config) = tool_configuration { + Self::apply_tool_directive(event, tool_config, id) + } else { + Some(event) + }; + + // Broadcast the event (if not hidden by tool directive) + if let Some(event) = event { + debug_log!("📢 Broadcasting event: {:?}", event); + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + event.clone(), + connector_uid, + id.to_string(), + ); + sharing_bus.publish(bus_event).await; + match events_tx.send(event) { + Ok(receiver_count) => { + debug_log!("✅ Event sent to {} receivers", receiver_count); + info!(connector_id = %id, receiver_count = receiver_count, "📢 Event broadcast to {} receivers", receiver_count); + } + Err(e) => { + debug_log!("⚠️ No receivers for event"); + warn!(connector_id = %id, error = ?e, "⚠️ No receivers for event (this is normal if no clients are listening)"); + } + } + } + + Ok(()) + } + Err(e) => { + // Translation failed + match e { + dirigent_protocol::adapters::AcpTranslationError::Duplicate => { + // Duplicates are expected and should be silently skipped + debug!(connector_id = %id, "Skipped duplicate notification"); + Ok(()) + } + dirigent_protocol::adapters::AcpTranslationError::UnknownMethod(method) => { + // Unknown methods are logged at debug, not errors + debug!(connector_id = %id, method = method, "Unknown notification method"); + Ok(()) + } + dirigent_protocol::adapters::AcpTranslationError::InvalidValue { + ref field, + ref reason, + } => { + // Invalid values (like missing type field) are logged at debug + debug!(connector_id = %id, field = field, reason = reason, "Skipped notification with invalid value"); + Ok(()) + } + dirigent_protocol::adapters::AcpTranslationError::MissingField(ref field) => { + // Log warning when falling back to Unknown event + warn!( + connector_id = %id, + field = field, + "Missing field in notification, forwarding as Unknown" + ); + + // Extract session ID from notification + let session_id = notification_clone + .get("params") + .and_then(|p| p.get("sessionId")) + .and_then(|s| s.as_str()) + .unwrap_or("unknown") + .to_string(); + + // Create Unknown event with raw JSON for pass-through + let event = Event::SessionUpdate { + connector_id: id.to_string(), + session_id, + update: SessionUpdate::Unknown { + data: notification_clone, + }, + }; + + // Broadcast the event + debug_log!("📢 Broadcasting Unknown event: {:?}", event); + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + event.clone(), + connector_uid, + id.to_string(), + ); + sharing_bus.publish(bus_event).await; + match events_tx.send(event) { + Ok(receiver_count) => { + debug_log!("✅ Unknown event sent to {} receivers", receiver_count); + info!(connector_id = %id, receiver_count = receiver_count, "📢 Unknown event broadcast to {} receivers", receiver_count); + } + Err(e) => { + debug_log!("⚠️ No receivers for Unknown event"); + warn!(connector_id = %id, error = ?e, "⚠️ No receivers for Unknown event (this is normal if no clients are listening)"); + } + } + + Ok(()) + } + } + } + } + } + + /// Build a `Meta` value containing the tool configuration (if present). + /// + /// Used when emitting `SessionCreated` events so the archivist can persist + /// the tool configuration in the session metadata. + fn build_session_meta( + tool_configuration: &Option, + ) -> Option { + let tool_config = tool_configuration.as_ref()?; + let tool_config_value = serde_json::to_value(tool_config).ok()?; + let mut extra = HashMap::new(); + extra.insert("tool_configuration".to_string(), tool_config_value); + Some(Meta { + provider: None, + extra, + }) + } + + /// Apply tool directive intercept to an event. + /// + /// Checks `SessionUpdate::ToolCall` events against the connector's tool + /// configuration: + /// - `Deny`: marks the tool call as failed with an error message + /// - `Hide`: returns `None` to suppress the event entirely + /// - Other handlers or unconfigured tools: pass through unchanged + fn apply_tool_directive( + event: Event, + tool_config: &crate::tools::ToolConfiguration, + connector_id: &str, + ) -> Option { + use crate::tools::ToolHandler; + use dirigent_protocol::types::ToolCallStatus; + + match event { + Event::SessionUpdate { + connector_id: cid, + session_id, + update: + SessionUpdate::ToolCall { + message_id, + mut tool_call, + _meta, + }, + } => { + if tool_config.should_intercept(&tool_call.tool_name) { + match tool_config.active_handler(&tool_call.tool_name) { + Some(ToolHandler::Hide) => { + info!( + connector_id = %connector_id, + tool_name = %tool_call.tool_name, + "Tool call hidden by directive" + ); + return None; + } + Some(ToolHandler::Deny) => { + info!( + connector_id = %connector_id, + tool_name = %tool_call.tool_name, + "Tool call denied by directive" + ); + tool_call.status = ToolCallStatus::Error; + tool_call.error = + Some("Tool denied by directive".to_string()); + } + _ => { + // Other interceptable handlers (Dirigent, Editor, + // Plugin) pass through for now -- future tasks will + // route them to the appropriate handler. + } + } + } + Some(Event::SessionUpdate { + connector_id: cid, + session_id, + update: SessionUpdate::ToolCall { + message_id, + tool_call, + _meta, + }, + }) + } + // All other events pass through unchanged. + other => Some(other), + } + } + + /// Build a crash reason string from transport crash context. + async fn build_crash_reason(transport: &Box, default_reason: &str) -> String { + if let Some(ctx) = transport.get_crash_context().await { + let mut reason = "Connection lost: agent process exited".to_string(); + if let Some(status) = ctx.exit_status { + reason.push_str(&format!(" ({})", status)); + } + if !ctx.recent_stderr.is_empty() { + let last_lines: Vec<&str> = ctx.recent_stderr + .iter() + .rev() + .take(5) + .map(|s| s.as_str()) + .collect::>() + .into_iter() + .rev() + .collect(); + reason.push_str(&format!(". Last stderr: {}", last_lines.join(" | "))); + } + reason + } else { + default_reason.to_string() + } + } +} + +impl Connector for AcpConnector { + fn id(&self) -> &ConnectorId { + &self.id + } + + fn kind(&self) -> ConnectorKind { + ConnectorKind::Acp + } + + fn owner(&self) -> &UserId { + &self.owner + } + + fn title(&self) -> &str { + &self.title + } + + fn state(&self) -> ConnectorState { + // Try to read the state without blocking + match self.state.try_read() { + Ok(state_guard) => state_guard.clone(), + Err(_) => { + // Lock is held, return Initializing as a safe default + ConnectorState::Initializing + } + } + } + + fn command_tx(&self) -> mpsc::Sender { + self.cmd_tx.clone() + } + + fn subscribe(&self) -> broadcast::Receiver { + self.events_tx.subscribe() + } + + fn stop(&self) { + // Send shutdown command + let cmd_tx = self.cmd_tx.clone(); + tokio::spawn(async move { + let _ = cmd_tx.send(ConnectorCommand::Shutdown).await; + }); + } +} + +// Note: ProtocolHandler's generate_request_id is private, +// so we'll use the pattern from the protocol layer properly + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_acp_connector_creation() { + let config = AcpConfig::stdio("test-agent", vec!["--stdio".to_string()]); + + let connector = AcpConnector::new( + "test-conn".to_string(), + uuid::Uuid::nil(), + "Test ACP Connector".to_string(), + config, + SharingBus::new(), + ) + .unwrap(); + + assert_eq!(connector.id(), "test-conn"); + assert_eq!(*connector.owner(), uuid::Uuid::nil()); + assert_eq!(connector.title(), "Test ACP Connector"); + assert_eq!(connector.kind(), ConnectorKind::Acp); + assert_eq!(connector.state(), ConnectorState::Initializing); + } + + #[tokio::test] + async fn test_acp_connector_invalid_config() { + let config = AcpConfig::stdio("", vec![]); // Empty command + + let result = AcpConnector::new( + "test-conn".to_string(), + uuid::Uuid::nil(), + "Test ACP Connector".to_string(), + config, + SharingBus::new(), + ); + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_acp_connector_command_tx() { + let config = AcpConfig::stdio("test-agent", vec![]); + + let connector = AcpConnector::new( + "test-conn".to_string(), + uuid::Uuid::nil(), + "Test ACP Connector".to_string(), + config, + SharingBus::new(), + ) + .unwrap(); + + // Should be able to clone the command sender + let cmd_tx1 = connector.command_tx(); + let cmd_tx2 = connector.command_tx(); + + drop(cmd_tx1); + drop(cmd_tx2); + } + + #[tokio::test] + async fn test_acp_connector_subscribe() { + let config = AcpConfig::stdio("test-agent", vec![]); + + let connector = AcpConnector::new( + "test-conn".to_string(), + uuid::Uuid::nil(), + "Test ACP Connector".to_string(), + config, + SharingBus::new(), + ) + .unwrap(); + + // Should be able to create multiple subscriptions + let _rx1 = connector.subscribe(); + let _rx2 = connector.subscribe(); + } + + // ===================================================================== + // Tool directive intercept tests + // ===================================================================== + + mod tool_directive_tests { + use crate::tools::{ToolConfiguration, ToolDirective, ToolHandler}; + use dirigent_protocol::types::{ToolCall, ToolCallStatus, ToolOrigin}; + use dirigent_protocol::{Event, SessionUpdate}; + + use super::AcpConnector; + + /// Helper to create a ToolCall event for testing. + fn make_tool_call_event(tool_name: &str) -> Event { + Event::SessionUpdate { + connector_id: "test-conn".to_string(), + session_id: "sess-1".to_string(), + update: SessionUpdate::ToolCall { + message_id: "msg-1".to_string(), + tool_call: ToolCall { + id: "tc-1".to_string(), + tool_name: tool_name.to_string(), + status: ToolCallStatus::Running, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: Some(ToolOrigin::External), + }, + _meta: None, + }, + } + } + + #[test] + fn tool_directive_deny_marks_error() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::checked("read", ToolHandler::Deny)); + + let event = make_tool_call_event("read"); + let result = AcpConnector::apply_tool_directive(event, &config, "test-conn"); + + let result = result.expect("Deny should still emit an event"); + if let Event::SessionUpdate { + update: SessionUpdate::ToolCall { tool_call, .. }, + .. + } = result + { + assert_eq!(tool_call.status, ToolCallStatus::Error); + assert_eq!( + tool_call.error.as_deref(), + Some("Tool denied by directive") + ); + } else { + panic!("Expected SessionUpdate::ToolCall event"); + } + } + + #[test] + fn tool_directive_hide_suppresses_event() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::checked("write", ToolHandler::Hide)); + + let event = make_tool_call_event("write"); + let result = AcpConnector::apply_tool_directive(event, &config, "test-conn"); + + assert!(result.is_none(), "Hide should suppress the event entirely"); + } + + #[test] + fn tool_directive_unchecked_passes_through() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::passthrough("bash")); + + let event = make_tool_call_event("bash"); + let result = AcpConnector::apply_tool_directive(event, &config, "test-conn"); + + assert!(result.is_some(), "Passthrough should not suppress the event"); + if let Some(Event::SessionUpdate { + update: SessionUpdate::ToolCall { tool_call, .. }, + .. + }) = result + { + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert!(tool_call.error.is_none()); + } else { + panic!("Expected SessionUpdate::ToolCall event"); + } + } + + #[test] + fn tool_directive_unknown_tool_passes_through() { + let config = ToolConfiguration::new(); // Empty config + + let event = make_tool_call_event("unknown_tool"); + let result = AcpConnector::apply_tool_directive(event, &config, "test-conn"); + + assert!( + result.is_some(), + "Unknown tools should pass through unmodified" + ); + } + + #[test] + fn tool_directive_non_tool_event_passes_through() { + let config = ToolConfiguration::new(); + + // A non-ToolCall event should pass through unchanged + let event = Event::SessionUpdate { + connector_id: "test-conn".to_string(), + session_id: "sess-1".to_string(), + update: SessionUpdate::AgentMessageChunk { + message_id: "msg-1".to_string(), + content: dirigent_protocol::ContentBlock::Text { + text: "hello".to_string(), + }, + _meta: None, + }, + }; + let result = AcpConnector::apply_tool_directive(event, &config, "test-conn"); + + assert!( + result.is_some(), + "Non-ToolCall events should pass through" + ); + } + + #[test] + fn tool_directive_agent_handler_not_intercepted() { + let mut config = ToolConfiguration::new(); + // Agent handler with checked=true should NOT be intercepted + config.set(ToolDirective::checked("file_write", ToolHandler::Agent)); + + let event = make_tool_call_event("file_write"); + let result = AcpConnector::apply_tool_directive(event, &config, "test-conn"); + + assert!( + result.is_some(), + "Agent handler should not intercept" + ); + if let Some(Event::SessionUpdate { + update: SessionUpdate::ToolCall { tool_call, .. }, + .. + }) = result + { + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert!(tool_call.error.is_none()); + } + } + + #[test] + fn build_session_meta_with_tool_config() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::checked("shell_exec", ToolHandler::Deny)); + + let meta = AcpConnector::build_session_meta(&Some(config.clone())); + + assert!(meta.is_some()); + let meta = meta.unwrap(); + assert!(meta.extra.contains_key("tool_configuration")); + + // Verify the serialized value can be deserialized back + let restored: ToolConfiguration = + serde_json::from_value(meta.extra["tool_configuration"].clone()).unwrap(); + assert_eq!(restored, config); + } + + #[test] + fn build_session_meta_without_tool_config() { + let meta = AcpConnector::build_session_meta(&None); + assert!(meta.is_none()); + } + } +} diff --git a/crates/dirigent_core/src/connectors/acp/error.rs b/crates/dirigent_core/src/connectors/acp/error.rs new file mode 100644 index 0000000..3e1351f --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/error.rs @@ -0,0 +1,300 @@ +//! Error types for ACP connector +//! +//! This module defines error types specific to ACP connector operations, +//! including transport failures, protocol errors, and configuration issues. + +use thiserror::Error; + +/// Result type for ACP connector operations +pub type AcpResult = Result; + +/// Errors that can occur during ACP connector operations +/// +/// This error type covers all failure modes of the ACP connector, from +/// low-level transport issues to high-level protocol violations. +#[derive(Debug, Error)] +pub enum AcpError { + /// Transport layer error (connection failure, I/O error, etc.) + /// + /// Wraps errors from the transport layer (stdio or HTTP) that prevent + /// communication with the agent. + /// + /// # Examples + /// + /// - Process failed to spawn + /// - Network connection refused + /// - Broken pipe (process terminated) + /// - SSL/TLS handshake failed + #[error("Transport error: {0}")] + Transport(String), + + /// Protocol layer error (JSON-RPC format, invalid response, etc.) + /// + /// Indicates a violation of the ACP/JSON-RPC protocol, such as + /// malformed messages or unexpected response structure. + /// + /// # Examples + /// + /// - Invalid JSON syntax + /// - Missing required JSON-RPC fields + /// - Incorrect message format + /// - Protocol version mismatch + #[error("Protocol error: {0}")] + Protocol(String), + + /// Request timeout (no response received within timeout period) + /// + /// The agent did not respond to a request within the configured timeout. + /// This may indicate agent unresponsiveness, network issues, or very + /// slow operations. + #[error("Request timeout after {timeout_ms}ms")] + Timeout { + /// Timeout duration in milliseconds + timeout_ms: u64, + }, + + /// Configuration error (invalid settings, validation failure, etc.) + /// + /// The connector configuration is invalid or incomplete, preventing + /// connector creation or connection attempts. + /// + /// # Examples + /// + /// - Empty command for stdio transport + /// - Invalid URL for HTTP transport + /// - Zero timeout value + /// - Invalid protocol version + #[error("Configuration error: {0}")] + Config(String), + + /// Initialization error (agent initialization failed) + /// + /// The agent rejected the initialization request or returned an error + /// during the handshake phase. + /// + /// # Examples + /// + /// - Unsupported protocol version + /// - Agent capabilities mismatch + /// - Authentication failure + #[error("Initialization failed: {0}")] + Initialization(String), + + /// Session operation error (session/new, session/prompt, etc.) + /// + /// An error occurred during a session-related operation like creating + /// a new session or sending a prompt. + /// + /// # Examples + /// + /// - Session ID not found + /// - Session already terminated + /// - Invalid prompt format + /// - Resource exhaustion + #[error("Session operation failed: {0}")] + SessionOperation(String), + + /// Internal error (unexpected state, logic error, etc.) + /// + /// An internal error in the connector implementation. This typically + /// indicates a bug and should be reported. + /// + /// # Examples + /// + /// - Channel unexpectedly closed + /// - Lock poisoned + /// - Invalid state transition + #[error("Internal error: {0}")] + Internal(String), + + /// Agent returned a JSON-RPC error response + /// + /// The agent responded with a JSON-RPC error object instead of a + /// successful result. + #[error("Agent error: {message} (code: {code})")] + AgentError { + /// JSON-RPC error code + code: i64, + /// Error message from agent + message: String, + /// Optional additional error data + data: Option, + }, +} + +impl AcpError { + /// Create a transport error + pub fn transport(msg: impl Into) -> Self { + Self::Transport(msg.into()) + } + + /// Create a protocol error + pub fn protocol(msg: impl Into) -> Self { + Self::Protocol(msg.into()) + } + + /// Create a timeout error + pub fn timeout(timeout_ms: u64) -> Self { + Self::Timeout { timeout_ms } + } + + /// Create a configuration error + pub fn config(msg: impl Into) -> Self { + Self::Config(msg.into()) + } + + /// Create an initialization error + pub fn initialization(msg: impl Into) -> Self { + Self::Initialization(msg.into()) + } + + /// Create a session operation error + pub fn session_operation(msg: impl Into) -> Self { + Self::SessionOperation(msg.into()) + } + + /// Create an internal error + pub fn internal(msg: impl Into) -> Self { + Self::Internal(msg.into()) + } + + /// Create an agent error from JSON-RPC error object + pub fn agent_error(code: i64, message: impl Into, data: Option) -> Self { + Self::AgentError { + code, + message: message.into(), + data, + } + } + + /// Check if this error is retriable (should attempt reconnection) + /// + /// Returns true for transient errors like transport failures and timeouts. + /// Returns false for permanent errors like configuration issues. + pub fn is_retriable(&self) -> bool { + matches!( + self, + AcpError::Transport(_) | AcpError::Timeout { .. } | AcpError::Internal(_) + ) + } +} + +// Conversion from protocol handler errors +impl From for AcpError { + fn from(err: crate::connectors::acp::protocol::ProtocolError) -> Self { + use crate::connectors::acp::protocol::ProtocolError; + + match err { + ProtocolError::Timeout { timeout_ms } => AcpError::Timeout { timeout_ms }, + ProtocolError::JsonRpcError { code, message, data } => { + AcpError::AgentError { code, message, data } + } + ProtocolError::InvalidMessage(msg) => AcpError::Protocol(msg), + ProtocolError::Internal(msg) => AcpError::Internal(msg), + } + } +} + +// Conversion from transport errors (boxed error) +impl From> for AcpError { + fn from(err: Box) -> Self { + AcpError::Transport(err.to_string()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_transport_error() { + let err = AcpError::transport("Connection refused"); + assert_eq!(err.to_string(), "Transport error: Connection refused"); + } + + #[test] + fn test_protocol_error() { + let err = AcpError::protocol("Invalid JSON"); + assert_eq!(err.to_string(), "Protocol error: Invalid JSON"); + } + + #[test] + fn test_timeout_error() { + let err = AcpError::timeout(30000); + assert_eq!(err.to_string(), "Request timeout after 30000ms"); + } + + #[test] + fn test_config_error() { + let err = AcpError::config("Empty command"); + assert_eq!(err.to_string(), "Configuration error: Empty command"); + } + + #[test] + fn test_initialization_error() { + let err = AcpError::initialization("Protocol version mismatch"); + assert_eq!(err.to_string(), "Initialization failed: Protocol version mismatch"); + } + + #[test] + fn test_session_operation_error() { + let err = AcpError::session_operation("Session not found"); + assert_eq!(err.to_string(), "Session operation failed: Session not found"); + } + + #[test] + fn test_internal_error() { + let err = AcpError::internal("Channel closed"); + assert_eq!(err.to_string(), "Internal error: Channel closed"); + } + + #[test] + fn test_agent_error() { + let err = AcpError::agent_error(-32601, "Method not found", None); + assert_eq!(err.to_string(), "Agent error: Method not found (code: -32601)"); + } + + #[test] + fn test_is_retriable() { + assert!(AcpError::transport("test").is_retriable()); + assert!(AcpError::timeout(1000).is_retriable()); + assert!(AcpError::internal("test").is_retriable()); + + assert!(!AcpError::config("test").is_retriable()); + assert!(!AcpError::protocol("test").is_retriable()); + assert!(!AcpError::initialization("test").is_retriable()); + } + + #[test] + fn test_from_protocol_error_timeout() { + use crate::connectors::acp::protocol::ProtocolError; + + let protocol_err = ProtocolError::Timeout { timeout_ms: 5000 }; + let acp_err: AcpError = protocol_err.into(); + + match acp_err { + AcpError::Timeout { timeout_ms } => assert_eq!(timeout_ms, 5000), + _ => panic!("Expected Timeout variant"), + } + } + + #[test] + fn test_from_protocol_error_jsonrpc() { + use crate::connectors::acp::protocol::ProtocolError; + + let protocol_err = ProtocolError::JsonRpcError { + code: -32600, + message: "Invalid Request".to_string(), + data: None, + }; + let acp_err: AcpError = protocol_err.into(); + + match acp_err { + AcpError::AgentError { code, message, .. } => { + assert_eq!(code, -32600); + assert_eq!(message, "Invalid Request"); + } + _ => panic!("Expected AgentError variant"), + } + } +} diff --git a/crates/dirigent_core/src/connectors/acp/idle_detector.rs b/crates/dirigent_core/src/connectors/acp/idle_detector.rs new file mode 100644 index 0000000..f0ff5f5 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/idle_detector.rs @@ -0,0 +1,251 @@ +//! Idle Session Detection +//! +//! This module provides utilities for detecting idle sessions and emitting +//! SessionIdle events after a period of inactivity. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use crate::sharing::bus::SharingBus; +use dirigent_protocol::{Event, Message, TurnCompleteTrigger}; +use tokio::sync::{broadcast, Mutex}; +use tracing::{info, trace}; +use uuid::Uuid; + +/// Duration of notification silence before emitting SessionIdle +/// 300ms is a good balance between responsiveness and not being too aggressive +pub const IDLE_THRESHOLD: Duration = Duration::from_millis(300); + +/// Data needed to emit a deferred MessageCompleted event +/// Stored until SessionIdle is ready to emit, ensuring all notifications are processed +#[derive(Debug, Clone)] +pub struct PendingCompletion { + pub connector_id: String, + pub message: Message, +} + +/// Per-session state for activity tracking and idle detection +#[derive(Debug)] +pub struct SessionState { + /// Last time a notification was received for this session + last_activity: Instant, + /// True after JSON-RPC response received, false after SessionIdle emitted + awaiting_idle: bool, + /// Pending MessageCompleted to emit when SessionIdle fires + /// This delays MessageCompleted until all notifications are processed + pub pending_completion: Option, +} + +impl SessionState { + /// Create a new session state + pub fn new(_id: String) -> Self { + Self { + last_activity: Instant::now(), + awaiting_idle: false, + pending_completion: None, + } + } + + /// Update activity timestamp (call when notification received) + pub fn touch(&mut self) { + self.last_activity = Instant::now(); + } + + /// Mark session as awaiting idle emission + pub fn mark_awaiting_idle(&mut self) { + self.awaiting_idle = true; + self.last_activity = Instant::now(); + } + + /// Clear awaiting idle flag (after SessionIdle emitted) + pub fn clear_awaiting_idle(&mut self) { + self.awaiting_idle = false; + } + + /// Check if session should emit SessionIdle + pub fn should_emit_idle(&self, threshold: Duration) -> bool { + self.awaiting_idle && self.last_activity.elapsed() >= threshold + } + + /// Get time since last activity + pub fn elapsed(&self) -> Duration { + self.last_activity.elapsed() + } + + /// Check if session is awaiting idle + pub fn is_awaiting_idle(&self) -> bool { + self.awaiting_idle + } +} + +/// Check all sessions and emit SessionIdle for those with no recent activity +/// +/// For sessions with pending MessageCompleted, emits MessageCompleted BEFORE SessionIdle. +/// This ensures all notifications are processed before finalization signals are sent. +/// +/// Returns the list of session IDs that transitioned to idle. +pub async fn check_idle_sessions( + session_states: &Arc>>, + events_tx: &broadcast::Sender, + sharing_bus: &Arc, + connector_uid: Option, + connector_id: &str, +) -> Vec { + let now = Instant::now(); + let mut sessions_to_idle: Vec<(String, Option)> = Vec::new(); + + // Collect sessions that should emit idle, along with any pending completions + { + let mut states = session_states.lock().await; + for (session_id, state) in states.iter_mut() { + // Add trace logging for debugging idle detection + if state.awaiting_idle { + let elapsed = now.duration_since(state.last_activity); + trace!( + session_id = %session_id, + elapsed_ms = elapsed.as_millis(), + threshold_ms = IDLE_THRESHOLD.as_millis(), + has_pending = state.pending_completion.is_some(), + "Checking session idle status" + ); + } + + if state.should_emit_idle(IDLE_THRESHOLD) { + // Take pending_completion (moves ownership out) + let pending = state.pending_completion.take(); + sessions_to_idle.push((session_id.clone(), pending)); + } + } + } + + let mut idled_sessions = Vec::with_capacity(sessions_to_idle.len()); + + // Emit MessageCompleted (if pending) then TurnComplete then SessionIdle for each qualifying session + for (session_id, pending_completion) in sessions_to_idle { + // IMPORTANT: Emit MessageCompleted BEFORE TurnComplete and SessionIdle + // This ensures archivist receives complete message data before finalization + if let Some(pending) = pending_completion { + info!( + session_id = %session_id, + message_id = %pending.message.id, + "Emitting deferred MessageCompleted after {:?} of inactivity", + IDLE_THRESHOLD + ); + + let completed_event = Event::MessageCompleted { + connector_id: pending.connector_id.clone(), + message: pending.message.clone(), + }; + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + completed_event.clone(), + connector_uid, + connector_id.to_string(), + ); + sharing_bus.publish(bus_event).await; + let _ = events_tx.send(completed_event); + + // Also emit TurnComplete with ResponseReceived trigger + // This signals that the turn is truly complete (ACP response received) + info!( + session_id = %session_id, + message_id = %pending.message.id, + "Emitting TurnComplete with ResponseReceived trigger", + ); + + let turn_event = Event::TurnComplete { + connector_id: pending.connector_id, + session_id: session_id.clone(), + message_id: pending.message.id, + trigger: TurnCompleteTrigger::ResponseReceived, + }; + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + turn_event.clone(), + connector_uid, + connector_id.to_string(), + ); + sharing_bus.publish(bus_event).await; + let _ = events_tx.send(turn_event); + } + + // Then emit SessionIdle + info!( + session_id = %session_id, + "Emitting SessionIdle after {:?} of inactivity", + IDLE_THRESHOLD + ); + + let idle_event = Event::SessionIdle { + connector_id: connector_id.to_string(), + session_id: session_id.clone(), + }; + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + idle_event.clone(), + connector_uid, + connector_id.to_string(), + ); + sharing_bus.publish(bus_event).await; + let _ = events_tx.send(idle_event); + + // Clear the awaiting_idle flag + { + let mut states = session_states.lock().await; + if let Some(state) = states.get_mut(&session_id) { + state.clear_awaiting_idle(); + } + } + + idled_sessions.push(session_id); + } + + idled_sessions +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_session_state_new() { + let state = SessionState::new("test".to_string()); + assert!(!state.is_awaiting_idle()); + assert!(state.pending_completion.is_none()); + } + + #[test] + fn test_session_state_touch() { + let mut state = SessionState::new("test".to_string()); + let before = state.elapsed(); + std::thread::sleep(Duration::from_millis(10)); + state.touch(); + let after = state.elapsed(); + assert!(after < before); + } + + #[test] + fn test_session_state_awaiting_idle() { + let mut state = SessionState::new("test".to_string()); + assert!(!state.is_awaiting_idle()); + + state.mark_awaiting_idle(); + assert!(state.is_awaiting_idle()); + + state.clear_awaiting_idle(); + assert!(!state.is_awaiting_idle()); + } + + #[test] + fn test_session_state_should_emit_idle() { + let mut state = SessionState::new("test".to_string()); + + // Not awaiting idle - should not emit + assert!(!state.should_emit_idle(Duration::ZERO)); + + // Awaiting idle but not enough time passed + state.mark_awaiting_idle(); + assert!(!state.should_emit_idle(Duration::from_secs(10))); + + // Awaiting idle with zero threshold - should emit immediately + assert!(state.should_emit_idle(Duration::ZERO)); + } +} diff --git a/crates/dirigent_core/src/connectors/acp/logging.rs b/crates/dirigent_core/src/connectors/acp/logging.rs new file mode 100644 index 0000000..ae2d362 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/logging.rs @@ -0,0 +1,259 @@ +//! ACP Protocol Logging +//! +//! This module provides file-based logging of all ACP protocol messages (stdin/stdout) +//! for debugging and analysis. When enabled via configuration, all JSON-RPC messages +//! are logged to JSONL files in the specified directory. +//! +//! # Log Format +//! +//! Logs are written in JSONL (newline-delimited JSON) format: +//! ```json +//! {"ts":"2026-01-10T12:34:56.789Z","dir":"out","msg":{...json-rpc message...}} +//! {"ts":"2026-01-10T12:34:57.001Z","dir":"in","msg":{...json-rpc message...}} +//! {"ts":"2026-01-10T12:34:58.123Z","dir":"in","raw":"not valid json","parse_error":"..."} +//! ``` +//! +//! # File Naming +//! +//! Log files are named `{identifier}_{timestamp}.jsonl` +//! - identifier: Command name or connector identifier +//! - timestamp: ISO 8601 timestamp when logging started + +use chrono::Utc; +use serde_json::Value; +use std::fs::{self, File}; +use std::io::{BufWriter, Write}; +use std::path::PathBuf; +use tracing::{debug, error, info, warn}; + +/// ACP Protocol Logger +/// +/// Logs all JSON-RPC messages to files for debugging. +/// Owned by the transport - no Arc/Mutex needed. +pub struct AcpProtocolLogger { + /// Current log file handle + file: BufWriter, + /// Path to the log file (for logging) + path: PathBuf, +} + +impl AcpProtocolLogger { + /// Create a new logger and open the log file immediately + /// + /// # Arguments + /// * `log_dir` - Directory to store log files (will be created if it doesn't exist) + /// * `identifier` - Identifier for file naming (e.g., command name) + /// + /// # Returns + /// `Some(logger)` if file was created successfully, `None` otherwise + pub fn new(log_dir: PathBuf, identifier: &str) -> Option { + // Create directory if it doesn't exist + if let Err(e) = fs::create_dir_all(&log_dir) { + warn!( + "Failed to create ACP log directory {:?}: {}", + log_dir, e + ); + return None; + } + + Self::cleanup_old_logs(&log_dir, 100); + + let timestamp = Utc::now().format("%Y%m%d_%H%M%S%.3f").to_string(); + let safe_identifier = identifier + .replace('/', "_") + .replace('\\', "_") + .replace(':', "_"); + let filename = format!("{}_{}.jsonl", safe_identifier, timestamp); + let path = log_dir.join(filename); + + match File::create(&path) { + Ok(file) => { + info!("ACP protocol logging to {:?}", path); + Some(Self { + file: BufWriter::new(file), + path, + }) + } + Err(e) => { + error!("Failed to create ACP log file {:?}: {}", path, e); + None + } + } + } + + /// Log an outgoing message (to the agent) + pub fn log_outgoing(&mut self, message: &Value) { + let entry = serde_json::json!({ + "ts": Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true), + "dir": "out", + "msg": message + }); + self.write_entry(&entry); + } + + /// Log an incoming raw line + /// + /// Tries to parse as JSON. If successful, logs the parsed message. + /// If parsing fails, logs the raw line with the parse error. + pub fn log_incoming_raw(&mut self, raw_line: &str) { + match serde_json::from_str::(raw_line) { + Ok(message) => { + let entry = serde_json::json!({ + "ts": Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true), + "dir": "in", + "msg": message + }); + self.write_entry(&entry); + } + Err(e) => { + let entry = serde_json::json!({ + "ts": Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true), + "dir": "in", + "raw": raw_line, + "parse_error": e.to_string() + }); + self.write_entry(&entry); + } + } + } + + /// Log an incoming parsed message (when already parsed) + pub fn log_incoming(&mut self, message: &Value) { + let entry = serde_json::json!({ + "ts": Utc::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, true), + "dir": "in", + "msg": message + }); + self.write_entry(&entry); + } + + fn cleanup_old_logs(log_dir: &PathBuf, max_files: usize) { + let entries: Vec<_> = match fs::read_dir(log_dir) { + Ok(rd) => rd + .filter_map(|e| e.ok()) + .filter(|e| { + e.path() + .extension() + .map(|ext| ext == "jsonl") + .unwrap_or(false) + }) + .collect(), + Err(_) => return, + }; + + if entries.len() <= max_files { + return; + } + + let mut by_modified: Vec<_> = entries + .into_iter() + .filter_map(|e| e.metadata().ok().and_then(|m| m.modified().ok()).map(|t| (e, t))) + .collect(); + by_modified.sort_by_key(|(_, t)| *t); + + let to_remove = by_modified.len().saturating_sub(max_files); + let mut removed = 0; + for (entry, _) in by_modified.into_iter().take(to_remove) { + if fs::remove_file(entry.path()).is_ok() { + removed += 1; + } + } + if removed > 0 { + info!("Cleaned up {} old ACP log files from {:?}", removed, log_dir); + } + } + + fn write_entry(&mut self, entry: &Value) { + if let Err(e) = writeln!(self.file, "{}", entry.to_string()) { + error!("Failed to write ACP log entry: {}", e); + } else if let Err(e) = self.file.flush() { + debug!("Failed to flush ACP log: {}", e); + } + } + + /// Get the log file path + pub fn path(&self) -> &PathBuf { + &self.path + } +} + +impl Drop for AcpProtocolLogger { + fn drop(&mut self) { + if let Err(e) = self.file.flush() { + warn!("Failed to flush ACP log on close: {}", e); + } + info!("Closed ACP protocol log: {:?}", self.path); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::tempdir; + + #[test] + fn test_logger_creates_directory_and_file() { + let temp = tempdir().unwrap(); + let log_dir = temp.path().join("acp_logs"); + + let logger = AcpProtocolLogger::new(log_dir.clone(), "test-command"); + assert!(logger.is_some()); + assert!(log_dir.exists()); + + // Should have created a file + let entries: Vec<_> = fs::read_dir(&log_dir).unwrap().collect(); + assert_eq!(entries.len(), 1); + } + + #[test] + fn test_logger_writes_messages() { + let temp = tempdir().unwrap(); + let log_dir = temp.path().join("acp_logs"); + + let mut logger = AcpProtocolLogger::new(log_dir.clone(), "test-command").unwrap(); + + let msg = serde_json::json!({"jsonrpc": "2.0", "method": "test", "id": 1}); + logger.log_outgoing(&msg); + logger.log_incoming(&serde_json::json!({"jsonrpc": "2.0", "result": "ok", "id": 1})); + drop(logger); // Flush on drop + + // Read the file + let entries: Vec<_> = fs::read_dir(&log_dir).unwrap().collect(); + let file_path = entries[0].as_ref().unwrap().path(); + let content = fs::read_to_string(file_path).unwrap(); + let lines: Vec<&str> = content.lines().collect(); + + assert_eq!(lines.len(), 2); + assert!(lines[0].contains("\"dir\":\"out\"")); + assert!(lines[1].contains("\"dir\":\"in\"")); + } + + #[test] + fn test_logger_handles_raw_parse_error() { + let temp = tempdir().unwrap(); + let log_dir = temp.path().join("acp_logs"); + + let mut logger = AcpProtocolLogger::new(log_dir.clone(), "test-command").unwrap(); + + // Log a valid JSON + logger.log_incoming_raw(r#"{"jsonrpc":"2.0","result":"ok","id":1}"#); + // Log invalid JSON (raw output from crashed process) + logger.log_incoming_raw("Error: process crashed with SIGSEGV"); + drop(logger); + + // Read the file + let entries: Vec<_> = fs::read_dir(&log_dir).unwrap().collect(); + let file_path = entries[0].as_ref().unwrap().path(); + let content = fs::read_to_string(file_path).unwrap(); + let lines: Vec<&str> = content.lines().collect(); + + assert_eq!(lines.len(), 2); + // First line should be parsed JSON + assert!(lines[0].contains("\"msg\":")); + assert!(!lines[0].contains("\"parse_error\"")); + // Second line should have raw + parse_error + assert!(lines[1].contains("\"raw\":")); + assert!(lines[1].contains("\"parse_error\":")); + assert!(lines[1].contains("SIGSEGV")); + } +} diff --git a/crates/dirigent_core/src/connectors/acp/mod.rs b/crates/dirigent_core/src/connectors/acp/mod.rs new file mode 100644 index 0000000..07029da --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/mod.rs @@ -0,0 +1,100 @@ +//! ACP (Agent-Client Protocol) connector module +//! +//! This module provides a complete connector implementation for the Agent-Client Protocol, +//! which is a standardized protocol for communicating with AI agents. +//! +//! # Architecture +//! +//! The ACP connector consists of multiple layers: +//! +//! ## Transport Layer (`transport`) +//! Provides low-level communication mechanisms: +//! - **StdioTransport**: Line-delimited JSON over stdin/stdout (for spawned processes) +//! - **HttpTransport**: HTTP POST for requests + SSE for notifications +//! +//! ## Protocol Layer (`protocol`) +//! Provides JSON-RPC request/response correlation: +//! - **ProtocolHandler**: Correlates requests with responses by ID +//! - Request timeout handling (default 30s) +//! - Notification routing (server-initiated messages) +//! - Helper functions for building ACP messages +//! +//! ## Configuration (`config`) +//! Configuration types for connector creation: +//! - **AcpConfig**: Main configuration struct +//! - **TransportKind**: Transport selection (Stdio vs HTTP) +//! - **RetryConfig**: Reconnection behavior +//! +//! ## Error Handling (`error`) +//! Error types for all connector operations: +//! - **AcpError**: Comprehensive error enum +//! - **AcpResult**: Result type alias +//! +//! ## State Management (`state`) +//! Internal state tracking: +//! - **InternalState**: Thread-safe state container +//! - **SessionInfo**: Session metadata +//! +//! ## Connector Implementation (`connector`) +//! Main connector implementation: +//! - **AcpConnector**: Implements `Connector` trait +//! - Background task for event processing +//! - Automatic reconnection with exponential backoff +//! +//! # Usage +//! +//! ```no_run +//! use dirigent_core::connectors::acp::{AcpConnector, AcpConfig}; +//! use dirigent_core::connectors::Connector; +//! +//! # async fn example() -> anyhow::Result<()> { +//! // Create connector config +//! let config = AcpConfig::stdio("dirigate", vec!["serve".to_string(), "--stdio".to_string()]); +//! +//! // Create connector +//! let connector = AcpConnector::new( +//! "acp-1".to_string(), +//! uuid::Uuid::now_v7(), +//! "ACP Agent".to_string(), +//! config, +//! )?; +//! +//! // Start background task +//! let task_handle = connector.start_task().await; +//! +//! // Subscribe to events +//! let mut events = connector.subscribe(); +//! +//! // Send commands +//! let cmd_tx = connector.command_tx(); +//! // ... use connector ... +//! +//! // Stop connector +//! connector.stop(); +//! task_handle.await?; +//! # Ok(()) +//! # } +//! ``` + +pub mod config; +pub mod connector; +pub mod error; +pub mod idle_detector; +pub mod logging; +pub mod protocol; +pub mod state; +pub mod title_utils; +pub mod transport; + +// Re-export main types for convenience +pub use config::{features, AcpConfig, RetryConfig, TransportKind}; +pub use connector::AcpConnector; +pub use error::{AcpError, AcpResult}; +pub use logging::AcpProtocolLogger; +pub use protocol::{ + build_initialize_request, build_session_cancel_request, build_session_load_request, + build_session_new_request, build_session_prompt_request, ProtocolError, ProtocolHandler, + ProtocolResult, RequestId, +}; +pub use state::{InternalState, SessionInfo, SessionStatus}; +pub use transport::{AcpTransport, HttpTransport, StdioTransport, TransportResult}; diff --git a/crates/dirigent_core/src/connectors/acp/protocol.rs b/crates/dirigent_core/src/connectors/acp/protocol.rs new file mode 100644 index 0000000..d190ac3 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/protocol.rs @@ -0,0 +1,1385 @@ +//! JSON-RPC protocol handler for ACP +//! +//! This module provides request/response correlation, timeout handling, and +//! notification routing for JSON-RPC 2.0 messages used in the Agent-Client Protocol. +//! +//! # Architecture +//! +//! The `ProtocolHandler` maintains a registry of pending requests and correlates +//! incoming responses by request ID. It distinguishes between: +//! +//! - **Requests**: Messages with an `id` field that expect a response +//! - **Responses**: Messages with an `id` field and a `result` or `error` field +//! - **Notifications**: Messages with a `method` field but no `id` (server-initiated) +//! +//! # Request/Response Correlation +//! +//! When sending a request: +//! 1. Generate unique request ID using atomic counter +//! 2. Create oneshot channel for response +//! 3. Store channel sender in pending requests map +//! 4. Send request to transport +//! 5. Wait for response with timeout +//! +//! When receiving a message: +//! 1. Check if message has `id` (response) or only `method` (notification) +//! 2. For responses: Look up pending request by ID, send response via oneshot +//! 3. For notifications: Route to notification channel +//! +//! # Example +//! +//! ```no_run +//! use dirigent_core::connectors::acp::protocol::ProtocolHandler; +//! use serde_json::json; +//! +//! # async fn example() -> anyhow::Result<()> { +//! let mut handler = ProtocolHandler::new(); +//! +//! // Build and send a request +//! let request = json!({ +//! "jsonrpc": "2.0", +//! "method": "initialize", +//! "params": { +//! "protocolVersion": 1, +//! "clientCapabilities": {} +//! } +//! }); +//! +//! let response = handler.send_request(request).await?; +//! println!("Response: {:?}", response); +//! # Ok(()) +//! # } +//! ``` + +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot, Mutex}; +use tokio::time::{timeout, Duration}; +use tracing::{debug, info, warn}; + +/// Default timeout for requests (30 seconds) +const DEFAULT_TIMEOUT_MS: u64 = 30_000; + +/// Request ID type (JSON-RPC 2.0 allows string, number, or null) +pub type RequestId = Value; + +/// Result type for protocol operations +pub type ProtocolResult = Result; + +/// Protocol handler errors +#[derive(Debug, thiserror::Error)] +pub enum ProtocolError { + /// Request timed out waiting for response + #[error("Request timeout after {timeout_ms}ms")] + Timeout { timeout_ms: u64 }, + + /// Received error response from server + #[error("JSON-RPC error: {message} (code: {code})")] + JsonRpcError { + code: i64, + message: String, + data: Option, + }, + + /// Invalid message format + #[error("Invalid message: {0}")] + InvalidMessage(String), + + /// Internal error + #[error("Internal error: {0}")] + Internal(String), +} + +/// Result type for message handling +/// +/// Represents the different outcomes when handling an incoming message from the agent. +/// This allows the protocol handler to distinguish between notifications (no response), +/// immediate responses, and agent requests that need external input. +#[derive(Debug)] +pub enum MessageHandlerResult { + /// No response needed (notification was processed) + None, + + /// Immediate response to send back to the agent + Response(Value), + + /// Agent-initiated request requiring external input (e.g., permission prompt) + /// + /// The connector should emit an Event::AgentRequest and wait for the client's + /// response before sending anything back to the agent. + AgentRequest { + /// Request ID from the agent (for correlating the response) + request_id: Value, + /// Method being requested (e.g., "session/request_permission") + method: String, + /// Request parameters + params: Value, + }, +} + +/// JSON-RPC protocol handler +/// +/// Manages request/response correlation, timeouts, and notification routing. +/// +/// # Thread Safety +/// +/// ProtocolHandler is thread-safe and can be shared across tasks using `Arc`. +/// All internal state is protected by async-aware locks. +pub struct ProtocolHandler { + /// Atomic counter for generating unique request IDs + next_request_id: Arc, + + /// Map of pending requests waiting for responses + /// + /// Key: Request ID + /// Value: Oneshot sender to deliver the response + pending_requests: Arc>>>, + + /// Channel for routing notifications (messages without request ID) + notification_tx: mpsc::UnboundedSender, + + /// Receiver for notifications (taken by consumer) + notification_rx: Arc>>>, + + /// Default timeout for requests + timeout_ms: u64, +} + +impl ProtocolHandler { + /// Create a new protocol handler + /// + /// # Example + /// + /// ``` + /// use dirigent_core::connectors::acp::protocol::ProtocolHandler; + /// + /// let handler = ProtocolHandler::new(); + /// ``` + pub fn new() -> Self { + Self::with_timeout(DEFAULT_TIMEOUT_MS) + } + + /// Create a new protocol handler with custom timeout + /// + /// # Arguments + /// + /// * `timeout_ms` - Request timeout in milliseconds + /// + /// # Example + /// + /// ``` + /// use dirigent_core::connectors::acp::protocol::ProtocolHandler; + /// + /// let handler = ProtocolHandler::with_timeout(60_000); // 60 second timeout + /// ``` + pub fn with_timeout(timeout_ms: u64) -> Self { + let (notification_tx, notification_rx) = mpsc::unbounded_channel(); + + Self { + next_request_id: Arc::new(AtomicU64::new(1)), + pending_requests: Arc::new(Mutex::new(HashMap::new())), + notification_tx, + notification_rx: Arc::new(Mutex::new(Some(notification_rx))), + timeout_ms, + } + } + + /// Generate a unique request ID + /// + /// Uses an atomic counter to ensure uniqueness across concurrent requests. + fn generate_request_id(&self) -> RequestId { + let id = self.next_request_id.fetch_add(1, Ordering::SeqCst); + json!(id) + } + + /// Add request ID to a message and prepare for response + /// + /// Returns the modified message and a oneshot receiver for the response. + /// + /// # Arguments + /// + /// * `message` - JSON-RPC request (without `id` field) + /// + /// # Returns + /// + /// Tuple of (message with ID, response receiver) + pub async fn prepare_request(&self, mut message: Value) -> (Value, oneshot::Receiver) { + let request_id = self.generate_request_id(); + + // Add ID to message + if let Some(obj) = message.as_object_mut() { + obj.insert("id".to_string(), request_id.clone()); + } + + // Create oneshot channel for response + let (tx, rx) = oneshot::channel(); + + // Store in pending requests + self.pending_requests.lock().await.insert(request_id, tx); + + (message, rx) + } + + /// Send a request and wait for response with timeout + /// + /// This is a placeholder that will be integrated with the transport layer. + /// For now, it demonstrates the correlation logic that will be used. + /// + /// # Arguments + /// + /// * `message` - JSON-RPC request (without `id` field) + /// + /// # Returns + /// + /// Response value or error + /// + /// # Errors + /// + /// - `ProtocolError::Timeout`: Request timed out + /// - `ProtocolError::JsonRpcError`: Server returned error response + /// - `ProtocolError::Internal`: Internal error occurred + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::connectors::acp::protocol::ProtocolHandler; + /// # use serde_json::json; + /// # async fn example() -> anyhow::Result<()> { + /// let handler = ProtocolHandler::new(); + /// + /// let request = json!({ + /// "jsonrpc": "2.0", + /// "method": "initialize", + /// "params": { "protocolVersion": 1 } + /// }); + /// + /// let response = handler.send_request(request).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn send_request(&self, message: Value) -> ProtocolResult { + let (message_with_id, response_rx) = self.prepare_request(message).await; + + // Extract the request ID for cleanup on timeout + let request_id = message_with_id + .get("id") + .cloned() + .unwrap_or(json!(null)); + + // In a real implementation, this would send via transport + // For now, we just demonstrate the correlation logic + debug!(request_id = %request_id, "Request prepared (transport send would happen here)"); + + // Wait for response with timeout + match timeout(Duration::from_millis(self.timeout_ms), response_rx).await { + Ok(Ok(response)) => { + // Response received successfully + self.handle_response_result(response) + } + Ok(Err(_)) => { + // Oneshot channel was dropped (should not happen) + Err(ProtocolError::Internal( + "Response channel dropped".to_string(), + )) + } + Err(_) => { + // Timeout occurred + // Clean up pending request + self.pending_requests.lock().await.remove(&request_id); + + Err(ProtocolError::Timeout { + timeout_ms: self.timeout_ms, + }) + } + } + } + + /// Handle incoming message from transport + /// + /// Routes the message based on its type: + /// - Request (has `method` and `id`): Return AgentRequest or Response + /// - Response (has `id` and `result`/`error`): Route to pending request, return None + /// - Notification (has `method` but no `id`): Route to notification channel, return None + /// + /// # Arguments + /// + /// * `message` - JSON-RPC message from transport + /// + /// # Returns + /// + /// - `MessageHandlerResult::None`: Message processed, no response needed + /// - `MessageHandlerResult::Response(val)`: Immediate response to send to agent + /// - `MessageHandlerResult::AgentRequest`: Request needs external input (emit event) + /// + /// # Example + /// + /// ``` + /// # use dirigent_core::connectors::acp::protocol::{ProtocolHandler, MessageHandlerResult}; + /// # use serde_json::json; + /// # async fn example() { + /// let handler = ProtocolHandler::new(); + /// + /// // Handle a response + /// let response = json!({ + /// "jsonrpc": "2.0", + /// "id": 1, + /// "result": { "status": "ok" } + /// }); + /// let result = handler.handle_message(response).await; + /// // Returns MessageHandlerResult::None + /// + /// // Handle a notification + /// let notification = json!({ + /// "jsonrpc": "2.0", + /// "method": "session/update", + /// "params": { "sessionId": "123" } + /// }); + /// let result = handler.handle_message(notification).await; + /// // Returns MessageHandlerResult::None + /// # } + /// ``` + pub async fn handle_message(&self, message: Value) -> MessageHandlerResult { + // Determine message type according to JSON-RPC 2.0 spec: + // - Request: has "method" AND "id" (expects response) + // - Notification: has "method" but NO "id" (no response expected) + // - Response: has "result" or "error" AND "id", NO "method" (answer to our request) + + let has_id = message.get("id").is_some(); + let has_method = message.get("method").is_some(); + let has_result_or_error = message.get("result").is_some() || message.get("error").is_some(); + + if has_method && has_id { + // This is an INCOMING REQUEST from the agent - it expects a response from us + let method = message.get("method").and_then(|m| m.as_str()).unwrap_or(""); + let request_id = message.get("id").cloned().unwrap(); + + info!( + method = method, + request_id = %request_id, + "📨 Received request from agent (agent→client)" + ); + + // Handle known request types + match method { + "session/request_permission" => { + // Return agent request for external handling (permission UI in client) + info!( + request_id = %request_id, + "🔔 Permission request from agent - forwarding to client for user input" + ); + + let params = message.get("params").cloned().unwrap_or(json!({})); + + return MessageHandlerResult::AgentRequest { + request_id, + method: method.to_string(), + params, + }; + } + _ => { + warn!( + method = method, + request_id = %request_id, + "⚠️ Received unsupported request from agent, returning error" + ); + // Return error response for unsupported methods + return MessageHandlerResult::Response(json!({ + "jsonrpc": "2.0", + "id": request_id, + "error": { + "code": -32601, + "message": format!("Method not supported: {}", method) + } + })); + } + } + } else if has_id && has_result_or_error { + // This is a RESPONSE to one of OUR requests + let request_id = message.get("id").cloned().unwrap(); + + let sender = self.pending_requests.lock().await.remove(&request_id); + + if let Some(tx) = sender { + // Send response to waiting request + if tx.send(message).is_err() { + warn!( + request_id = %request_id, + "Failed to deliver response (receiver dropped)" + ); + } + } else { + // Unknown request ID (could be late response after timeout) + warn!( + request_id = %request_id, + "Received response for unknown request ID (possible timeout)" + ); + } + } else if has_method { + // This is a NOTIFICATION - route to notification channel + let method = message.get("method").and_then(|m| m.as_str()).unwrap_or(""); + + // Extract session_id from params if available (for session/update notifications) + let session_id = message.get("params") + .and_then(|p| p.get("sessionId")) + .and_then(|s| s.as_str()); + + // Use masked version for logging + let masked_message = dirigent_protocol::log_utils::format_for_log(&message); + + if let Some(sid) = session_id { + info!(method = method, session_id = sid, message = %masked_message, "🔔 Protocol handler received notification, routing to channel"); + } else { + info!(method = method, message = %masked_message, "🔔 Protocol handler received notification, routing to channel"); + } + + match self.notification_tx.send(message.clone()) { + Ok(_) => { + info!(method = method, "✅ Successfully forwarded notification to channel"); + } + Err(e) => { + warn!( + error = %e, + method = method, + "⚠️ Failed to route notification (no receivers)" + ); + } + } + } else { + // Invalid message format + warn!( + message = %message, + "Received message with neither valid request, response, nor notification format" + ); + } + + MessageHandlerResult::None + } + + /// Cancel all pending requests with a structured error response. + /// + /// Called before the protocol handler is dropped when the transport dies, + /// so that spawned tasks receive a meaningful error response instead of + /// a bare `RecvError` from the dropped oneshot sender. + pub async fn cancel_all_pending(&self, reason: &str) { + let mut pending = self.pending_requests.lock().await; + let count = pending.len(); + if count > 0 { + warn!(count, reason, "Cancelling pending requests"); + for (id, sender) in pending.drain() { + debug!(request_id = %id, "Cancelling pending request"); + let _ = sender.send(serde_json::json!({ + "jsonrpc": "2.0", + "id": id.to_string(), + "error": { + "code": -32000, + "message": reason + } + })); + } + } + } + + /// Get the number of currently pending requests. + pub async fn pending_request_count(&self) -> usize { + self.pending_requests.lock().await.len() + } + + /// Take the notification receiver + /// + /// This can only be called once. Subsequent calls will return None. + /// + /// # Returns + /// + /// The notification receiver, or None if already taken + /// + /// # Example + /// + /// ``` + /// # use dirigent_core::connectors::acp::protocol::ProtocolHandler; + /// # async fn example() { + /// let handler = ProtocolHandler::new(); + /// + /// let mut notifications = handler.take_notification_receiver().await.unwrap(); + /// + /// // Process notifications + /// while let Some(notif) = notifications.recv().await { + /// println!("Notification: {:?}", notif); + /// } + /// # } + /// ``` + pub async fn take_notification_receiver(&self) -> Option> { + self.notification_rx.lock().await.take() + } + + /// Handle a response value, checking for JSON-RPC errors + /// + /// # Arguments + /// + /// * `response` - Full JSON-RPC response message + /// + /// # Returns + /// + /// The result value or error + fn handle_response_result(&self, response: Value) -> ProtocolResult { + // Check for error field + if let Some(error) = response.get("error") { + let code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1); + let message = error + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error") + .to_string(); + let data = error.get("data").cloned(); + + return Err(ProtocolError::JsonRpcError { + code, + message, + data, + }); + } + + // Return the full response (caller can extract result field) + Ok(response) + } +} + +impl Default for ProtocolHandler { + fn default() -> Self { + Self::new() + } +} + +// Helper functions for building ACP JSON-RPC messages + +/// Build an initialize request +/// +/// # Arguments +/// +/// * `protocol_version` - ACP protocol version to use +/// * `capabilities` - Client capabilities object +/// +/// # Returns +/// +/// JSON-RPC initialize request (without `id` field) +/// +/// # Example +/// +/// ``` +/// use dirigent_core::connectors::acp::protocol::build_initialize_request; +/// use serde_json::json; +/// +/// let request = build_initialize_request(1, json!({})); +/// assert_eq!(request["method"], "initialize"); +/// assert_eq!(request["params"]["protocolVersion"], 1); +/// ``` +pub fn build_initialize_request(protocol_version: u32, capabilities: Value) -> Value { + json!({ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": protocol_version, + "clientCapabilities": capabilities + } + }) +} + +/// Build a session/new request +/// +/// # Arguments +/// +/// * `cwd` - Current working directory for the session +/// * `mcp_servers` - Optional array of MCP server configurations +/// +/// # Returns +/// +/// JSON-RPC session/new request (without `id` field) +/// +/// # Example +/// +/// ``` +/// use dirigent_core::connectors::acp::protocol::build_session_new_request; +/// use serde_json::json; +/// +/// let request = build_session_new_request(".", None); +/// assert_eq!(request["method"], "session/new"); +/// assert_eq!(request["params"]["cwd"], "."); +/// ``` +pub fn build_session_new_request(cwd: impl Into, mcp_servers: Option) -> Value { + json!({ + "jsonrpc": "2.0", + "method": "session/new", + "params": { + "cwd": cwd.into(), + "mcpServers": mcp_servers.unwrap_or(json!([])) + } + }) +} + +/// Build a session/prompt request +/// +/// # Arguments +/// +/// * `session_id` - Session ID to send prompt to +/// * `prompt` - Prompt array (content blocks) +/// +/// # Returns +/// +/// JSON-RPC session/prompt request (without `id` field) +/// +/// # Example +/// +/// ``` +/// use dirigent_core::connectors::acp::protocol::build_session_prompt_request; +/// use serde_json::json; +/// +/// let prompt = json!([ +/// { +/// "type": "text", +/// "text": "Hello, world!" +/// } +/// ]); +/// +/// let request = build_session_prompt_request("session-123", prompt); +/// assert_eq!(request["method"], "session/prompt"); +/// assert_eq!(request["params"]["sessionId"], "session-123"); +/// ``` +pub fn build_session_prompt_request(session_id: impl Into, prompt: Value) -> Value { + json!({ + "jsonrpc": "2.0", + "method": "session/prompt", + "params": { + "sessionId": session_id.into(), + "prompt": prompt + } + }) +} + +/// Build a session/cancel request +/// +/// # Arguments +/// +/// * `session_id` - Session ID to cancel +/// +/// # Returns +/// +/// JSON-RPC session/cancel request (without `id` field) +/// +/// # Example +/// +/// ``` +/// use dirigent_core::connectors::acp::protocol::build_session_cancel_request; +/// +/// let request = build_session_cancel_request("session-123"); +/// assert_eq!(request["method"], "session/cancel"); +/// assert_eq!(request["params"]["sessionId"], "session-123"); +/// ``` +/// Build a session/load request +/// +/// Loads an existing session and replays its history (if the agent supports it). +/// Per the ACP spec, `sessionId`, `cwd`, and `mcpServers` are all required fields. +/// +/// # Arguments +/// +/// * `session_id` - The ID of the session to load +/// * `cwd` - The working directory for the session +/// * `mcp_servers` - Optional MCP server configuration (defaults to empty array) +/// +/// # Returns +/// +/// A JSON-RPC 2.0 request message for session/load +/// +/// # Example +/// +/// ``` +/// use dirigent_core::connectors::acp::protocol::build_session_load_request; +/// use serde_json::json; +/// +/// let request = build_session_load_request("session-123", "/home/user", None); +/// assert_eq!(request["method"], "session/load"); +/// assert_eq!(request["params"]["sessionId"], "session-123"); +/// assert_eq!(request["params"]["cwd"], "/home/user"); +/// ``` +pub fn build_session_load_request( + session_id: impl Into, + cwd: impl Into, + mcp_servers: Option, +) -> Value { + json!({ + "jsonrpc": "2.0", + "method": "session/load", + "params": { + "sessionId": session_id.into(), + "cwd": cwd.into(), + "mcpServers": mcp_servers.unwrap_or(json!([])) + } + }) +} + +pub fn build_session_cancel_request(session_id: impl Into) -> Value { + json!({ + "jsonrpc": "2.0", + "method": "session/cancel", + "params": { + "sessionId": session_id.into() + } + }) +} + +/// Build a session/list request +/// +/// Lists sessions available from the agent. Supports optional filtering +/// by working directory and cursor-based pagination. +/// +/// # Arguments +/// +/// * `cwd` - Optional working directory filter +/// * `cursor` - Optional pagination cursor from previous response +/// +/// # Returns +/// +/// JSON-RPC session/list request (without `id` field) +/// +/// # Example +/// +/// ``` +/// use dirigent_core::connectors::acp::protocol::build_session_list_request; +/// +/// let request = build_session_list_request(None, None); +/// assert_eq!(request["method"], "session/list"); +/// ``` +pub fn build_session_list_request( + cwd: Option<&str>, + cursor: Option<&str>, +) -> Value { + let mut params = serde_json::Map::new(); + if let Some(cwd) = cwd { + params.insert("cwd".to_string(), Value::String(cwd.to_string())); + } + if let Some(cursor) = cursor { + params.insert("cursor".to_string(), Value::String(cursor.to_string())); + } + json!({ + "jsonrpc": "2.0", + "method": "session/list", + "params": params + }) +} + +/// Build a session/resume request +/// +/// Resumes an existing session without replaying history. The agent restores +/// internal context but does not send `session/update` history notifications. +/// +/// # Arguments +/// +/// * `session_id` - The session to resume +/// * `cwd` - Optional working directory +/// * `mcp_servers` - Optional MCP server configuration +/// +/// # Example +/// +/// ``` +/// use dirigent_core::connectors::acp::protocol::build_session_resume_request; +/// +/// let request = build_session_resume_request("sess-123", None, None); +/// assert_eq!(request["method"], "session/resume"); +/// ``` +pub fn build_session_resume_request( + session_id: impl Into, + cwd: Option<&str>, + mcp_servers: Option, +) -> Value { + let mut params = serde_json::Map::new(); + params.insert("sessionId".to_string(), Value::String(session_id.into())); + if let Some(cwd) = cwd { + params.insert("cwd".to_string(), Value::String(cwd.to_string())); + } + if let Some(servers) = mcp_servers { + params.insert("mcpServers".to_string(), servers); + } + json!({ + "jsonrpc": "2.0", + "method": "session/resume", + "params": params + }) +} + +/// Build a session/close JSON-RPC request. +/// +/// Tells the agent to release resources for this session. The session +/// remains in session/list and can be session/load-ed again later. +/// +/// # Arguments +/// +/// * `session_id` - The session to close +/// +/// # Example +/// +/// ``` +/// use dirigent_core::connectors::acp::protocol::build_session_close_request; +/// +/// let request = build_session_close_request("session-abc"); +/// assert_eq!(request["method"], "session/close"); +/// ``` +pub fn build_session_close_request(session_id: impl Into) -> Value { + json!({ + "jsonrpc": "2.0", + "method": "session/close", + "params": { + "sessionId": session_id.into() + } + }) +} + +/// Build a session/set_config_option JSON-RPC request. +/// +/// Sets a configuration option for a session. The response contains +/// the complete updated configOptions array. +/// +/// # Arguments +/// +/// * `session_id` - The session to update +/// * `config_id` - The config option identifier (e.g., "mode", "model") +/// * `value` - The new value to set +/// +/// # Example +/// +/// ``` +/// use dirigent_core::connectors::acp::protocol::build_set_config_option_request; +/// +/// let request = build_set_config_option_request("session-abc", "mode", "code"); +/// assert_eq!(request["method"], "session/set_config_option"); +/// assert_eq!(request["params"]["configId"], "mode"); +/// assert_eq!(request["params"]["value"], "code"); +/// ``` +pub fn build_set_config_option_request( + session_id: impl Into, + config_id: impl Into, + value: impl Into, +) -> Value { + json!({ + "jsonrpc": "2.0", + "method": "session/set_config_option", + "params": { + "sessionId": session_id.into(), + "configId": config_id.into(), + "value": value.into() + } + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio::time::sleep; + + #[tokio::test] + async fn test_request_id_generation() { + let handler = ProtocolHandler::new(); + + let id1 = handler.generate_request_id(); + let id2 = handler.generate_request_id(); + let id3 = handler.generate_request_id(); + + // IDs should be unique + assert_ne!(id1, id2); + assert_ne!(id2, id3); + assert_ne!(id1, id3); + + // IDs should be numbers + assert!(id1.is_u64()); + assert!(id2.is_u64()); + assert!(id3.is_u64()); + } + + #[tokio::test] + async fn test_request_id_uniqueness_concurrent() { + let handler = Arc::new(ProtocolHandler::new()); + let mut handles = vec![]; + + // Generate 1000 IDs concurrently + for _ in 0..1000 { + let h = Arc::clone(&handler); + let handle = tokio::spawn(async move { h.generate_request_id() }); + handles.push(handle); + } + + let mut ids = vec![]; + for handle in handles { + ids.push(handle.await.unwrap()); + } + + // Check all IDs are unique + let unique_count = ids.iter().collect::>().len(); + assert_eq!(unique_count, 1000, "All IDs should be unique"); + } + + #[tokio::test] + async fn test_handle_response_with_matching_id() { + let handler = ProtocolHandler::new(); + + // Prepare a request + let request = json!({ + "jsonrpc": "2.0", + "method": "test", + "params": {} + }); + + let (message_with_id, response_rx) = handler.prepare_request(request).await; + let request_id = message_with_id.get("id").cloned().unwrap(); + + // Simulate receiving a response + let response = json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": { "status": "ok" } + }); + + handler.handle_message(response.clone()).await; + + // Verify response was delivered + let received = response_rx.await.unwrap(); + assert_eq!(received, response); + } + + #[tokio::test] + async fn test_handle_notification_routing() { + let handler = ProtocolHandler::new(); + + let mut notifications = handler.take_notification_receiver().await.unwrap(); + + // Send a notification + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { "sessionId": "123" } + }); + + handler.handle_message(notification.clone()).await; + + // Verify notification was routed + let received = notifications.recv().await.unwrap(); + assert_eq!(received, notification); + } + + #[tokio::test] + async fn test_timeout_cleans_up_pending_request() { + let handler = ProtocolHandler::with_timeout(100); // 100ms timeout + + // Prepare request + let request = json!({ + "jsonrpc": "2.0", + "method": "test", + "params": {} + }); + + let (message_with_id, _response_rx) = handler.prepare_request(request).await; + let request_id = message_with_id.get("id").cloned().unwrap(); + + // Verify request is pending + assert!(handler + .pending_requests + .lock() + .await + .contains_key(&request_id)); + + // Wait for timeout + sleep(Duration::from_millis(150)).await; + + // Simulate timeout by trying to send response (should be cleaned up) + handler + .handle_message(json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": {} + })) + .await; + + // Request should still be in pending (timeout cleanup happens in send_request) + // Let's verify by creating a real timeout scenario + } + + #[tokio::test] + async fn test_response_with_unknown_id() { + let handler = ProtocolHandler::new(); + + // Send response with unknown ID (should be logged and discarded) + let response = json!({ + "jsonrpc": "2.0", + "id": 99999, + "result": {} + }); + + // This should not panic + handler.handle_message(response).await; + + // Verify no pending requests + assert_eq!(handler.pending_requests.lock().await.len(), 0); + } + + #[tokio::test] + async fn test_concurrent_requests_correlation() { + let handler = Arc::new(ProtocolHandler::new()); + let mut handles = vec![]; + + // Send 10 concurrent requests + for i in 0..10 { + let h = Arc::clone(&handler); + let handle = tokio::spawn(async move { + let request = json!({ + "jsonrpc": "2.0", + "method": "test", + "params": { "value": i } + }); + + let (message_with_id, response_rx) = h.prepare_request(request).await; + let request_id = message_with_id.get("id").cloned().unwrap(); + + // Simulate response + h.handle_message(json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": { "value": i } + })) + .await; + + // Receive response + response_rx.await.unwrap() + }); + handles.push(handle); + } + + // Verify all responses correlated correctly + for (i, handle) in handles.into_iter().enumerate() { + let response = handle.await.unwrap(); + assert_eq!(response["result"]["value"], i); + } + } + + #[tokio::test] + async fn test_json_rpc_error_response() { + let handler = ProtocolHandler::new(); + + let error_response = json!({ + "jsonrpc": "2.0", + "id": 1, + "error": { + "code": -32601, + "message": "Method not found", + "data": { "method": "unknown" } + } + }); + + let result = handler.handle_response_result(error_response); + + match result { + Err(ProtocolError::JsonRpcError { + code, + message, + data, + }) => { + assert_eq!(code, -32601); + assert_eq!(message, "Method not found"); + assert!(data.is_some()); + } + _ => panic!("Expected JsonRpcError"), + } + } + + #[tokio::test] + async fn test_invalid_message_no_id_no_method() { + let handler = ProtocolHandler::new(); + + // Message with neither ID nor method (should be logged and discarded) + let invalid = json!({ + "jsonrpc": "2.0", + "params": {} + }); + + // Should not panic + handler.handle_message(invalid).await; + } + + // Helper builder tests + + #[test] + fn test_build_initialize_request() { + let request = build_initialize_request(1, json!({})); + + assert_eq!(request["jsonrpc"], "2.0"); + assert_eq!(request["method"], "initialize"); + assert_eq!(request["params"]["protocolVersion"], 1); + assert!(request["params"]["clientCapabilities"].is_object()); + assert!(request.get("id").is_none()); // Should not have ID yet + } + + #[test] + fn test_build_session_new_request() { + let request = build_session_new_request(".", None); + + assert_eq!(request["jsonrpc"], "2.0"); + assert_eq!(request["method"], "session/new"); + assert_eq!(request["params"]["cwd"], "."); + assert!(request["params"]["mcpServers"].is_array()); + } + + #[test] + fn test_build_session_new_request_with_servers() { + let servers = json!([ + { "name": "server1", "url": "http://localhost:8080" } + ]); + + let request = build_session_new_request("/path", Some(servers.clone())); + + assert_eq!(request["params"]["mcpServers"], servers); + } + + #[test] + fn test_build_session_prompt_request() { + let prompt = json!([ + { + "type": "text", + "text": "Hello" + } + ]); + + let request = build_session_prompt_request("session-123", prompt.clone()); + + assert_eq!(request["jsonrpc"], "2.0"); + assert_eq!(request["method"], "session/prompt"); + assert_eq!(request["params"]["sessionId"], "session-123"); + assert_eq!(request["params"]["prompt"], prompt); + } + + #[test] + fn test_build_session_cancel_request() { + let request = build_session_cancel_request("session-456"); + + assert_eq!(request["jsonrpc"], "2.0"); + assert_eq!(request["method"], "session/cancel"); + assert_eq!(request["params"]["sessionId"], "session-456"); + } + + // ======================================================================== + // MessageHandlerResult Tests (T044) + // ======================================================================== + + #[tokio::test] + async fn test_message_handler_result_none_for_notification() { + let handler = ProtocolHandler::new(); + + // Notification (has method, no id) + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-123", + "update": { + "sessionUpdate": "agent_message_chunk", + "content": {"type": "text", "text": "Hello"} + } + } + }); + + let result = handler.handle_message(notification).await; + + // Should return None for notifications + assert!(matches!(result, MessageHandlerResult::None)); + } + + #[tokio::test] + async fn test_message_handler_result_response_for_unsupported_method() { + let handler = ProtocolHandler::new(); + + // Request with unsupported method (has method and id) + let request = json!({ + "jsonrpc": "2.0", + "id": 42, + "method": "unsupported/method", + "params": {} + }); + + let result = handler.handle_message(request).await; + + // Should return Response variant with error + match result { + MessageHandlerResult::Response(response) => { + assert_eq!(response.get("jsonrpc"), Some(&json!("2.0"))); + assert_eq!(response.get("id"), Some(&json!(42))); + assert!(response.get("error").is_some()); + + let error = response.get("error").unwrap(); + assert_eq!(error.get("code"), Some(&json!(-32601))); + assert!(error.get("message").unwrap().as_str().unwrap().contains("not supported")); + } + _ => panic!("Expected Response variant, got {:?}", result), + } + } + + #[tokio::test] + async fn test_message_handler_result_agent_request_for_permission() { + let handler = ProtocolHandler::new(); + + // Permission request (has method and id) + let request = json!({ + "jsonrpc": "2.0", + "id": 0, + "method": "session/request_permission", + "params": { + "question": "Allow Write tool?", + "options": [ + {"id": "allow", "label": "Allow"}, + {"id": "deny", "label": "Deny"} + ] + } + }); + + let result = handler.handle_message(request).await; + + // Should return AgentRequest variant + match result { + MessageHandlerResult::AgentRequest { request_id, method, params } => { + assert_eq!(request_id, json!(0)); + assert_eq!(method, "session/request_permission"); + assert_eq!(params.get("question").unwrap(), "Allow Write tool?"); + assert!(params.get("options").is_some()); + } + _ => panic!("Expected AgentRequest variant, got {:?}", result), + } + } + + #[tokio::test] + async fn test_message_handler_result_none_for_response_to_our_request() { + let handler = ProtocolHandler::new(); + + // Prepare a request first + let request = json!({ + "jsonrpc": "2.0", + "method": "test", + "params": {} + }); + + let (message_with_id, _response_rx) = handler.prepare_request(request).await; + let request_id = message_with_id.get("id").cloned().unwrap(); + + // Now simulate receiving a response + let response = json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": {"status": "ok"} + }); + + let result = handler.handle_message(response).await; + + // Should return None (response was delivered to pending request) + assert!(matches!(result, MessageHandlerResult::None)); + } + + #[tokio::test] + async fn test_message_handler_result_invalid_message() { + let handler = ProtocolHandler::new(); + + // Invalid message (no method, no id, no result/error) + let invalid = json!({ + "jsonrpc": "2.0", + "params": {} + }); + + let result = handler.handle_message(invalid).await; + + // Should return None (message logged and discarded) + assert!(matches!(result, MessageHandlerResult::None)); + } + + #[test] + fn test_build_session_list_request_no_params() { + let request = build_session_list_request(None, None); + assert_eq!(request["jsonrpc"], "2.0"); + assert_eq!(request["method"], "session/list"); + assert!(request["params"].is_object()); + } + + #[test] + fn test_build_session_list_request_with_cwd() { + let request = build_session_list_request(Some("/home/user/project"), None); + assert_eq!(request["params"]["cwd"], "/home/user/project"); + } + + #[test] + fn test_build_session_list_request_with_cursor() { + let request = build_session_list_request(None, Some("cursor-token-123")); + assert_eq!(request["params"]["cursor"], "cursor-token-123"); + } + + #[test] + fn test_build_session_list_request_with_both() { + let request = build_session_list_request(Some("/tmp"), Some("page2")); + assert_eq!(request["params"]["cwd"], "/tmp"); + assert_eq!(request["params"]["cursor"], "page2"); + } + + #[test] + fn test_build_session_load_request_basic() { + let request = build_session_load_request("session-123", "/home/user", None); + assert_eq!(request["jsonrpc"], "2.0"); + assert_eq!(request["method"], "session/load"); + assert_eq!(request["params"]["sessionId"], "session-123"); + assert_eq!(request["params"]["cwd"], "/home/user"); + assert_eq!(request["params"]["mcpServers"], json!([])); + } + + #[test] + fn test_build_session_load_request_with_mcp_servers() { + let servers = json!([{"name": "test-server", "url": "http://localhost:3000"}]); + let request = build_session_load_request("session-123", "/tmp", Some(servers.clone())); + assert_eq!(request["params"]["sessionId"], "session-123"); + assert_eq!(request["params"]["cwd"], "/tmp"); + assert_eq!(request["params"]["mcpServers"], servers); + } + + #[test] + fn test_build_session_resume_request_basic() { + let request = build_session_resume_request("sess-123", None, None); + assert_eq!(request["jsonrpc"], "2.0"); + assert_eq!(request["method"], "session/resume"); + assert_eq!(request["params"]["sessionId"], "sess-123"); + assert!(request["params"].get("cwd").is_none()); + assert!(request["params"].get("mcpServers").is_none()); + } + + #[test] + fn test_build_session_resume_request_with_cwd() { + let request = build_session_resume_request("sess-123", Some("/home/user"), None); + assert_eq!(request["params"]["sessionId"], "sess-123"); + assert_eq!(request["params"]["cwd"], "/home/user"); + } + + #[test] + fn test_build_session_resume_request_with_mcp_servers() { + let servers = json!([{"name": "srv"}]); + let request = build_session_resume_request("sess-123", Some("/tmp"), Some(servers.clone())); + assert_eq!(request["params"]["sessionId"], "sess-123"); + assert_eq!(request["params"]["cwd"], "/tmp"); + assert_eq!(request["params"]["mcpServers"], servers); + } + + #[test] + fn test_build_session_close_request() { + let request = build_session_close_request("session-abc"); + assert_eq!(request["jsonrpc"], "2.0"); + assert_eq!(request["method"], "session/close"); + assert_eq!(request["params"]["sessionId"], "session-abc"); + } + + #[test] + fn test_build_set_config_option_request() { + let request = build_set_config_option_request("session-abc", "mode", "code"); + assert_eq!(request["jsonrpc"], "2.0"); + assert_eq!(request["method"], "session/set_config_option"); + assert_eq!(request["params"]["sessionId"], "session-abc"); + assert_eq!(request["params"]["configId"], "mode"); + assert_eq!(request["params"]["value"], "code"); + } +} diff --git a/crates/dirigent_core/src/connectors/acp/state.rs b/crates/dirigent_core/src/connectors/acp/state.rs new file mode 100644 index 0000000..68aacd6 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/state.rs @@ -0,0 +1,977 @@ +//! Internal state management for ACP connector +//! +//! This module provides thread-safe state tracking for the ACP connector, +//! including protocol version, agent capabilities, and active sessions. + +use chrono::Utc; +use dirigent_protocol::{SessionModeState, SessionModelState, SessionOwnership}; +use serde_json::Value; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; +use tokio::sync::RwLock; + +/// Internal state for ACP connector +/// +/// Tracks connection state, protocol information, and active sessions. +/// All fields are protected by RwLock for thread-safe concurrent access. +/// +/// # Thread Safety +/// +/// This struct is designed to be shared across async tasks via `Arc`. +/// All mutations go through interior mutability patterns (RwLock). +#[derive(Debug)] +pub struct InternalState { + /// Negotiated protocol version + /// + /// Set during initialization handshake. None if not yet initialized. + protocol_version: Arc>>, + + /// Agent capabilities + /// + /// Populated from the initialize response. Contains information about + /// what the agent supports (streaming, tools, MCP servers, etc.). + agent_capabilities: Arc>>, + + /// Active sessions + /// + /// Map of session ID to session metadata. Updated as sessions are + /// created, updated, or terminated. + sessions: Arc>>, + + /// Active message IDs per session + /// + /// Tracks the current Dirigent message_id being accumulated for each session. + /// This allows us to translate many ACP chunks (each with different messageIds) + /// into a single Dirigent message with one consistent message_id. + active_messages: Arc>>, + + /// Message start times + /// + /// Tracks when each message started (when the first chunk arrived). + /// Maps message_id to timestamp. This allows us to provide accurate + /// created_at timestamps for messages instead of using finalization time. + message_start_times: Arc>>>, + + /// Loaded sessions + /// + /// Tracks which sessions have been successfully loaded (via session/load + /// or session/resume) during this connection. Used to decide whether to + /// load or resume a session: if already loaded, resume; otherwise, load first. + /// Cleared on reconnect. + loaded_sessions: Arc>>, +} + +/// Information about an active session +/// +/// Lightweight metadata tracked for each session. Does not include +/// full message history (that's stored by the agent). +#[derive(Debug, Clone)] +pub struct SessionInfo { + /// Unique session identifier + pub id: String, + + /// Session title (if available) + pub title: Option, + + /// Current working directory for this session + pub cwd: String, + + /// Number of messages in this session (approximate) + pub message_count: u32, + + /// Session model (if available) - legacy field for simple model name + pub model: Option, + + /// Creation timestamp + pub created_at: chrono::DateTime, + + /// Last activity timestamp + pub last_activity: chrono::DateTime, + + /// Session status + pub status: SessionStatus, + + // ======================================================================== + // ACP Session Metadata (from session/new and session/load responses) + // ======================================================================== + + /// Available models and current model selection + /// + /// Captured from ACP `session/new` or `session/load` response. + /// Note: This field is marked UNSTABLE in the ACP spec but Claude-ACP uses it. + pub models: Option, + + /// Available modes and current mode selection + /// + /// Captured from ACP `session/new` or `session/load` response. + /// This is part of the stable ACP specification. + pub modes: Option, + + /// ACP config options (replaces modes/models in future ACP versions) + /// + /// Captured from ACP `session/new` or `session/load` response `configOptions` field. + pub config_options: Option>, + + /// Ownership model for this session + /// + /// Tracks whether this session is internal to Dirigent or originates from + /// an external client, and how tool calls should be handled. + pub ownership: SessionOwnership, +} + +/// Status of a session +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SessionStatus { + /// Session is active and can accept prompts + Active, + + /// Session is processing a prompt (agent is responding) + Processing, + + /// Session is idle (no active processing) + Idle, + + /// Session has ended + Ended, +} + +impl InternalState { + /// Create a new empty state + /// + /// All fields are initialized to None/empty. State is populated + /// as the connector progresses through initialization and operation. + pub fn new() -> Self { + Self { + protocol_version: Arc::new(RwLock::new(None)), + agent_capabilities: Arc::new(RwLock::new(None)), + sessions: Arc::new(RwLock::new(HashMap::new())), + active_messages: Arc::new(RwLock::new(HashMap::new())), + message_start_times: Arc::new(RwLock::new(HashMap::new())), + loaded_sessions: Arc::new(RwLock::new(HashSet::new())), + } + } + + /// Get the negotiated protocol version + /// + /// Returns None if initialization has not completed. + pub async fn protocol_version(&self) -> Option { + *self.protocol_version.read().await + } + + /// Set the negotiated protocol version + /// + /// Called after successful initialization handshake. + pub async fn set_protocol_version(&self, version: u32) { + let mut guard = self.protocol_version.write().await; + *guard = Some(version); + } + + /// Get the agent capabilities + /// + /// Returns None if initialization has not completed. + pub async fn agent_capabilities(&self) -> Option { + self.agent_capabilities.read().await.clone() + } + + /// Set the agent capabilities + /// + /// Called after successful initialization handshake. + pub async fn set_agent_capabilities(&self, capabilities: Value) { + let mut guard = self.agent_capabilities.write().await; + *guard = Some(capabilities); + } + + /// Add or update a session + /// + /// If the session ID already exists, the info is updated. + /// Otherwise, a new entry is created. + pub async fn upsert_session(&self, info: SessionInfo) { + let mut sessions = self.sessions.write().await; + sessions.insert(info.id.clone(), info); + } + + /// Remove a session + /// + /// Called when a session is terminated or deleted. + pub async fn remove_session(&self, session_id: &str) { + let mut sessions = self.sessions.write().await; + sessions.remove(session_id); + } + + /// Get a specific session + /// + /// Returns None if the session is not tracked. + pub async fn get_session(&self, session_id: &str) -> Option { + let sessions = self.sessions.read().await; + sessions.get(session_id).cloned() + } + + /// List all tracked sessions + /// + /// Returns a snapshot of all session info. The returned Vec is + /// independent of the internal state and can be safely modified. + pub async fn list_sessions(&self) -> Vec { + let sessions = self.sessions.read().await; + sessions.values().cloned().collect() + } + + /// Get the number of tracked sessions + pub async fn session_count(&self) -> usize { + let sessions = self.sessions.read().await; + sessions.len() + } + + /// Update message count for a session + /// + /// Convenience method to increment or set message count without + /// replacing the entire SessionInfo. + pub async fn update_message_count(&self, session_id: &str, count: u32) { + let mut sessions = self.sessions.write().await; + if let Some(info) = sessions.get_mut(session_id) { + info.message_count = count; + info.last_activity = chrono::Utc::now(); + } + } + + /// Check if a session is active + /// + /// Returns true if the session exists and is not in Ended status. + pub async fn is_session_active(&self, session_id: &str) -> bool { + let sessions = self.sessions.read().await; + sessions + .get(session_id) + .map(|info| info.status != SessionStatus::Ended) + .unwrap_or(false) + } + + /// Update session status + /// + /// Convenience method to update session status and last activity time. + pub async fn update_session_status(&self, session_id: &str, status: SessionStatus) { + let mut sessions = self.sessions.write().await; + if let Some(info) = sessions.get_mut(session_id) { + info.status = status; + info.last_activity = Utc::now(); + } + } + + /// Touch session (update last activity timestamp) + /// + /// Called when any activity occurs on a session. + pub async fn touch_session(&self, session_id: &str) { + let mut sessions = self.sessions.write().await; + if let Some(info) = sessions.get_mut(session_id) { + info.last_activity = Utc::now(); + } + } + + /// Get or create active message ID for a session + /// + /// If no message is currently being accumulated for this session, generates + /// a new UUID-based Dirigent message_id and stores it. Returns the active + /// message_id (either existing or newly created). + /// + /// This allows us to translate multiple ACP chunks (each with different + /// ACP messageIds) into a single Dirigent message with one consistent ID. + pub async fn get_or_create_active_message(&self, session_id: &str) -> String { + let mut active_messages = self.active_messages.write().await; + + // Check if this is a new message BEFORE inserting + let is_new_message = !active_messages.contains_key(session_id); + + let message_id = active_messages + .entry(session_id.to_string()) + .or_insert_with(|| format!("msg-{}", uuid::Uuid::now_v7())) + .clone(); + + // Track message start time when creating new message + if is_new_message { + let mut start_times = self.message_start_times.write().await; + start_times.insert(message_id.clone(), Utc::now()); + } + + message_id + } + + /// Clear active message ID for a session + /// + /// Called when a message is completed (after receiving prompt response). + /// This allows the next message to get a fresh Dirigent message_id. + pub async fn clear_active_message(&self, session_id: &str) -> Option { + let mut active_messages = self.active_messages.write().await; + active_messages.remove(session_id) + } + + /// Get active message ID for a session (if any) + /// + /// Returns None if no message is currently being accumulated. + pub async fn get_active_message(&self, session_id: &str) -> Option { + let active_messages = self.active_messages.read().await; + active_messages.get(session_id).cloned() + } + + /// Get message start time + /// + /// Returns the timestamp when the message started (first chunk arrived). + /// Returns None if the message is not tracked. + pub async fn get_message_start_time(&self, message_id: &str) -> Option> { + let start_times = self.message_start_times.read().await; + start_times.get(message_id).copied() + } + + /// Clear message start time + /// + /// Called after a message is finalized and archived. + /// This prevents memory leaks from accumulating start times. + pub async fn clear_message_start_time(&self, message_id: &str) { + let mut start_times = self.message_start_times.write().await; + start_times.remove(message_id); + } + + /// Clear all state + /// + /// Resets the state to its initial empty condition. Called on + /// reconnection or when recovering from errors. + pub async fn clear(&self) { + let mut protocol_version = self.protocol_version.write().await; + *protocol_version = None; + + let mut capabilities = self.agent_capabilities.write().await; + *capabilities = None; + + let mut sessions = self.sessions.write().await; + sessions.clear(); + + let mut active_messages = self.active_messages.write().await; + active_messages.clear(); + + let mut start_times = self.message_start_times.write().await; + start_times.clear(); + + let mut loaded = self.loaded_sessions.write().await; + loaded.clear(); + } + + /// Check if the connector is initialized + /// + /// Returns true if we have completed the initialization handshake + /// (protocol version and capabilities are set). + pub async fn is_initialized(&self) -> bool { + self.protocol_version().await.is_some() + } + + /// Check if the upstream agent supports session/load + /// + /// Returns true if `agentCapabilities.loadSession` is `true`. + pub async fn agent_supports_load_session(&self) -> bool { + self.agent_capabilities + .read() + .await + .as_ref() + .and_then(|caps| caps.get("loadSession")) + .and_then(|v| v.as_bool()) + .unwrap_or(false) + } + + /// Check if the upstream agent supports session/list + /// + /// Checks `sessionCapabilities.list` (spec-correct, object presence) with + /// fallback to legacy `listSessions` (boolean). + pub async fn agent_supports_list_sessions(&self) -> bool { + self.agent_capabilities.read().await.as_ref() + .map(|caps| { + // Spec-correct: sessionCapabilities.list (object presence) + let nested = caps.get("sessionCapabilities") + .and_then(|sc| sc.get("list")) + .is_some(); + // Legacy compat: listSessions (boolean) + let flat = caps.get("listSessions") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + nested || flat + }) + .unwrap_or(false) + } + + /// Check if the upstream agent supports session/resume + /// + /// Returns true if `agentCapabilities.sessionCapabilities.resume` exists. + pub async fn agent_supports_session_resume(&self) -> bool { + self.agent_capabilities + .read() + .await + .as_ref() + .and_then(|caps| caps.get("sessionCapabilities")) + .and_then(|sc| sc.get("resume")) + .is_some() + } + + /// Check if the upstream agent supports session/close + /// + /// Returns true if `agentCapabilities.sessionCapabilities.close` exists. + pub async fn agent_supports_session_close(&self) -> bool { + self.agent_capabilities + .read() + .await + .as_ref() + .and_then(|caps| caps.get("sessionCapabilities")) + .and_then(|sc| sc.get("close")) + .is_some() + } + + /// Mark a session as loaded + /// + /// Called after a successful session/load or session/resume. Subsequent + /// operations on this session can use session/resume instead of session/load. + pub async fn mark_session_loaded(&self, session_id: &str) { + let mut loaded = self.loaded_sessions.write().await; + loaded.insert(session_id.to_string()); + } + + /// Check if a session has been loaded in this connection + /// + /// Returns true if the session was previously loaded via session/load or + /// session/resume. Used to decide whether to load or resume. + pub async fn is_session_loaded(&self, session_id: &str) -> bool { + let loaded = self.loaded_sessions.read().await; + loaded.contains(session_id) + } + + /// Unmark a session as loaded + /// + /// Called when a session is closed. The next access will require a + /// full session/load again. + pub async fn unmark_session_loaded(&self, session_id: &str) { + let mut loaded = self.loaded_sessions.write().await; + loaded.remove(session_id); + } +} + +impl Default for InternalState { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_state_creation() { + let state = InternalState::new(); + + assert_eq!(state.protocol_version().await, None); + assert_eq!(state.agent_capabilities().await, None); + assert_eq!(state.session_count().await, 0); + assert!(!state.is_initialized().await); + } + + #[tokio::test] + async fn test_protocol_version() { + let state = InternalState::new(); + + state.set_protocol_version(1).await; + + assert_eq!(state.protocol_version().await, Some(1)); + assert!(state.is_initialized().await); + } + + #[tokio::test] + async fn test_agent_capabilities() { + let state = InternalState::new(); + + let caps = serde_json::json!({ + "streaming": true, + "tools": ["bash", "edit"] + }); + + state.set_agent_capabilities(caps.clone()).await; + + let retrieved = state.agent_capabilities().await; + assert_eq!(retrieved, Some(caps)); + } + + #[tokio::test] + async fn test_session_management() { + let state = InternalState::new(); + + let now = Utc::now(); + let session1 = SessionInfo { + id: "session-1".to_string(), + title: Some("Test Session".to_string()), + cwd: ".".to_string(), + message_count: 0, + model: Some("claude-3".to_string()), + created_at: now, + last_activity: now, + status: SessionStatus::Active, + models: None, + modes: None, + config_options: None, + ownership: SessionOwnership::default(), + }; + + let session2 = SessionInfo { + id: "session-2".to_string(), + title: None, + cwd: "/tmp".to_string(), + message_count: 5, + model: None, + created_at: now, + last_activity: now, + status: SessionStatus::Idle, + models: None, + modes: None, + config_options: None, + ownership: SessionOwnership::default(), + }; + + // Add sessions + state.upsert_session(session1.clone()).await; + state.upsert_session(session2.clone()).await; + + assert_eq!(state.session_count().await, 2); + + // Get specific session + let retrieved = state.get_session("session-1").await; + assert!(retrieved.is_some()); + assert_eq!(retrieved.unwrap().id, "session-1"); + + // List all sessions + let all_sessions = state.list_sessions().await; + assert_eq!(all_sessions.len(), 2); + + // Remove a session + state.remove_session("session-1").await; + assert_eq!(state.session_count().await, 1); + assert!(state.get_session("session-1").await.is_none()); + } + + #[tokio::test] + async fn test_session_upsert() { + let state = InternalState::new(); + + let now = Utc::now(); + let session = SessionInfo { + id: "session-1".to_string(), + title: Some("Original".to_string()), + cwd: ".".to_string(), + message_count: 1, + model: None, + created_at: now, + last_activity: now, + status: SessionStatus::Active, + models: None, + modes: None, + config_options: None, + ownership: SessionOwnership::default(), + }; + + // Add session + state.upsert_session(session.clone()).await; + assert_eq!(state.session_count().await, 1); + + // Update session (same ID) + let updated = SessionInfo { + id: "session-1".to_string(), + title: Some("Updated".to_string()), + cwd: ".".to_string(), + message_count: 2, + model: Some("claude-3".to_string()), + created_at: now, + last_activity: now, + status: SessionStatus::Processing, + models: None, + modes: None, + config_options: None, + ownership: SessionOwnership::default(), + }; + + state.upsert_session(updated.clone()).await; + assert_eq!(state.session_count().await, 1); + + let retrieved = state.get_session("session-1").await.unwrap(); + assert_eq!(retrieved.title, Some("Updated".to_string())); + assert_eq!(retrieved.message_count, 2); + } + + #[tokio::test] + async fn test_update_message_count() { + let state = InternalState::new(); + + let now = Utc::now(); + let session = SessionInfo { + id: "session-1".to_string(), + title: None, + cwd: ".".to_string(), + message_count: 0, + model: None, + created_at: now, + last_activity: now, + status: SessionStatus::Active, + models: None, + modes: None, + config_options: None, + ownership: SessionOwnership::default(), + }; + + state.upsert_session(session).await; + + state.update_message_count("session-1", 10).await; + + let retrieved = state.get_session("session-1").await.unwrap(); + assert_eq!(retrieved.message_count, 10); + } + + #[tokio::test] + async fn test_clear_state() { + let state = InternalState::new(); + + // Populate state + state.set_protocol_version(1).await; + state.set_agent_capabilities(serde_json::json!({})).await; + let now = Utc::now(); + state.upsert_session(SessionInfo { + id: "session-1".to_string(), + title: None, + cwd: ".".to_string(), + message_count: 0, + model: None, + created_at: now, + last_activity: now, + status: SessionStatus::Active, + models: None, + modes: None, + config_options: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }).await; + + assert!(state.is_initialized().await); + assert_eq!(state.session_count().await, 1); + + // Clear everything + state.clear().await; + + assert!(!state.is_initialized().await); + assert_eq!(state.session_count().await, 0); + assert_eq!(state.protocol_version().await, None); + assert_eq!(state.agent_capabilities().await, None); + } + + #[tokio::test] + async fn test_concurrent_access() { + let state = Arc::new(InternalState::new()); + + let mut handles = vec![]; + + // Spawn 10 tasks that all add sessions + for i in 0..10 { + let state_clone = Arc::clone(&state); + let handle = tokio::spawn(async move { + let now = Utc::now(); + let session = SessionInfo { + id: format!("session-{}", i), + title: None, + cwd: ".".to_string(), + message_count: 0, + model: None, + created_at: now, + last_activity: now, + status: SessionStatus::Active, + models: None, + modes: None, + config_options: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }; + state_clone.upsert_session(session).await; + }); + handles.push(handle); + } + + // Wait for all tasks + for handle in handles { + handle.await.unwrap(); + } + + // Verify all sessions were added + assert_eq!(state.session_count().await, 10); + } + + #[tokio::test] + async fn test_is_session_active() { + let state = InternalState::new(); + let now = Utc::now(); + + // Active session + state.upsert_session(SessionInfo { + id: "active".to_string(), + title: None, + cwd: ".".to_string(), + message_count: 0, + model: None, + created_at: now, + last_activity: now, + status: SessionStatus::Active, + models: None, + modes: None, + config_options: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }).await; + + // Ended session + state.upsert_session(SessionInfo { + id: "ended".to_string(), + title: None, + cwd: ".".to_string(), + message_count: 0, + model: None, + created_at: now, + last_activity: now, + status: SessionStatus::Ended, + models: None, + modes: None, + config_options: None, + ownership: SessionOwnership::default(), + }).await; + + assert!(state.is_session_active("active").await); + assert!(!state.is_session_active("ended").await); + assert!(!state.is_session_active("nonexistent").await); + } + + #[tokio::test] + async fn test_update_session_status() { + let state = InternalState::new(); + let now = Utc::now(); + + state.upsert_session(SessionInfo { + id: "session-1".to_string(), + title: None, + cwd: ".".to_string(), + message_count: 0, + model: None, + created_at: now, + last_activity: now, + status: SessionStatus::Active, + models: None, + modes: None, + config_options: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }).await; + + // Change status to Processing + state.update_session_status("session-1", SessionStatus::Processing).await; + + let session = state.get_session("session-1").await.unwrap(); + assert_eq!(session.status, SessionStatus::Processing); + assert!(session.last_activity > now); + } + + #[tokio::test] + async fn test_touch_session() { + let state = InternalState::new(); + let now = Utc::now(); + + state.upsert_session(SessionInfo { + id: "session-1".to_string(), + title: None, + cwd: ".".to_string(), + message_count: 0, + model: None, + created_at: now, + last_activity: now, + status: SessionStatus::Active, + models: None, + modes: None, + config_options: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }).await; + + // Small delay to ensure timestamp difference + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + + // Touch the session + state.touch_session("session-1").await; + + let session = state.get_session("session-1").await.unwrap(); + assert!(session.last_activity > now); + } + + #[tokio::test] + async fn test_agent_supports_load_session_true() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "loadSession": true + })).await; + assert!(state.agent_supports_load_session().await); + } + + #[tokio::test] + async fn test_agent_supports_load_session_false() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "loadSession": false + })).await; + assert!(!state.agent_supports_load_session().await); + } + + #[tokio::test] + async fn test_agent_supports_load_session_missing() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({})).await; + assert!(!state.agent_supports_load_session().await); + } + + #[tokio::test] + async fn test_agent_supports_list_sessions_legacy_true() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "listSessions": true + })).await; + assert!(state.agent_supports_list_sessions().await); + } + + #[tokio::test] + async fn test_agent_supports_list_sessions_legacy_false() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "listSessions": false + })).await; + assert!(!state.agent_supports_list_sessions().await); + } + + #[tokio::test] + async fn test_agent_supports_list_sessions_nested() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "sessionCapabilities": { + "list": {} + } + })).await; + assert!(state.agent_supports_list_sessions().await); + } + + #[tokio::test] + async fn test_agent_supports_list_sessions_nested_with_options() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "sessionCapabilities": { + "list": { "supportsCwd": true } + } + })).await; + assert!(state.agent_supports_list_sessions().await); + } + + #[tokio::test] + async fn test_agent_supports_list_sessions_both_formats() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "listSessions": true, + "sessionCapabilities": { + "list": {} + } + })).await; + assert!(state.agent_supports_list_sessions().await); + } + + #[tokio::test] + async fn test_agent_supports_list_sessions_missing() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({})).await; + assert!(!state.agent_supports_list_sessions().await); + } + + #[tokio::test] + async fn test_agent_supports_list_sessions_no_caps() { + let state = InternalState::new(); + assert!(!state.agent_supports_list_sessions().await); + } + + #[tokio::test] + async fn test_agent_supports_session_resume_true() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "sessionCapabilities": { + "resume": {} + } + })).await; + assert!(state.agent_supports_session_resume().await); + } + + #[tokio::test] + async fn test_agent_supports_session_resume_missing() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "sessionCapabilities": {} + })).await; + assert!(!state.agent_supports_session_resume().await); + } + + #[tokio::test] + async fn test_agent_supports_session_resume_no_session_caps() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({})).await; + assert!(!state.agent_supports_session_resume().await); + } + + #[tokio::test] + async fn test_agent_supports_session_close_true() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "sessionCapabilities": { + "close": {} + } + })).await; + assert!(state.agent_supports_session_close().await); + } + + #[tokio::test] + async fn test_agent_supports_session_close_missing() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({ + "sessionCapabilities": {} + })).await; + assert!(!state.agent_supports_session_close().await); + } + + #[tokio::test] + async fn test_agent_supports_session_close_no_session_caps() { + let state = InternalState::new(); + state.set_agent_capabilities(serde_json::json!({})).await; + assert!(!state.agent_supports_session_close().await); + } + + #[tokio::test] + async fn test_loaded_session_tracking() { + let state = InternalState::new(); + + // Initially not loaded + assert!(!state.is_session_loaded("sess-1").await); + + // Mark as loaded + state.mark_session_loaded("sess-1").await; + assert!(state.is_session_loaded("sess-1").await); + assert!(!state.is_session_loaded("sess-2").await); + + // Unmark + state.unmark_session_loaded("sess-1").await; + assert!(!state.is_session_loaded("sess-1").await); + } + + #[tokio::test] + async fn test_loaded_sessions_cleared_on_reconnect() { + let state = InternalState::new(); + + state.mark_session_loaded("sess-1").await; + state.mark_session_loaded("sess-2").await; + assert!(state.is_session_loaded("sess-1").await); + assert!(state.is_session_loaded("sess-2").await); + + // Clear simulates reconnect + state.clear().await; + + assert!(!state.is_session_loaded("sess-1").await); + assert!(!state.is_session_loaded("sess-2").await); + } +} diff --git a/crates/dirigent_core/src/connectors/acp/title_utils.rs b/crates/dirigent_core/src/connectors/acp/title_utils.rs new file mode 100644 index 0000000..de37cd9 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/title_utils.rs @@ -0,0 +1,134 @@ +//! Session Title Derivation Utilities +//! +//! This module provides utilities for deriving session titles from user messages. + +/// Maximum length for derived titles +pub const MAX_TITLE_LENGTH: usize = 50; + +/// Derive a session title from user message text +/// +/// Takes the first 50 characters, trimming whitespace and trying to break at word boundaries. +/// Sanitizes by removing newlines and control characters. +/// +/// # Arguments +/// * `text` - The user message text to derive title from +/// +/// # Returns +/// Derived title (max 50 chars, trimmed, sanitized) +pub fn derive_title_from_text(text: &str) -> String { + // Sanitize: replace newlines and control chars with space + let sanitized = text + .chars() + .map(|c| if c.is_control() { ' ' } else { c }) + .collect::(); + + // Trim whitespace + let trimmed = sanitized.trim(); + + // Handle empty case + if trimmed.is_empty() { + return "Untitled Session".to_string(); + } + + // If text is short enough, return as-is + if trimmed.len() <= MAX_TITLE_LENGTH { + return trimmed.to_string(); + } + + // Text is longer than MAX_TITLE_LENGTH - truncate at a valid char boundary + let end = trimmed.floor_char_boundary(MAX_TITLE_LENGTH); + let truncated = &trimmed[..end]; + + // Try to break at word boundary within the truncated text + if let Some(last_space) = truncated.rfind(' ') { + // Break at word boundary + trimmed[..last_space].trim_end().to_string() + } else { + // No spaces found - just truncate + truncated.trim_end().to_string() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_derive_title_from_text_short() { + let text = "Hello world"; + let title = derive_title_from_text(text); + assert_eq!(title, "Hello world"); + } + + #[test] + fn test_derive_title_from_text_long_with_word_boundary() { + let text = "This is a very long message that should be truncated at word boundary"; + let title = derive_title_from_text(text); + assert!(title.len() <= 50); + assert!(title.starts_with("This is a very long message")); + // Should break at word boundary - the exact result depends on word positions + // Just verify it doesn't exceed 50 chars and doesn't end mid-word + assert!(!title.ends_with("at")); // Shouldn't break "at word" + // The title should end on a complete word + let last_char = title.chars().last().unwrap(); + assert!(last_char.is_alphanumeric() || last_char == ' '); + } + + #[test] + fn test_derive_title_from_text_long_no_spaces() { + let text = "Thisisaverylongmessagewithoutanyspacesthatcannotbetruncatedatwordboundary"; + let title = derive_title_from_text(text); + assert_eq!(title.len(), 50); + // Without spaces, it will just truncate at 50 chars + assert_eq!(title, &text[..50]); + } + + #[test] + fn test_derive_title_from_text_with_newlines() { + let text = "First line\nSecond line\nThird line"; + let title = derive_title_from_text(text); + // Newlines should be replaced with spaces + assert!(!title.contains('\n')); + assert!(title.contains("First line Second line")); + } + + #[test] + fn test_derive_title_from_text_empty() { + let text = ""; + let title = derive_title_from_text(text); + assert_eq!(title, "Untitled Session"); + } + + #[test] + fn test_derive_title_from_text_whitespace_only() { + let text = " \n\t "; + let title = derive_title_from_text(text); + assert_eq!(title, "Untitled Session"); + } + + #[test] + fn test_derive_title_from_text_with_control_chars() { + let text = "Hello\x00world\x01test\x02"; + let title = derive_title_from_text(text); + // Control chars should be replaced with spaces + assert!(!title.chars().any(|c| c.is_control() && c != ' ')); + assert_eq!(title, "Hello world test"); + } + + #[test] + fn test_derive_title_exact_max_length() { + // Create a string of exactly MAX_TITLE_LENGTH characters + let text = "a".repeat(MAX_TITLE_LENGTH); + let title = derive_title_from_text(&text); + assert_eq!(title.len(), MAX_TITLE_LENGTH); + assert_eq!(title, text); + } + + #[test] + fn test_derive_title_one_over_max_length() { + // Create a string just over MAX_TITLE_LENGTH + let text = "a".repeat(MAX_TITLE_LENGTH + 1); + let title = derive_title_from_text(&text); + assert_eq!(title.len(), MAX_TITLE_LENGTH); + } +} diff --git a/crates/dirigent_core/src/connectors/acp/transport/http.rs b/crates/dirigent_core/src/connectors/acp/transport/http.rs new file mode 100644 index 0000000..c177e32 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/transport/http.rs @@ -0,0 +1,654 @@ +//! HTTP transport for ACP +//! +//! This module provides a transport implementation that communicates with ACP agents +//! via HTTP POST for requests and Server-Sent Events (SSE) for notifications. +//! +//! # Protocol +//! +//! - **Requests**: HTTP POST to `/jsonrpc` endpoint with JSON-RPC payload +//! - **Notifications**: SSE stream from `/events` endpoint for real-time updates +//! +//! # Architecture +//! +//! The transport maintains two connections: +//! 1. HTTP client for sending requests (reqwest::Client) +//! 2. SSE client for receiving notifications (eventsource stream) +//! +//! Responses and notifications are multiplexed into a single receive channel +//! using a background task that routes SSE events based on their ID. +//! +//! # Lifecycle +//! +//! 1. Create HttpTransport with base URL +//! 2. Call `connect()` to establish SSE connection +//! 3. Use `send()` for requests, `recv()` for responses and notifications +//! 4. Call `close()` to shut down connections + +use super::{AcpTransport, TransportResult}; +use async_trait::async_trait; +use futures::StreamExt; +use reqwest::Client; +use serde_json::Value; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{mpsc, Mutex}; +use tokio::task::JoinHandle; +use tracing::{debug, error, info, warn}; + +/// HTTP transport implementation +/// +/// Communicates with ACP agents via HTTP POST (requests) and SSE (notifications). +/// +/// # Example +/// +/// ```no_run +/// use dirigent_core::connectors::acp::transport::{AcpTransport, HttpTransport}; +/// use serde_json::json; +/// +/// # async fn example() -> anyhow::Result<()> { +/// // Connect to ACP agent HTTP server +/// let mut transport = HttpTransport::new("http://localhost:8080"); +/// +/// // Connect (establishes SSE stream) +/// transport.connect().await?; +/// +/// // Send initialize request +/// transport.send(json!({ +/// "jsonrpc": "2.0", +/// "method": "initialize", +/// "id": 1, +/// "params": { +/// "protocolVersion": 1, +/// "clientCapabilities": {} +/// } +/// })).await?; +/// +/// // Receive response +/// if let Some(response) = transport.recv().await? { +/// println!("Initialized: {:?}", response); +/// } +/// +/// // Clean up +/// transport.close().await?; +/// # Ok(()) +/// # } +/// ``` +pub struct HttpTransport { + /// Base URL of the ACP server (e.g., "http://localhost:8080") + base_url: String, + + /// Client ID for SSE subscription and RPC identification + /// + /// This UUID is generated when connecting to the ACP server and is used: + /// - As query parameter for SSE endpoint: `/events?client_id=...` + /// - In RPC requests to identify this client (future: via X-Client-Id header) + client_id: String, + + /// Request timeout + timeout: Duration, + + /// HTTP client for sending requests + client: Client, + + /// Channel for receiving messages (responses + notifications) + /// + /// This channel is fed by the SSE background task and HTTP responses. + /// The recv() method reads from this channel. + rx: Arc>>>, + + /// Sender for the receive channel (used internally) + tx: Arc>>>, + + /// Join handle for the SSE background task + sse_task: Arc>>>, + + /// Flag to indicate if the transport is connected + connected: Arc>, +} + +impl HttpTransport { + /// Create a new HttpTransport + /// + /// # Arguments + /// + /// * `base_url` - Base URL of the ACP server (e.g., "http://localhost:8080") + /// * `client_id` - Client ID for SSE subscription (UUID string) + /// + /// # Returns + /// + /// A new HttpTransport that is not yet connected. Call `connect()` to + /// establish the SSE connection. + /// + /// # Example + /// + /// ```no_run + /// use dirigent_core::connectors::acp::transport::HttpTransport; + /// + /// let client_id = uuid::Uuid::new_v4().to_string(); + /// let transport = HttpTransport::new("http://localhost:8080", client_id); + /// ``` + pub fn new(base_url: impl Into, client_id: impl Into) -> Self { + let base_url = base_url.into(); + let client_id = client_id.into(); + + info!( + base_url = %base_url, + client_id = %client_id, + "Creating HTTP transport" + ); + + let (tx, rx) = mpsc::channel(1000); // Buffer up to 1000 messages + + Self { + base_url, + client_id, + timeout: Duration::from_secs(30), // Default 30 seconds + client: Client::new(), + rx: Arc::new(Mutex::new(Some(rx))), + tx: Arc::new(Mutex::new(Some(tx))), + sse_task: Arc::new(Mutex::new(None)), + connected: Arc::new(Mutex::new(false)), + } + } + + /// Set the request timeout + /// + /// Must be called before `connect()`. + pub fn set_timeout(&mut self, timeout: Duration) { + self.timeout = timeout; + } + + /// Start the SSE background task + /// + /// This task connects to the `/events` endpoint and forwards all SSE events + /// to the receive channel. + async fn start_sse_task(&self) -> TransportResult<()> { + let url = format!("{}/events?client_id={}", self.base_url, self.client_id); + let tx = self + .tx + .lock() + .await + .clone() + .ok_or("Transport already closed")?; + + info!(url = %url, client_id = %self.client_id, "Starting SSE subscription"); + + // Create SSE client using reqwest + let client = self.client.clone(); + let sse_url = url.clone(); + + let handle = tokio::spawn(async move { + debug!("SSE task started"); + + // Connect to SSE endpoint + let response = match client.get(&sse_url).send().await { + Ok(r) => r, + Err(e) => { + error!(error = %e, "Failed to connect to SSE endpoint"); + return; + } + }; + + if !response.status().is_success() { + error!(status = %response.status(), "SSE endpoint returned error"); + return; + } + + // Read SSE stream using eventsource-client + use eventsource_client::{Client as EventSourceClient, ClientBuilder, SSE}; + + let sse_client = match ClientBuilder::for_url(&sse_url) { + Ok(c) => c.build(), + Err(e) => { + error!(error = %e, "Failed to create SSE client"); + return; + } + }; + + let mut stream = sse_client.stream(); + + // Process SSE events + while let Some(event_result) = stream.next().await { + match event_result { + Ok(SSE::Connected(_)) => { + info!("SSE stream connected"); + } + Ok(SSE::Event(event)) => { + debug!(event = ?event, "Received SSE event"); + + // Parse the event data as JSON + match serde_json::from_str::(&event.data) { + Ok(message) => { + let masked_message = dirigent_protocol::log_utils::format_for_log(&message); + debug!(message = %masked_message, "Parsed SSE message"); + + // Send to receive channel + if let Err(e) = tx.send(message).await { + warn!(error = %e, "Failed to forward SSE message (receiver dropped)"); + break; + } + } + Err(e) => { + warn!(error = %e, data = %event.data, "Failed to parse SSE event data"); + } + } + } + Ok(SSE::Comment(comment)) => { + debug!(comment = %comment, "Received SSE comment"); + } + Err(e) => { + error!(error = %e, "SSE stream error"); + break; + } + } + } + + info!("SSE task ended"); + }); + + // Store the task handle + *self.sse_task.lock().await = Some(handle); + + Ok(()) + } +} + +#[async_trait] +impl AcpTransport for HttpTransport { + async fn connect(&mut self) -> TransportResult<()> { + let mut connected = self.connected.lock().await; + + if *connected { + return Err("Transport already connected".into()); + } + + info!(base_url = %self.base_url, "Connecting HTTP transport"); + + // Start SSE background task + self.start_sse_task().await?; + + *connected = true; + + info!("HTTP transport connected"); + + Ok(()) + } + + async fn send(&mut self, message: Value) -> TransportResult<()> { + let connected = self.connected.lock().await; + if !*connected { + return Err("Transport not connected".into()); + } + drop(connected); // Release lock + + // Send HTTP POST to /jsonrpc + let url = format!("{}/jsonrpc", self.base_url); + + let masked_message = dirigent_protocol::log_utils::format_for_log(&message); + debug!( + url = %url, + client_id = %self.client_id, + message = %masked_message, + "Sending HTTP request with X-Client-Id header" + ); + + let response = self + .client + .post(&url) + .header("X-Client-ID", &self.client_id) + .json(&message) + .send() + .await + .map_err(|e| format!("Failed to send HTTP request: {}", e))?; + + let status = response.status(); + + if !status.is_success() { + let error_body = response + .text() + .await + .unwrap_or_else(|_| "".to_string()); + return Err(format!("HTTP request failed with status {}: {}", status, error_body).into()); + } + + // Parse the response + let response_json: Value = response + .json() + .await + .map_err(|e| format!("Failed to parse response JSON: {}", e))?; + + let masked_response = dirigent_protocol::log_utils::format_for_log(&response_json); + debug!(response = %masked_response, "Received HTTP response"); + + // Forward the response to the receive channel + let tx = self + .tx + .lock() + .await + .clone() + .ok_or("Transport closed")?; + tx.send(response_json) + .await + .map_err(|e| format!("Failed to forward response: {}", e))?; + + Ok(()) + } + + async fn recv(&mut self) -> TransportResult> { + let mut rx_lock = self.rx.lock().await; + let rx = rx_lock + .as_mut() + .ok_or("Transport not connected (no receiver)")?; + + // Receive from the channel (blocks until message available) + match rx.recv().await { + Some(message) => { + let masked_message = dirigent_protocol::log_utils::format_for_log(&message); + debug!(message = %masked_message, "Received message"); + Ok(Some(message)) + } + None => { + // Channel closed (SSE task ended) + debug!("Receive channel closed"); + Ok(None) + } + } + } + + async fn close(&mut self) -> TransportResult<()> { + info!("Closing HTTP transport"); + + // Mark as disconnected + *self.connected.lock().await = false; + + // Drop the sender to close the channel + *self.tx.lock().await = None; + + // Cancel the SSE task + let mut sse_task_lock = self.sse_task.lock().await; + if let Some(handle) = sse_task_lock.take() { + debug!("Aborting SSE task"); + handle.abort(); + + // Wait for task to finish (with timeout) + match tokio::time::timeout(std::time::Duration::from_secs(2), handle).await { + Ok(Ok(())) => { + info!("SSE task finished gracefully"); + } + Ok(Err(e)) if e.is_cancelled() => { + info!("SSE task was cancelled"); + } + Ok(Err(e)) => { + warn!(error = %e, "SSE task panicked"); + } + Err(_) => { + warn!("SSE task did not finish within timeout"); + } + } + } + + // Drop the receiver + *self.rx.lock().await = None; + + info!("HTTP transport closed"); + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + use std::net::SocketAddr; + use axum::{ + routing::{get, post}, + Json, Router, + }; + use axum::response::{sse::Event, Sse}; + use futures::stream; + use std::convert::Infallible; + use std::time::Duration; + + /// Helper to find a free port for the test server + async fn find_free_port() -> u16 { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("Failed to bind to random port"); + listener + .local_addr() + .expect("Failed to get local addr") + .port() + } + + /// Helper to start a mock ACP HTTP server for testing + async fn start_mock_server() -> (String, tokio::task::JoinHandle<()>) { + + let port = find_free_port().await; + let addr: SocketAddr = format!("127.0.0.1:{}", port).parse().unwrap(); + + // Create router + let app = Router::new() + .route("/jsonrpc", post(handle_jsonrpc)) + .route("/events", get(handle_sse)); + + // Spawn server + let handle = tokio::spawn(async move { + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + axum::serve(listener, app).await.unwrap(); + }); + + // Wait a bit for server to start + tokio::time::sleep(Duration::from_millis(100)).await; + + (format!("http://127.0.0.1:{}", port), handle) + } + + async fn handle_jsonrpc(Json(request): Json) -> Json { + // Echo back a successful response + let id = request.get("id").cloned().unwrap_or(json!(null)); + let method = request + .get("method") + .and_then(|m| m.as_str()) + .unwrap_or(""); + + let result = match method { + "initialize" => json!({ + "protocolVersion": 1, + "agentCapabilities": {} + }), + "session/new" => json!({ + "sessionId": "test-session-123" + }), + _ => json!({}), + }; + + Json(json!({ + "jsonrpc": "2.0", + "id": id, + "result": result + })) + } + + async fn handle_sse() -> Sse>> { + // Send a few test events + let stream = stream::iter(vec![ + Ok(Event::default().data( + json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-123", + "update": { + "type": "messageStarted" + } + } + }) + .to_string(), + )), + Ok(Event::default().data( + json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-123", + "update": { + "type": "messageCompleted" + } + } + }) + .to_string(), + )), + ]); + + Sse::new(stream) + } + + #[tokio::test] + async fn test_http_transport_basic_flow() { + let (base_url, _server_handle) = start_mock_server().await; + + // Create transport + let mut transport = HttpTransport::new(&base_url, "test-client"); + + // Connect + transport.connect().await.expect("Failed to connect"); + + // Test 1: Initialize + let init_request = json!({ + "jsonrpc": "2.0", + "method": "initialize", + "id": 1, + "params": { + "protocolVersion": 1, + "clientCapabilities": {} + } + }); + + transport + .send(init_request) + .await + .expect("Failed to send init request"); + + let init_response = transport + .recv() + .await + .expect("Failed to receive init response") + .expect("Got None instead of response"); + + assert_eq!(init_response["jsonrpc"], "2.0"); + assert_eq!(init_response["id"], 1); + assert!(init_response["result"]["protocolVersion"].is_number()); + + // Test 2: Create session + let session_request = json!({ + "jsonrpc": "2.0", + "method": "session/new", + "id": 2, + "params": { + "cwd": ".", + "mcpServers": [] + } + }); + + transport + .send(session_request) + .await + .expect("Failed to send session request"); + + let session_response = transport + .recv() + .await + .expect("Failed to receive session response") + .expect("Got None instead of response"); + + assert_eq!(session_response["jsonrpc"], "2.0"); + assert_eq!(session_response["id"], 2); + + // Test 3: Receive SSE notifications + // The mock server sends 2 notifications + let notif1 = transport + .recv() + .await + .expect("Failed to receive notification") + .expect("Got None instead of notification"); + assert_eq!(notif1["method"], "session/update"); + + let notif2 = transport + .recv() + .await + .expect("Failed to receive notification") + .expect("Got None instead of notification"); + assert_eq!(notif2["method"], "session/update"); + + // Close + transport.close().await.expect("Failed to close"); + } + + #[tokio::test] + async fn test_http_transport_sse_notifications() { + let (base_url, _server_handle) = start_mock_server().await; + + let mut transport = HttpTransport::new(&base_url, "test-client"); + transport.connect().await.expect("Failed to connect"); + + // Receive SSE notifications without sending any requests + let notif1 = transport + .recv() + .await + .expect("Failed to receive notification") + .expect("Got None instead of notification"); + assert_eq!(notif1["method"], "session/update"); + + transport.close().await.expect("Failed to close"); + } + + #[tokio::test] + async fn test_http_transport_concurrent_requests() { + let (base_url, _server_handle) = start_mock_server().await; + + let mut transport = HttpTransport::new(&base_url, "test-client"); + transport.connect().await.expect("Failed to connect"); + + // Send multiple requests concurrently + let transport = Arc::new(Mutex::new(transport)); + + let mut handles = vec![]; + for i in 0..5 { + let t = Arc::clone(&transport); + let handle = tokio::spawn(async move { + let mut transport = t.lock().await; + transport + .send(json!({ + "jsonrpc": "2.0", + "method": "session/new", + "id": i + 10, + "params": { + "cwd": ".", + "mcpServers": [] + } + })) + .await + }); + handles.push(handle); + } + + // All sends should succeed + for handle in handles { + handle + .await + .expect("Task panicked") + .expect("Send failed"); + } + + // Close + let mut transport = transport.lock().await; + transport.close().await.expect("Failed to close"); + } + + #[tokio::test] + async fn test_http_transport_error_response() { + // This test would require a mock server that returns errors + // For now, we'll skip it as the basic mock server always succeeds + } +} diff --git a/crates/dirigent_core/src/connectors/acp/transport/mod.rs b/crates/dirigent_core/src/connectors/acp/transport/mod.rs new file mode 100644 index 0000000..94a23f2 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/transport/mod.rs @@ -0,0 +1,217 @@ +//! ACP Transport Layer +//! +//! This module provides transport abstractions for the Agent-Client Protocol (ACP). +//! Transports handle the low-level communication with ACP agents via different +//! mechanisms (stdio, HTTP/SSE, WebSocket). +//! +//! # Architecture +//! +//! The `AcpTransport` trait defines a standard interface for sending and receiving +//! JSON-RPC messages. Concrete implementations handle the protocol-specific details: +//! +//! - **StdioTransport**: Line-delimited JSON over stdin/stdout (for process spawning) +//! - **HttpTransport**: HTTP POST for requests + SSE for notifications +//! +//! # Usage +//! +//! ```no_run +//! use dirigent_core::connectors::acp::transport::{AcpTransport, StdioTransport}; +//! use serde_json::json; +//! +//! # async fn example() -> anyhow::Result<()> { +//! // Create and connect a transport +//! let mut transport = StdioTransport::new("dirigate", &["serve", "--stdio"]); +//! transport.connect().await?; +//! +//! // Send a request +//! let request = json!({ +//! "jsonrpc": "2.0", +//! "method": "initialize", +//! "id": 1, +//! "params": { +//! "protocolVersion": 1, +//! "clientCapabilities": {} +//! } +//! }); +//! transport.send(request).await?; +//! +//! // Receive response +//! if let Some(response) = transport.recv().await? { +//! println!("Response: {:?}", response); +//! } +//! +//! // Clean up +//! transport.close().await?; +//! # Ok(()) +//! # } +//! ``` + +use async_trait::async_trait; +use serde_json::Value; +use std::error::Error; + +pub mod http; +pub mod stdio; + +pub use http::HttpTransport; +pub use stdio::CrashContext; +pub use stdio::StdioTransport; + +/// Result type for transport operations +pub type TransportResult = Result>; + +/// Transport abstraction for ACP communication +/// +/// This trait defines the interface for sending and receiving JSON-RPC messages +/// over various transport mechanisms. Implementations must be thread-safe (Send + Sync) +/// to allow concurrent operations. +/// +/// # Object Safety +/// +/// This trait is object-safe and can be used as `Box` for +/// dynamic dispatch when the concrete transport type is not known at compile time. +/// +/// # Message Format +/// +/// All messages are JSON-RPC 2.0 format, represented as `serde_json::Value`. +/// The transport layer does not interpret message semantics - it only handles +/// serialization and transmission. +/// +/// # Lifecycle +/// +/// 1. Create the transport instance with configuration +/// 2. Call `connect()` to establish the connection +/// 3. Use `send()` and `recv()` for bidirectional communication +/// 4. Call `close()` to clean up resources gracefully +#[async_trait] +pub trait AcpTransport: Send + Sync { + /// Establish the connection to the ACP agent + /// + /// This method must be called before any `send()` or `recv()` operations. + /// For stdio transports, this spawns the process. For network transports, + /// this establishes the TCP/HTTP connection. + /// + /// # Errors + /// + /// Returns an error if: + /// - The connection cannot be established (network unreachable, process failed to start) + /// - The transport is already connected + /// - Required resources are unavailable + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::connectors::acp::transport::{AcpTransport, StdioTransport}; + /// # async fn example() -> anyhow::Result<()> { + /// let mut transport = StdioTransport::new("my-agent", &[]); + /// transport.connect().await?; + /// # Ok(()) + /// # } + /// ``` + async fn connect(&mut self) -> TransportResult<()>; + + /// Send a JSON-RPC message to the agent + /// + /// The message is serialized and transmitted according to the transport's + /// protocol (line-delimited JSON for stdio, HTTP POST for HTTP transport). + /// + /// # Arguments + /// + /// * `message` - JSON-RPC 2.0 message (request or response) + /// + /// # Errors + /// + /// Returns an error if: + /// - The transport is not connected + /// - The message cannot be serialized + /// - The underlying I/O operation fails + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::connectors::acp::transport::{AcpTransport, StdioTransport}; + /// # use serde_json::json; + /// # async fn example(mut transport: StdioTransport) -> anyhow::Result<()> { + /// let request = json!({ + /// "jsonrpc": "2.0", + /// "method": "session/new", + /// "id": 1, + /// "params": {} + /// }); + /// transport.send(request).await?; + /// # Ok(()) + /// # } + /// ``` + async fn send(&mut self, message: Value) -> TransportResult<()>; + + /// Receive a JSON-RPC message from the agent + /// + /// This method blocks until a message is available or the connection is closed. + /// Returns `None` if the connection is closed gracefully (EOF on stdio, SSE + /// stream ended, etc.). + /// + /// # Returns + /// + /// - `Ok(Some(message))` - A message was received + /// - `Ok(None)` - The connection was closed gracefully (no more messages) + /// - `Err(e)` - An error occurred during receive + /// + /// # Errors + /// + /// Returns an error if: + /// - The transport is not connected + /// - The received data cannot be parsed as JSON + /// - The underlying I/O operation fails + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::connectors::acp::transport::{AcpTransport, StdioTransport}; + /// # async fn example(mut transport: StdioTransport) -> anyhow::Result<()> { + /// while let Some(message) = transport.recv().await? { + /// println!("Received: {:?}", message); + /// } + /// # Ok(()) + /// # } + /// ``` + async fn recv(&mut self) -> TransportResult>; + + /// Close the connection gracefully + /// + /// This method cleans up resources and closes the connection. After calling + /// `close()`, no further `send()` or `recv()` operations are allowed. + /// + /// For stdio transports, this terminates the child process. For network + /// transports, this closes the socket/connection. + /// + /// # Errors + /// + /// Returns an error if the cleanup operation fails. This is typically + /// non-fatal and can be ignored in most cases. + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::connectors::acp::transport::{AcpTransport, StdioTransport}; + /// # async fn example(mut transport: StdioTransport) -> anyhow::Result<()> { + /// // ... use transport ... + /// transport.close().await?; + /// # Ok(()) + /// # } + /// ``` + async fn close(&mut self) -> TransportResult<()>; + + /// Get crash context if the transport detected a child process crash. + /// + /// Only meaningful for transports that manage child processes (e.g., StdioTransport). + /// Returns None by default. + async fn get_crash_context(&self) -> Option { + None + } + + /// Get the process ID, if applicable (e.g., for stdio transports that spawn a child process). + /// Returns None for transports that don't manage OS processes. + async fn pid(&self) -> Option { + None + } +} diff --git a/crates/dirigent_core/src/connectors/acp/transport/stdio.rs b/crates/dirigent_core/src/connectors/acp/transport/stdio.rs new file mode 100644 index 0000000..0be83f2 --- /dev/null +++ b/crates/dirigent_core/src/connectors/acp/transport/stdio.rs @@ -0,0 +1,1174 @@ +//! Stdio transport for ACP +//! +//! This module provides a transport implementation that communicates with ACP agents +//! via stdin/stdout using line-delimited JSON. This is the standard transport for +//! agents that run as child processes. +//! +//! # Protocol +//! +//! Messages are sent as line-delimited JSON: +//! - Each JSON-RPC message is serialized to a single line +//! - Lines are terminated with `\n` +//! - The process's stdout is read line-by-line to receive messages +//! +//! # Lifecycle +//! +//! 1. Create StdioTransport with command and arguments +//! 2. Call `connect()` to spawn the process +//! 3. Use `send()` and `recv()` for communication +//! 4. Call `close()` to terminate the process gracefully +//! +//! # Thread Safety +//! +//! StdioTransport uses `Arc>` for thread-safe access to stdin/stdout handles. +//! This allows multiple tasks to share the transport, though typically only one task +//! should be reading (recv) and one writing (send) at a time. + +use super::{AcpTransport, TransportResult}; +use crate::connectors::acp::logging::AcpProtocolLogger; +use async_trait::async_trait; +use serde_json::Value; +use std::collections::VecDeque; +use std::path::PathBuf; +use std::process::Stdio; +use std::sync::Arc; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; +use tokio::sync::Mutex; +use tracing::{debug, error, info, warn}; + +/// Context captured when a child process crashes. +/// +/// Assembled from three data sources that are normally siloed: +/// stderr output, process exit status, and partial stdout data. +#[derive(Debug, Clone)] +pub struct CrashContext { + /// Recent stderr lines from the child process (last 50 lines) + pub recent_stderr: Vec, + /// Exit status if available at crash detection time + pub exit_status: Option, + /// Partial stdout data that was being written when the process crashed + pub partial_stdout: Option, +} + +/// Stdio transport implementation +/// +/// Spawns a child process and communicates via stdin/stdout using line-delimited JSON. +/// +/// # Example +/// +/// ```no_run +/// use dirigent_core::connectors::acp::transport::{AcpTransport, StdioTransport}; +/// use serde_json::json; +/// +/// # async fn example() -> anyhow::Result<()> { +/// // Spawn dirigate process +/// let mut transport = StdioTransport::new( +/// "dirigate", +/// &["serve", "--fixtures", "./fixtures", "--stdio"] +/// ); +/// +/// // Connect (spawns process) +/// transport.connect().await?; +/// +/// // Send initialize request +/// transport.send(json!({ +/// "jsonrpc": "2.0", +/// "method": "initialize", +/// "id": 1, +/// "params": { +/// "protocolVersion": 1, +/// "clientCapabilities": {} +/// } +/// })).await?; +/// +/// // Receive response +/// if let Some(response) = transport.recv().await? { +/// println!("Initialized: {:?}", response); +/// } +/// +/// // Clean up +/// transport.close().await?; +/// # Ok(()) +/// # } +/// ``` +pub struct StdioTransport { + /// Command to execute (e.g., "dirigate" or "/path/to/agent") + command: String, + + /// Arguments to pass to the command + args: Vec, + + /// Working directory for the process (optional) + cwd: Option, + + /// Environment variables to set (key-value pairs) + env: Vec<(String, String)>, + + /// Child process handle + /// + /// Wrapped in Arc for thread-safe access. This allows the + /// transport to be used across multiple async tasks. + child: Arc>>, + + /// Stdin handle for sending messages + /// + /// Wrapped in Arc for thread-safe writes. Only one task should + /// write at a time to avoid message interleaving. + stdin: Arc>>, + + /// Stdout handle for receiving messages + /// + /// Wrapped in BufReader for line-by-line reading, then in Arc + /// for thread-safe access. Only one task should read at a time. + stdout: Arc>>>, + + /// Stderr handle for draining logs + /// + /// Wrapped in BufReader for line-by-line reading, then in Arc + /// for thread-safe access. A background task continuously drains this. + stderr: Arc>>>, + + /// Background task handle for stderr draining + /// + /// Continuously reads stderr and logs it to prevent buffer overflow. + stderr_task: Arc>>>, + + /// Directory for protocol logging (optional) + /// + /// When set, all JSON-RPC messages are logged to JSONL files in this directory. + log_dir: Option, + + /// Protocol logger (created in connect() if log_dir is set) + logger: Option, + + /// Timestamp of the last write (send) operation + last_io_write: Arc>>, + + /// Timestamp of the last read (recv) operation + last_io_read: Arc>>, + + /// Ring buffer of recent stderr lines from the child process. + /// Capped at 50 lines. Used to assemble crash context. + recent_stderr: Arc>>, + + /// Crash context captured when recv() detects a child process crash. + /// Set by recv() when it detects a partial read (crash mid-write). + last_crash_context: Arc>>, + + /// Pending message from a previous multi-message parse. + /// + /// When `try_extract_json` finds two JSON values in one buffer + /// (e.g., a complete multi-line message followed by the start of the next), + /// the second value is stored here and returned on the next `recv()` call. + /// + /// See [`crate::acp::transport::json_reader::try_extract_json`]. + pending_message: Arc>>, + + /// Incomplete remainder from a previous recv that contained a complete JSON + /// message followed by the start of another message that wasn't yet complete. + /// + /// Without this, the partial data would be discarded and the next `read_line()` + /// would pick up mid-message, causing a parse failure. + pending_remainder: Arc>>, + + /// Optional process lifecycle manager for graceful shutdown. + /// + /// When set, `connect()` registers the child PID with the lifecycle manager + /// and `close()` performs a graceful shutdown (SIGTERM + timeout) before + /// force-killing the process. Without it, the original hard-kill behavior + /// is preserved. + #[cfg(feature = "server")] + process_lifecycle: Option>, +} + +impl StdioTransport { + /// Create a new StdioTransport + /// + /// # Arguments + /// + /// * `command` - Command to execute (e.g., "dirigate") + /// * `args` - Arguments to pass to the command + /// + /// # Returns + /// + /// A new StdioTransport that is not yet connected. Call `connect()` to + /// spawn the process. + /// + /// # Example + /// + /// ```no_run + /// use dirigent_core::connectors::acp::transport::StdioTransport; + /// + /// let transport = StdioTransport::new( + /// "dirigate", + /// &["serve", "--stdio"] + /// ); + /// ``` + pub fn new(command: impl Into, args: &[impl AsRef]) -> Self { + let command = command.into(); + let args = args.iter().map(|s| s.as_ref().to_string()).collect(); + + info!( + command = %command, + args = ?args, + "Creating stdio transport" + ); + + Self { + command, + args, + cwd: None, + env: vec![], + child: Arc::new(Mutex::new(None)), + stdin: Arc::new(Mutex::new(None)), + stdout: Arc::new(Mutex::new(None)), + stderr: Arc::new(Mutex::new(None)), + stderr_task: Arc::new(Mutex::new(None)), + log_dir: None, + logger: None, + last_io_write: Arc::new(Mutex::new(None)), + last_io_read: Arc::new(Mutex::new(None)), + recent_stderr: Arc::new(Mutex::new(VecDeque::new())), + last_crash_context: Arc::new(Mutex::new(None)), + pending_message: Arc::new(Mutex::new(None)), + pending_remainder: Arc::new(Mutex::new(None)), + #[cfg(feature = "server")] + process_lifecycle: None, + } + } + + /// Get the PID of the child process, if connected. + pub async fn pid(&self) -> Option { + let child_lock = self.child.lock().await; + child_lock.as_ref().and_then(|c| c.id()) + } + + /// Get the crash context captured by the last recv() call, if any. + pub async fn get_crash_context(&self) -> Option { + self.last_crash_context.lock().await.clone() + } + + /// Get recent stderr lines from the child process. + pub async fn get_recent_stderr(&self) -> Vec { + self.recent_stderr.lock().await.iter().cloned().collect() + } + + /// Get the command used to spawn this transport. + pub fn command(&self) -> &str { + &self.command + } + + /// Get the arguments passed to the command. + pub fn args(&self) -> &[String] { + &self.args + } + + /// Get the timestamp of the last write (send) operation. + pub async fn last_io_write(&self) -> Option { + *self.last_io_write.lock().await + } + + /// Get the timestamp of the last read (recv) operation. + pub async fn last_io_read(&self) -> Option { + *self.last_io_read.lock().await + } + + /// Set the working directory for the spawned process + /// + /// Must be called before `connect()`. + pub fn set_cwd(&mut self, cwd: PathBuf) { + self.cwd = Some(cwd); + } + + /// Add an environment variable for the spawned process + /// + /// Must be called before `connect()`. + pub fn set_env(&mut self, key: String, value: String) { + self.env.push((key, value)); + } + + /// Set the directory for protocol logging + /// + /// When set, all JSON-RPC messages (stdin/stdout) will be logged to JSONL files + /// in this directory. Useful for debugging protocol issues. + /// + /// Must be called before `connect()`. + pub fn set_log_dir(&mut self, log_dir: PathBuf) { + self.log_dir = Some(log_dir); + } + + /// Set the process lifecycle manager for graceful shutdown. + /// + /// When set, the transport will: + /// - Call `configure_async_command` on the command before spawning + /// - Register the child PID immediately after spawn + /// - Perform a graceful shutdown (SIGTERM → wait → SIGKILL) in `close()` + /// + /// Must be called before `connect()`. + #[cfg(feature = "server")] + pub fn set_process_lifecycle(&mut self, lifecycle: Box) { + self.process_lifecycle = Some(lifecycle); + } + + /// Log a parsed message and return it, updating I/O timestamp. + async fn log_and_return_message(&mut self, message: Value) -> TransportResult> { + let msg_type = if message.get("id").is_some() { + if message.get("method").is_some() { + "request" + } else { + "response" + } + } else if message.get("method").is_some() { + "notification" + } else { + "unknown" + }; + + info!( + msg_type = msg_type, + method = message.get("method").and_then(|m| m.as_str()).unwrap_or(""), + id = message + .get("id") + .map(|id| id.to_string()) + .unwrap_or_default(), + "📦 Parsed message type" + ); + + *self.last_io_read.lock().await = Some(std::time::Instant::now()); + + Ok(Some(message)) + } + + /// Core recv implementation that reads from stdout with three-tier parsing: + /// + /// 1. **Happy path**: `serde_json::from_str` on the complete line — fast, + /// no risk of partial extraction. + /// 2. **Fallback**: `StreamDeserializer` via `try_extract_json` — handles + /// multi-line JSON, concatenated messages, and carries over remainders. + /// 3. **Recovery**: If both fail, log the bad data and skip it instead of + /// killing the connection. Only EOF and I/O errors are fatal. + async fn recv_from_stdout(&mut self) -> TransportResult> { + // Buffer for assembling multi-line JSON messages. + let mut json_buffer = String::new(); + let mut buffered_lines: usize = 0; + + // Safety limit: max lines to buffer before giving up on this message. + const MAX_BUFFER_LINES: usize = 50_000; + // Safety limit: max buffer size in bytes (10 MB) + const MAX_BUFFER_BYTES: usize = 10 * 1024 * 1024; + + // Initialize buffer from pending remainder if available. + // This carries over incomplete data from a previous recv() where we + // extracted one complete JSON message but the remainder was an + // incomplete second message. + if let Some(remainder) = self.pending_remainder.lock().await.take() { + warn!( + remainder_bytes = remainder.len(), + remainder_preview = %&remainder[..remainder.len().min(120)], + "Initializing buffer from pending remainder" + ); + json_buffer = remainder; + buffered_lines = 1; + } + + loop { + let mut stdout_lock = self.stdout.lock().await; + let stdout = stdout_lock + .as_mut() + .ok_or("Transport not connected (no stdout)")?; + + let mut line = String::new(); + let bytes_read = stdout + .read_line(&mut line) + .await + .map_err(|e| format!("Failed to read from stdout: {}", e))?; + + // EOF — this IS fatal (process closed stdout) + if bytes_read == 0 { + if !json_buffer.is_empty() { + warn!( + buffer_lines = buffered_lines, + buffer_bytes = json_buffer.len(), + "EOF with incomplete multi-line JSON buffer" + ); + } + debug!("Received EOF from child process"); + return Ok(None); + } + + // Detect partial read (crash mid-write signature) — fatal + if bytes_read > 0 && !line.ends_with('\n') { + warn!( + bytes_read, + partial_data = %line, + "Partial read from child process (likely crash mid-write)" + ); + drop(stdout_lock); + + let exit_status = { + let mut child_lock = self.child.lock().await; + child_lock.as_mut().and_then(|c| c.try_wait().ok().flatten()) + }; + let recent_stderr = self.recent_stderr.lock().await.iter().cloned().collect(); + let crash_ctx = CrashContext { + recent_stderr, + exit_status, + partial_stdout: Some(line), + }; + info!( + exit_status = ?crash_ctx.exit_status, + stderr_lines = crash_ctx.recent_stderr.len(), + "Crash context assembled for child process crash" + ); + *self.last_crash_context.lock().await = Some(crash_ctx); + return Ok(None); + } + + // Log every line read for framing diagnostics + // TODO: downgrade to debug once framing bug is resolved + warn!( + bytes_read, + line_starts_with = %&line[..line.len().min(40)].escape_debug(), + buffer_lines = buffered_lines, + buffer_bytes = json_buffer.len(), + "📥 read_line from child stdout" + ); + + // Skip empty lines when not buffering + if line.trim().is_empty() && json_buffer.is_empty() { + warn!("Received empty line from child process, skipping"); + drop(stdout_lock); + continue; + } + + // Release stdout lock before doing work + drop(stdout_lock); + + // ── Tier 1: Happy path — from_str on fresh single line ── + // + // If we have no accumulated buffer, try parsing the line directly. + // This is the fast path for well-formed single-line JSON (99% of messages). + // Unlike StreamDeserializer, from_str is all-or-nothing: it cannot + // extract a truncated value from a line with escaping issues. + if json_buffer.is_empty() { + if let Ok(message) = serde_json::from_str::(line.trim()) { + info!( + method = message.get("method").and_then(|m| m.as_str()).unwrap_or(""), + id = message.get("id").map(|i| i.to_string()).unwrap_or_default(), + bytes = bytes_read, + "📨 Received message (from_str fast path)" + ); + if let Some(ref mut logger) = self.logger { + logger.log_incoming_raw(line.trim()); + } + return self.log_and_return_message(message).await; + } + // from_str failed — fall through to tier 2 (buffer + StreamDeserializer) + } + + // ── Tier 2: Fallback — multi-line buffering with StreamDeserializer ── + // + // Append to buffer and try StreamDeserializer, which handles: + // - Multi-line / pretty-printed JSON + // - Concatenated messages (extracts first, stores remainder) + // - Incomplete JSON (returns None, we keep buffering) + + // Append to buffer + if json_buffer.is_empty() { + json_buffer = line; + buffered_lines = 1; + } else { + json_buffer.push_str(&line); + buffered_lines += 1; + } + + if let Some((message, remainder)) = crate::acp::transport::json_reader::try_extract_json(&json_buffer) { + let method = message.get("method").and_then(|m| m.as_str()).unwrap_or(""); + let msg_id = message.get("id").map(|i| i.to_string()).unwrap_or_default(); + if buffered_lines > 1 { + warn!( + lines = buffered_lines, + bytes = json_buffer.len(), + method, + id = msg_id, + remainder_bytes = remainder.len(), + "Recovered multi-line JSON message from child process" + ); + } else { + info!(method, id = msg_id, "📨 Received message (StreamDeserializer fallback)"); + } + + // Log raw content to protocol log + if let Some(ref mut logger) = self.logger { + let parsed_end = json_buffer.len() - remainder.len(); + logger.log_incoming_raw(&json_buffer[..parsed_end]); + } + + // Handle leftover data from concatenated messages + if !remainder.is_empty() { + warn!( + remainder_bytes = remainder.len(), + remainder_preview = %&remainder[..remainder.len().min(120)], + "Buffer has remaining data after complete JSON, checking for second message" + ); + if let Some((second_msg, second_remainder)) = crate::acp::transport::json_reader::try_extract_json(&remainder) + { + warn!( + second_msg_method = second_msg.get("method").and_then(|m| m.as_str()).unwrap_or("?"), + second_remainder_bytes = second_remainder.len(), + "Extracted second complete message from remainder" + ); + *self.pending_message.lock().await = Some(second_msg); + if !second_remainder.is_empty() { + warn!( + remainder_bytes = second_remainder.len(), + remainder_preview = %&second_remainder[..second_remainder.len().min(120)], + "Carrying over third message fragment as pending remainder" + ); + *self.pending_remainder.lock().await = Some(second_remainder); + } + } else { + warn!( + remainder_bytes = remainder.len(), + remainder_preview = %&remainder[..remainder.len().min(120)], + "Carrying over incomplete remainder for next recv()" + ); + *self.pending_remainder.lock().await = Some(remainder); + } + } + + return self.log_and_return_message(message).await; + } + + // ── Tier 3: Recovery — skip unparseable data ── + // + // If the buffer doesn't start with a JSON token, it's garbage. + // Instead of killing the connection, log the bad data and reset. + // One lost notification is better than a dead session. + let buffer_trimmed = json_buffer.trim_start(); + if !buffer_trimmed.starts_with('{') && !buffer_trimmed.starts_with('[') { + warn!( + target: "acp_protocol", + buffered_lines, + buffer_bytes = json_buffer.len(), + line_bytes = bytes_read, + raw_preview = %&buffer_trimmed[..buffer_trimmed.len().min(300)], + "Skipping non-JSON data from child process (connection preserved)" + ); + if let Some(ref mut logger) = self.logger { + logger.log_incoming_raw(buffer_trimmed); + } + // Reset and keep reading + json_buffer.clear(); + buffered_lines = 0; + continue; + } + + // Safety limits — these ARE fatal (runaway buffering) + if buffered_lines >= MAX_BUFFER_LINES { + error!( + target: "acp_protocol", + lines = buffered_lines, + bytes = json_buffer.len(), + "Multi-line JSON buffer exceeded line limit, aborting" + ); + return Err(format!( + "Multi-line JSON message exceeded {} lines without completing", + MAX_BUFFER_LINES + ) + .into()); + } + if json_buffer.len() > MAX_BUFFER_BYTES { + error!( + target: "acp_protocol", + lines = buffered_lines, + bytes = json_buffer.len(), + "Multi-line JSON buffer exceeded size limit, aborting" + ); + return Err(format!( + "Multi-line JSON message exceeded {} bytes without completing", + MAX_BUFFER_BYTES + ) + .into()); + } + + if buffered_lines % 1000 == 0 { + debug!( + lines = buffered_lines, + bytes = json_buffer.len(), + "Still buffering multi-line JSON message..." + ); + } + } + } +} + +#[async_trait] +impl AcpTransport for StdioTransport { + async fn connect(&mut self) -> TransportResult<()> { + let mut child_lock = self.child.lock().await; + + if child_lock.is_some() { + return Err("Transport already connected".into()); + } + + info!( + command = %self.command, + args = ?self.args, + "Spawning child process" + ); + + // Spawn the child process + let mut command = Command::new(&self.command); + command + .args(&self.args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) // Pipe stderr to prevent buffer overflow + .kill_on_drop(true); // Ensure process is killed when dropped + + // Set working directory if specified + if let Some(cwd) = &self.cwd { + command.current_dir(cwd); + } + + // Set environment variables + for (key, value) in &self.env { + command.env(key, value); + } + + // Allow the process lifecycle manager to configure platform-specific + // flags (e.g. Windows Job Object, Linux process group) before spawning. + #[cfg(feature = "server")] + if let Some(ref lifecycle) = self.process_lifecycle { + lifecycle.configure_async_command(&mut command); + } + + let mut child = command + .spawn() + .map_err(|e| format!("Failed to spawn process '{}': {}", self.command, e))?; + + // Register the child PID with the lifecycle manager so it is included + // in the platform-level process group / job object. + #[cfg(feature = "server")] + if let Some(ref lifecycle) = self.process_lifecycle { + if let Some(pid) = child.id() { + if let Err(e) = lifecycle.register_child(pid) { + warn!(error = %e, pid, "Failed to register child with process lifecycle"); + } + } + } + + // Take stdin/stdout/stderr handles + let stdin = child.stdin.take().ok_or("Failed to open child stdin")?; + let stdout = child.stdout.take().ok_or("Failed to open child stdout")?; + let stderr = child.stderr.take().ok_or("Failed to open child stderr")?; + + // Store stdin/stdout handles + *child_lock = Some(child); + *self.stdin.lock().await = Some(stdin); + *self.stdout.lock().await = Some(BufReader::new(stdout)); + + // Store stderr handle and spawn draining task + let stderr_reader = BufReader::new(stderr); + *self.stderr.lock().await = Some(stderr_reader); + + // Spawn background task to drain stderr continuously + let stderr_arc = Arc::clone(&self.stderr); + let recent_stderr_arc = Arc::clone(&self.recent_stderr); + let command_name = self.command.clone(); + let stderr_handle = tokio::spawn(async move { + debug!(target: "child_stderr", command = %command_name, "Starting stderr draining task"); + + loop { + let mut stderr_lock = stderr_arc.lock().await; + if let Some(stderr_reader) = stderr_lock.as_mut() { + let mut line = String::new(); + match stderr_reader.read_line(&mut line).await { + Ok(0) => { + // EOF - process closed stderr + debug!(target: "child_stderr", command = %command_name, "Stderr EOF reached, task exiting"); + break; + } + Ok(_) => { + // Got a line - log it + let trimmed = line.trim(); + if !trimmed.is_empty() { + // Non-empty stderr from child process is often indicative of problems + // Use warn level so it's visible in console by default + warn!(target: "child_stderr", command = %command_name, line = %trimmed, "Child process stderr output"); + { + let mut buffer = recent_stderr_arc.lock().await; + buffer.push_back(trimmed.to_string()); + while buffer.len() > 50 { + buffer.pop_front(); + } + } + } + } + Err(e) => { + // Error reading - log and exit + warn!(target: "child_stderr", command = %command_name, error = %e, "Error reading stderr, task exiting"); + break; + } + } + } else { + // Stderr handle was taken (during close), exit task + debug!(target: "child_stderr", command = %command_name, "Stderr handle removed, task exiting"); + break; + } + // Release lock between reads to allow close() to proceed + drop(stderr_lock); + } + + debug!(target: "child_stderr", command = %command_name, "Stderr draining task finished"); + }); + + *self.stderr_task.lock().await = Some(stderr_handle); + + // Create protocol logger if log_dir is configured + if let Some(log_dir) = &self.log_dir { + self.logger = AcpProtocolLogger::new(log_dir.clone(), &self.command); + } + + info!("Child process spawned successfully with stderr draining"); + + Ok(()) + } + + async fn send(&mut self, message: Value) -> TransportResult<()> { + let mut stdin_lock = self.stdin.lock().await; + let stdin = stdin_lock + .as_mut() + .ok_or("Transport not connected (no stdin)")?; + + // Serialize to JSON and add newline + let json_str = serde_json::to_string(&message) + .map_err(|e| format!("Failed to serialize message: {}", e))?; + + // Determine message type for logging + let msg_type = if message.get("method").is_some() { + if message.get("id").is_some() { + "request" + } else { + "notification" + } + } else { + "response" + }; + let method = message.get("method").and_then(|m| m.as_str()).unwrap_or(""); + let id = message.get("id").map(|i| i.to_string()).unwrap_or_default(); + + info!( + msg_type = msg_type, + method = method, + id = id, + "📤 Sending message to child process" + ); + let masked_message = dirigent_protocol::log_utils::mask_json_string(&json_str); + debug!(message = %masked_message, "Full message JSON"); + + // Log to protocol log file if enabled + if let Some(ref mut logger) = self.logger { + logger.log_outgoing(&message); + } + + // Write to stdin + stdin + .write_all(json_str.as_bytes()) + .await + .map_err(|e| format!("Failed to write to stdin: {}", e))?; + stdin + .write_all(b"\n") + .await + .map_err(|e| format!("Failed to write newline to stdin: {}", e))?; + stdin + .flush() + .await + .map_err(|e| format!("Failed to flush stdin: {}", e))?; + + debug!("Message sent and flushed successfully"); + + // Record I/O timestamp + *self.last_io_write.lock().await = Some(std::time::Instant::now()); + + Ok(()) + } + + async fn recv(&mut self) -> TransportResult> { + // Check if we have a leftover message from a previous multi-message parse. + // Clone the Arc to avoid holding an immutable borrow of self across the + // mutable call to recv_from_stdout/log_and_return_message. + let pending_arc = Arc::clone(&self.pending_message); + let pending = pending_arc.lock().await.take(); + + if let Some(msg) = pending { + debug!("Returning pending message from previous recv() parse"); + self.log_and_return_message(msg).await + } else { + self.recv_from_stdout().await + } + } + + async fn close(&mut self) -> TransportResult<()> { + info!("Closing stdio transport"); + + // Abort stderr draining task first + if let Some(task) = self.stderr_task.lock().await.take() { + debug!("Aborting stderr draining task"); + task.abort(); + // Wait briefly for task to finish with timeout + let _ = tokio::time::timeout(std::time::Duration::from_millis(100), task).await; + } + + // Drop stdin/stdout/stderr to close pipes (signals process to exit) + *self.stdin.lock().await = None; + *self.stdout.lock().await = None; + *self.stderr.lock().await = None; + + // Terminate the child process + let mut child_lock = self.child.lock().await; + if let Some(mut child) = child_lock.take() { + debug!("Terminating child process"); + + #[cfg(feature = "server")] + let used_lifecycle = { + if let Some(ref lifecycle) = self.process_lifecycle { + if child.id().is_some() { + // Graceful shutdown: SIGTERM/CTRL_BREAK → wait → force kill + dirigent_process::graceful_shutdown_async( + lifecycle.as_ref(), + &mut child, + std::time::Duration::from_secs(5), + ) + .await; + true + } else { + // Process already exited + true + } + } else { + false + } + }; + + #[cfg(not(feature = "server"))] + let used_lifecycle = false; + + if !used_lifecycle { + // Fallback: hard kill (original behavior, belt-and-suspenders) + if let Err(e) = child.start_kill() { + warn!(error = %e, "Failed to send kill signal (process may have already exited)"); + } + + // Wait briefly with timeout for process to exit + let wait_result = + tokio::time::timeout(std::time::Duration::from_millis(500), child.wait()).await; + + match wait_result { + Ok(Ok(status)) => { + info!(status = ?status, "Child process exited"); + } + Ok(Err(e)) => { + warn!(error = %e, "Error waiting for child process"); + } + Err(_) => { + warn!("Child process did not exit within timeout, leaving to OS cleanup"); + } + } + } + } + + Ok(()) + } + + async fn get_crash_context(&self) -> Option { + self.last_crash_context.lock().await.clone() + } + + async fn pid(&self) -> Option { + let child_lock = self.child.lock().await; + child_lock.as_ref().and_then(|c| c.id()) + } +} + +impl Drop for StdioTransport { + fn drop(&mut self) { + // Attempt to kill the child process if it's still running + // This is a best-effort cleanup - we can't use async in Drop + if let Ok(mut child_lock) = self.child.try_lock() { + if let Some(mut child) = child_lock.take() { + // Try to kill synchronously (blocking) + let _ = child.start_kill(); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + /// Helper to find the mocker binary + fn mocker_binary() -> std::path::PathBuf { + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + std::path::Path::new(manifest_dir) + .parent() + .unwrap() + .parent() + .unwrap() + .join("target") + .join("debug") + .join(if cfg!(windows) { + "dirigate.exe" + } else { + "dirigate" + }) + } + + /// Helper to find the fixture file + fn fixture_path() -> std::path::PathBuf { + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + std::path::Path::new(manifest_dir) + .parent() + .unwrap() + .join("dirigate") + .join("examples") + .join("basic.yaml") + } + + #[tokio::test] + async fn test_stdio_transport_basic_flow() { + let binary = mocker_binary(); + if !binary.exists() { + eprintln!( + "Skipping test: mocker binary not found at {:?}. Run 'cargo build' first.", + binary + ); + return; + } + + let fixture = fixture_path(); + assert!(fixture.exists(), "Fixture file not found at {:?}", fixture); + + // Create transport + let mut transport = StdioTransport::new( + binary.to_str().unwrap(), + &["serve", "--fixtures", fixture.to_str().unwrap(), "--stdio"], + ); + + // Connect + transport.connect().await.expect("Failed to connect"); + + // Test 1: Initialize + let init_request = json!({ + "jsonrpc": "2.0", + "method": "initialize", + "id": 1, + "params": { + "protocolVersion": 1, + "clientCapabilities": {} + } + }); + + transport + .send(init_request) + .await + .expect("Failed to send init request"); + + let init_response = transport + .recv() + .await + .expect("Failed to receive init response") + .expect("Got None instead of response"); + + assert_eq!(init_response["jsonrpc"], "2.0"); + assert_eq!(init_response["id"], 1); + assert!(init_response["result"]["protocolVersion"].is_number()); + + // Test 2: Create session + let session_request = json!({ + "jsonrpc": "2.0", + "method": "session/new", + "id": 2, + "params": { + "cwd": ".", + "mcpServers": [] + } + }); + + transport + .send(session_request) + .await + .expect("Failed to send session request"); + + let session_response = transport + .recv() + .await + .expect("Failed to receive session response") + .expect("Got None instead of response"); + + assert_eq!(session_response["jsonrpc"], "2.0"); + assert_eq!(session_response["id"], 2); + let session_id = session_response["result"]["sessionId"] + .as_str() + .expect("sessionId should be a string"); + assert!(!session_id.is_empty()); + + // Test 3: Send a prompt + let prompt_request = json!({ + "jsonrpc": "2.0", + "method": "session/prompt", + "id": 3, + "params": { + "sessionId": session_id, + "prompt": [ + { + "type": "text", + "text": "hello" + } + ] + } + }); + + transport + .send(prompt_request) + .await + .expect("Failed to send prompt request"); + + // Receive response (skip any notifications) + let mut prompt_response = None; + for _ in 0..10 { + // Max 10 messages to avoid infinite loop + if let Some(msg) = transport.recv().await.expect("Failed to receive") { + // Check if this is the response we're looking for + if msg.get("id") == Some(&json!(3)) { + prompt_response = Some(msg); + break; + } + // Skip notifications + } else { + break; + } + } + + let prompt_response = prompt_response.expect("Did not receive response to prompt request"); + assert_eq!(prompt_response["jsonrpc"], "2.0"); + assert_eq!(prompt_response["id"], 3); + assert!(prompt_response["result"]["stopReason"].is_string()); + + // Close + transport.close().await.expect("Failed to close"); + } + + #[tokio::test] + async fn test_stdio_transport_process_exit() { + let binary = mocker_binary(); + if !binary.exists() { + eprintln!("Skipping test: mocker binary not found"); + return; + } + + let fixture = fixture_path(); + if !fixture.exists() { + eprintln!("Skipping test: fixture not found"); + return; + } + + let mut transport = StdioTransport::new( + binary.to_str().unwrap(), + &["serve", "--fixtures", fixture.to_str().unwrap(), "--stdio"], + ); + + transport.connect().await.expect("Failed to connect"); + + // Close the transport (kills the process) + transport.close().await.expect("Failed to close"); + + // Try to send after close - should fail + let result = transport.send(json!({"test": "message"})).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_stdio_transport_malformed_json() { + // This test would require a mock process that sends malformed JSON + // For now, we'll skip it as it requires more infrastructure + // In a real implementation, you might use a test harness process + } + + // try_extract_json unit tests are in crate::acp::transport::json_reader::tests + + #[tokio::test] + async fn test_stdio_transport_concurrent_sends() { + let binary = mocker_binary(); + if !binary.exists() { + eprintln!("Skipping test: mocker binary not found"); + return; + } + + let fixture = fixture_path(); + if !fixture.exists() { + eprintln!("Skipping test: fixture not found"); + return; + } + + let mut transport = StdioTransport::new( + binary.to_str().unwrap(), + &["serve", "--fixtures", fixture.to_str().unwrap(), "--stdio"], + ); + + transport.connect().await.expect("Failed to connect"); + + // Initialize first + transport + .send(json!({ + "jsonrpc": "2.0", + "method": "initialize", + "id": 1, + "params": { + "protocolVersion": 1, + "clientCapabilities": {} + } + })) + .await + .expect("Failed to send init"); + + // Drain the response + transport.recv().await.ok(); + + // Send multiple messages in quick succession + let transport = Arc::new(Mutex::new(transport)); + + let mut handles = vec![]; + for i in 0..5 { + let t = Arc::clone(&transport); + let handle = tokio::spawn(async move { + let mut transport = t.lock().await; + transport + .send(json!({ + "jsonrpc": "2.0", + "method": "session/new", + "id": i + 10, + "params": { + "cwd": ".", + "mcpServers": [] + } + })) + .await + }); + handles.push(handle); + } + + // All sends should succeed + for handle in handles { + handle.await.expect("Task panicked").expect("Send failed"); + } + + // Close + let mut transport = transport.lock().await; + transport.close().await.expect("Failed to close"); + } +} diff --git a/crates/dirigent_core/src/connectors/fingerprint.rs b/crates/dirigent_core/src/connectors/fingerprint.rs new file mode 100644 index 0000000..5ba24b3 --- /dev/null +++ b/crates/dirigent_core/src/connectors/fingerprint.rs @@ -0,0 +1,210 @@ +//! Connector fingerprint computation +//! +//! Computes deterministic fingerprints for connectors based on their kind and +//! connection parameters. Fingerprints are used by the archivist to re-associate +//! archived data with connectors across restarts, even when connector IDs change. +//! +//! # Fingerprint Format +//! +//! Each connector kind produces fingerprints in a specific format: +//! - ACP stdio: `acp/stdio:` +//! - ACP HTTP: `acp/http:` +//! - OpenCode: `opencode/http:` +//! - Gateway: `gateway:` +//! - Mock / Acceptor: no fingerprint (returns `None`) + +use crate::types::ConnectorKind; + +/// Compute a deterministic fingerprint for a connector based on its kind and connection params. +/// +/// The fingerprint captures the essential identity of a connector -- the parameters that +/// make it "the same connector" across restarts. For example, an ACP stdio connector +/// pointing to `/usr/bin/claude` will always produce the same fingerprint regardless +/// of what connector ID is assigned. +/// +/// Returns `None` for Mock and Acceptor connectors, which do not have meaningful +/// persistent identity. +pub fn compute_fingerprint(kind: &ConnectorKind, params: &serde_json::Value) -> Option<String> { + match kind { + ConnectorKind::Acp => compute_acp_fingerprint(params), + ConnectorKind::OpenCode => compute_opencode_fingerprint(params), + ConnectorKind::Gateway => compute_gateway_fingerprint(params), + ConnectorKind::Mock | ConnectorKind::Acceptor => None, + } +} + +fn compute_acp_fingerprint(params: &serde_json::Value) -> Option<String> { + let transport = params.get("transport")?; + let transport_type = transport.get("type")?.as_str()?; + match transport_type { + "stdio" => { + let command = transport.get("command")?.as_str()?; + let resolved = resolve_command_path(command); + Some(format!("acp/stdio:{}", resolved)) + } + "http" => { + let base_url = transport.get("base_url")?.as_str()?; + Some(format!("acp/http:{}", base_url)) + } + _ => None, + } +} + +fn compute_opencode_fingerprint(params: &serde_json::Value) -> Option<String> { + let base_url = params.get("base_url")?.as_str()?; + Some(format!("opencode/http:{}", base_url)) +} + +fn compute_gateway_fingerprint(params: &serde_json::Value) -> Option<String> { + let title = params + .get("title") + .and_then(|v| v.as_str()) + .unwrap_or("Gateway"); + Some(format!("gateway:{}", title)) +} + +/// Resolve a command name to its absolute path. +/// +/// If the command is already an absolute path, it is returned as-is. +/// Otherwise, uses platform-specific lookup (`which` on Unix, `where` on Windows) +/// to find the full path. Falls back to the raw command string if resolution fails. +fn resolve_command_path(command: &str) -> String { + let path = std::path::Path::new(command); + if path.is_absolute() { + return command.to_string(); + } + + #[cfg(unix)] + let result = std::process::Command::new("which").arg(command).output(); + + #[cfg(windows)] + let result = std::process::Command::new("where").arg(command).output(); + + match result { + Ok(output) if output.status.success() => String::from_utf8(output.stdout) + .ok() + .and_then(|s| s.lines().next().map(|l| l.trim().to_string())) + .unwrap_or_else(|| command.to_string()), + _ => command.to_string(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_acp_stdio_fingerprint() { + let params = serde_json::json!({ + "transport": { + "type": "stdio", + "command": "/usr/bin/claude", + "args": ["--acp"] + } + }); + assert_eq!( + compute_fingerprint(&ConnectorKind::Acp, ¶ms), + Some("acp/stdio:/usr/bin/claude".to_string()) + ); + } + + #[test] + fn test_acp_http_fingerprint() { + let params = serde_json::json!({ + "transport": { + "type": "http", + "base_url": "http://localhost:3000" + } + }); + assert_eq!( + compute_fingerprint(&ConnectorKind::Acp, ¶ms), + Some("acp/http:http://localhost:3000".to_string()) + ); + } + + #[test] + fn test_opencode_fingerprint() { + let params = serde_json::json!({"base_url": "http://localhost:12225"}); + assert_eq!( + compute_fingerprint(&ConnectorKind::OpenCode, ¶ms), + Some("opencode/http:http://localhost:12225".to_string()) + ); + } + + #[test] + fn test_gateway_fingerprint_with_title() { + let params = serde_json::json!({"title": "My Gateway"}); + assert_eq!( + compute_fingerprint(&ConnectorKind::Gateway, ¶ms), + Some("gateway:My Gateway".to_string()) + ); + } + + #[test] + fn test_gateway_fingerprint_default_title() { + let params = serde_json::json!({}); + assert_eq!( + compute_fingerprint(&ConnectorKind::Gateway, ¶ms), + Some("gateway:Gateway".to_string()) + ); + } + + #[test] + fn test_mock_no_fingerprint() { + assert_eq!( + compute_fingerprint(&ConnectorKind::Mock, &serde_json::json!({})), + None + ); + } + + #[test] + fn test_acceptor_no_fingerprint() { + assert_eq!( + compute_fingerprint(&ConnectorKind::Acceptor, &serde_json::json!({})), + None + ); + } + + #[test] + fn test_acp_unknown_transport_type() { + let params = serde_json::json!({ + "transport": { + "type": "websocket", + "url": "ws://localhost:8080" + } + }); + assert_eq!(compute_fingerprint(&ConnectorKind::Acp, ¶ms), None); + } + + #[test] + fn test_acp_missing_transport() { + let params = serde_json::json!({"base_url": "http://localhost:3000"}); + assert_eq!(compute_fingerprint(&ConnectorKind::Acp, ¶ms), None); + } + + #[test] + fn test_opencode_missing_base_url() { + let params = serde_json::json!({"title": "My OpenCode"}); + assert_eq!(compute_fingerprint(&ConnectorKind::OpenCode, ¶ms), None); + } + + #[test] + fn test_resolve_command_path_absolute() { + // Absolute paths should be returned as-is + #[cfg(unix)] + assert_eq!(resolve_command_path("/usr/bin/claude"), "/usr/bin/claude"); + + #[cfg(windows)] + assert_eq!( + resolve_command_path("C:\\Program Files\\claude.exe"), + "C:\\Program Files\\claude.exe" + ); + } + + #[test] + fn test_resolve_command_path_nonexistent() { + // A command that definitely doesn't exist should fall back to raw string + let result = resolve_command_path("definitely_not_a_real_command_12345"); + assert_eq!(result, "definitely_not_a_real_command_12345"); + } +} diff --git a/crates/dirigent_core/src/connectors/gateway/commands.rs b/crates/dirigent_core/src/connectors/gateway/commands.rs new file mode 100644 index 0000000..c4a8883 --- /dev/null +++ b/crates/dirigent_core/src/connectors/gateway/commands.rs @@ -0,0 +1,786 @@ +//! Command parsing and execution for the Gateway connector +//! +//! This module handles the built-in commands that can be invoked in +//! Gateway connector sessions using the `/command` syntax. + +use super::{ConnectorListCallback, ConnectorSummaryInfo, GatewaySession, SessionTransferCallback}; +use serde::{Deserialize, Serialize}; + +/// Information about a connector for transfer display +#[derive(Clone, Debug)] +pub struct ConnectorTransferInfo { + pub kind: String, // "ACP", "OpenCode", "Gateway" + pub title: String, + pub model: Option<String>, +} + +/// Commands that can be executed in a Gateway session +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum Command { + /// Enable or disable echo mode + Echo(bool), + + /// List all available connectors + ListConnectors, + + /// Select a specific connector to transfer the session to + /// Tuple: (connector_id, optional_session_id) + SelectConnector(String, Option<String>), + + /// Shortcut to transfer to a Claude connector + Claude, + + /// Show help for available commands + Help, +} + +/// Result of executing a command +#[derive(Clone, Debug)] +pub enum CommandResult { + /// A message to display to the user + Message(String), + + /// An error message + Error(String), + + /// Session was transferred to another connector + SessionTransferred { + to_connector: String, + /// Optional confirmation message + message: Option<String>, + /// Information about the connector for rich display + connector_info: Option<ConnectorTransferInfo>, + }, +} + +/// Parse a command from message text +/// +/// Commands must start with `/` and may have arguments. +/// Returns None if the message is not a command. +/// +/// # Examples +/// +/// ``` +/// use dirigent_core::connectors::gateway::commands::parse_command; +/// +/// assert!(parse_command("/help").is_some()); +/// assert!(parse_command("/echo on").is_some()); +/// assert!(parse_command("hello").is_none()); +/// ``` +pub fn parse_command(text: &str) -> Option<Command> { + let trimmed = text.trim(); + + // Commands must start with / + if !trimmed.starts_with('/') { + return None; + } + + // Split into command and arguments + let parts: Vec<&str> = trimmed[1..].splitn(2, char::is_whitespace).collect(); + let command_name = parts[0].to_lowercase(); + let args = parts.get(1).map(|s| s.trim()).unwrap_or(""); + + match command_name.as_str() { + "echo" => { + let arg = parse_arguments(args).into_iter().next().unwrap_or_default().to_lowercase(); + match arg.as_str() { + "on" | "true" | "1" | "enable" | "yes" => Some(Command::Echo(true)), + "off" | "false" | "0" | "disable" | "no" => Some(Command::Echo(false)), + "" => { + // No argument - invalid usage + None + } + _ => None, + } + } + "list-connectors" | "listconnectors" | "connectors" | "list" => { + Some(Command::ListConnectors) + } + "select-connector" | "selectconnector" | "select" | "use" => { + let args = parse_arguments(args); + let connector_id = args.first().map(|s| sanitize_connector_id(s)); + let session_id = args.get(1).map(|s| s.trim().to_string()); + connector_id.map(|id| Command::SelectConnector(id, session_id)) + } + "claude" => Some(Command::Claude), + "help" | "?" => Some(Command::Help), + _ => None, + } +} + +/// Parse command arguments, handling quoted strings +/// +/// # Examples +/// +/// - `arg1 arg2` -> ["arg1", "arg2"] +/// - `"arg with spaces" arg2` -> ["arg with spaces", "arg2"] +fn parse_arguments(args: &str) -> Vec<&str> { + let mut result = Vec::new(); + let mut chars = args.char_indices().peekable(); + let mut in_quotes = false; + let mut start = 0; + let mut current_quote = ' '; + + while let Some((i, c)) = chars.next() { + if in_quotes { + if c == current_quote { + result.push(&args[start..i]); + in_quotes = false; + // Skip whitespace after closing quote + while chars.peek().map(|(_, c)| c.is_whitespace()).unwrap_or(false) { + chars.next(); + } + if let Some((next_i, _)) = chars.peek() { + start = *next_i; + } + } + } else if c == '"' || c == '\'' { + in_quotes = true; + current_quote = c; + start = i + 1; + } else if c.is_whitespace() { + if start < i { + result.push(&args[start..i]); + } + // Skip additional whitespace + while chars.peek().map(|(_, c)| c.is_whitespace()).unwrap_or(false) { + chars.next(); + } + if let Some((next_i, _)) = chars.peek() { + start = *next_i; + } else { + start = args.len(); + } + } + } + + // Handle remaining text + if start < args.len() && !in_quotes { + let remaining = args[start..].trim(); + if !remaining.is_empty() { + result.push(remaining); + } + } + + result +} + +/// Sanitize a connector ID by keeping only alphanumeric, dash, underscore +/// Takes the first whitespace-delimited token and filters characters. +fn sanitize_connector_id(id: &str) -> String { + id.split_whitespace() + .next() + .unwrap_or("") + .chars() + .filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_') + .collect() +} + +/// Execute a command and return the result +pub async fn execute_command( + command: Command, + gateway_connector_id: &str, + session_id: &str, + session: &mut GatewaySession, + connector_list_callback: Option<&ConnectorListCallback>, + session_transfer_callback: Option<&SessionTransferCallback>, +) -> CommandResult { + match command { + Command::Echo(enabled) => { + execute_echo(enabled, session) + } + Command::ListConnectors => { + execute_list_connectors(connector_list_callback) + } + Command::SelectConnector(connector_id, target_session_id) => { + execute_select_connector( + gateway_connector_id, + session_id, + &connector_id, + target_session_id.as_deref(), + &session.current_mode_id, + &session.current_model_id, + connector_list_callback, + session_transfer_callback, + ).await + } + Command::Claude => { + execute_claude( + gateway_connector_id, + session_id, + &session.current_mode_id, + &session.current_model_id, + connector_list_callback, + session_transfer_callback, + ).await + } + Command::Help => { + execute_help() + } + } +} + +/// Execute the /echo command +fn execute_echo(enabled: bool, session: &mut GatewaySession) -> CommandResult { + session.echo_enabled = enabled; + if enabled { + CommandResult::Message("Echo mode enabled. Your messages will be echoed back.".to_string()) + } else { + CommandResult::Message("Echo mode disabled.".to_string()) + } +} + +/// Execute the /list-connectors command +fn execute_list_connectors(callback: Option<&ConnectorListCallback>) -> CommandResult { + match callback { + Some(cb) => { + let connectors: Vec<ConnectorSummaryInfo> = cb() + .into_iter() + .filter(|c| c.supports_session_transfer) + .collect(); + if connectors.is_empty() { + CommandResult::Message("No connectors available.".to_string()) + } else { + let mut message = String::from("Available connectors:\n\n"); + for (i, conn) in connectors.iter().enumerate() { + message.push_str(&format!( + "{}. **{}** ({})\n ID: `{}`\n State: {}\n\n", + i + 1, + conn.title, + conn.kind, + conn.id, + conn.state + )); + } + message.push_str("Use `/select-connector <id>` to transfer this session to a connector."); + CommandResult::Message(message) + } + } + None => { + CommandResult::Error("Connector listing is not available.".to_string()) + } + } +} + +/// Execute the /select-connector command +async fn execute_select_connector( + gateway_connector_id: &str, + session_id: &str, + connector_id: &str, + target_session_id: Option<&str>, + current_mode_id: &str, + current_model_id: &str, + list_callback: Option<&ConnectorListCallback>, + transfer_callback: Option<&SessionTransferCallback>, +) -> CommandResult { + use super::SessionTransferRequest; + use super::SessionTransferResult; + use tokio::sync::oneshot; + + match transfer_callback { + Some(cb) => { + let (result_tx, result_rx) = oneshot::channel(); + + let request = SessionTransferRequest { + gateway_connector_id: gateway_connector_id.to_string(), + gateway_session_id: session_id.to_string(), + target_connector_id: connector_id.to_string(), + target_session_id: target_session_id.map(String::from), + current_mode_id: current_mode_id.to_string(), + current_model_id: current_model_id.to_string(), + result_tx, + }; + + // Fire the callback (non-blocking) + cb(request); + + // Wait for result with timeout + match tokio::time::timeout( + std::time::Duration::from_secs(30), + result_rx + ).await { + Ok(Ok(SessionTransferResult::Transferred { connector_id, session_id: _, is_new, .. })) => { + // Look up connector info if list callback is available + let connector_info = list_callback.and_then(|list_cb| { + let connectors = list_cb(); + connectors.into_iter() + .find(|c| c.id == connector_id) + .map(|summary| ConnectorTransferInfo { + kind: summary.kind.clone(), + title: summary.title.clone(), + model: None, // Model comes later via SessionMetadataReceived + }) + }); + + let mode = if is_new { "new session" } else { "loaded session" }; + CommandResult::SessionTransferred { + to_connector: connector_id.clone(), + message: Some(format!("Transferred to {} ({})", connector_id, mode)), + connector_info, + } + } + Ok(Ok(SessionTransferResult::Failed(reason))) => { + CommandResult::Error(format!("Transfer failed: {}", reason)) + } + Ok(Err(_)) => { + CommandResult::Error("Transfer was cancelled".to_string()) + } + Err(_) => { + CommandResult::Error("Transfer timed out".to_string()) + } + } + } + None => { + CommandResult::Error("Session transfer is not available.".to_string()) + } + } +} + +/// Execute the /claude command +async fn execute_claude( + gateway_connector_id: &str, + session_id: &str, + current_mode_id: &str, + current_model_id: &str, + list_callback: Option<&ConnectorListCallback>, + transfer_callback: Option<&SessionTransferCallback>, +) -> CommandResult { + use crate::connectors::acp::config::ConnectorAgentType; + + // First, find a Claude connector by agent_type, with fallback to name matching + let claude_connector = match list_callback { + Some(cb) => { + let connectors = cb(); + // Priority 1: Find by agent_type = Claude + connectors.iter().find(|c| c.agent_type == Some(ConnectorAgentType::Claude)) + .cloned() + // Priority 2: Fallback to name matching (for backwards compatibility) + .or_else(|| { + connectors.into_iter().find(|c| { + c.title.to_lowercase().contains("claude") || + c.id.to_lowercase().contains("claude") + }) + }) + } + None => None, + }; + + match claude_connector { + Some(connector) => { + execute_select_connector( + gateway_connector_id, + session_id, + &connector.id, + None, + current_mode_id, + current_model_id, + list_callback, + transfer_callback, + ).await + } + None => { + CommandResult::Error( + "No Claude connector found. Use `/list-connectors` to see available connectors.".to_string() + ) + } + } +} + +/// Execute the /help command +fn execute_help() -> CommandResult { + let help_text = r#"**Available Commands** + +`/echo on|off` +Enable or disable echo mode. When enabled, your messages will be echoed back. + +`/list-connectors` +Show all available connectors that you can transfer this session to. + +`/select-connector <id>` +Transfer this session to the specified connector. Use the connector ID from `/list-connectors`. + +`/claude` +Shortcut to transfer this session to a Claude connector (finds a connector with "claude" in its name). + +`/help` +Show this help message. + +**Notes** +- Commands start with `/` +- Arguments can be quoted if they contain spaces +- Echo mode lets you test the connection without using an external agent"#; + + CommandResult::Message(help_text.to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_command_help() { + assert_eq!(parse_command("/help"), Some(Command::Help)); + assert_eq!(parse_command("/HELP"), Some(Command::Help)); + assert_eq!(parse_command("/?"), Some(Command::Help)); + assert_eq!(parse_command("/help extra args"), Some(Command::Help)); + } + + #[test] + fn test_parse_command_echo() { + assert_eq!(parse_command("/echo on"), Some(Command::Echo(true))); + assert_eq!(parse_command("/echo off"), Some(Command::Echo(false))); + assert_eq!(parse_command("/echo ON"), Some(Command::Echo(true))); + assert_eq!(parse_command("/echo true"), Some(Command::Echo(true))); + assert_eq!(parse_command("/echo false"), Some(Command::Echo(false))); + assert_eq!(parse_command("/echo enable"), Some(Command::Echo(true))); + assert_eq!(parse_command("/echo disable"), Some(Command::Echo(false))); + assert_eq!(parse_command("/echo"), None); // Missing argument + assert_eq!(parse_command("/echo invalid"), None); // Invalid argument + } + + #[test] + fn test_parse_command_list_connectors() { + assert_eq!(parse_command("/list-connectors"), Some(Command::ListConnectors)); + assert_eq!(parse_command("/listconnectors"), Some(Command::ListConnectors)); + assert_eq!(parse_command("/connectors"), Some(Command::ListConnectors)); + assert_eq!(parse_command("/list"), Some(Command::ListConnectors)); + } + + #[test] + fn test_parse_command_select_connector() { + assert_eq!( + parse_command("/select-connector my-connector"), + Some(Command::SelectConnector("my-connector".to_string(), None)) + ); + assert_eq!( + parse_command("/select my-connector"), + Some(Command::SelectConnector("my-connector".to_string(), None)) + ); + assert_eq!( + parse_command("/use my-connector"), + Some(Command::SelectConnector("my-connector".to_string(), None)) + ); + assert_eq!(parse_command("/select-connector"), None); // Missing ID + } + + #[test] + fn test_parse_command_claude() { + assert_eq!(parse_command("/claude"), Some(Command::Claude)); + } + + #[test] + fn test_parse_command_not_a_command() { + assert_eq!(parse_command("hello"), None); + assert_eq!(parse_command("not a command"), None); + assert_eq!(parse_command(""), None); + assert_eq!(parse_command(" "), None); + } + + #[test] + fn test_parse_command_unknown_command() { + assert_eq!(parse_command("/unknown"), None); + assert_eq!(parse_command("/foobar"), None); + } + + #[test] + fn test_parse_arguments_simple() { + assert_eq!(parse_arguments("arg1 arg2"), vec!["arg1", "arg2"]); + assert_eq!(parse_arguments("single"), vec!["single"]); + assert_eq!(parse_arguments(""), Vec::<&str>::new()); + } + + #[test] + fn test_parse_arguments_quoted() { + assert_eq!( + parse_arguments("\"arg with spaces\" arg2"), + vec!["arg with spaces", "arg2"] + ); + assert_eq!( + parse_arguments("'single quoted' arg2"), + vec!["single quoted", "arg2"] + ); + } + + #[test] + fn test_parse_arguments_whitespace() { + assert_eq!(parse_arguments(" arg1 arg2 "), vec!["arg1", "arg2"]); + } + + #[test] + fn test_execute_echo_on() { + let mut session = GatewaySession::new( + "test".to_string(), + "Test".to_string(), + false, + None, + ); + let result = execute_echo(true, &mut session); + assert!(session.echo_enabled); + match result { + CommandResult::Message(msg) => assert!(msg.contains("enabled")), + _ => panic!("Expected Message result"), + } + } + + #[test] + fn test_execute_echo_off() { + let mut session = GatewaySession::new( + "test".to_string(), + "Test".to_string(), + true, + None, + ); + let result = execute_echo(false, &mut session); + assert!(!session.echo_enabled); + match result { + CommandResult::Message(msg) => assert!(msg.contains("disabled")), + _ => panic!("Expected Message result"), + } + } + + #[test] + fn test_execute_help() { + let result = execute_help(); + match result { + CommandResult::Message(msg) => { + assert!(msg.contains("/echo")); + assert!(msg.contains("/help")); + assert!(msg.contains("/list-connectors")); + assert!(msg.contains("/select-connector")); + assert!(msg.contains("/claude")); + } + _ => panic!("Expected Message result"), + } + } + + #[test] + fn test_execute_list_connectors_no_callback() { + let result = execute_list_connectors(None); + match result { + CommandResult::Error(msg) => assert!(msg.contains("not available")), + _ => panic!("Expected Error result"), + } + } + + #[tokio::test] + async fn test_execute_select_connector_no_callback() { + let result = execute_select_connector( + "gateway-1", + "session-1", + "connector-1", + None, + "ask", + "default", + None, + None, + ).await; + match result { + CommandResult::Error(msg) => assert!(msg.contains("not available")), + _ => panic!("Expected Error result"), + } + } + + #[test] + fn test_execute_list_connectors_with_callback() { + use std::sync::Arc; + + // Create a callback that returns test connectors + let callback: super::ConnectorListCallback = Arc::new(|| { + use crate::connectors::acp::config::ConnectorAgentType; + vec![ + ConnectorSummaryInfo { + id: "opencode-1".to_string(), + title: "OpenCode Local".to_string(), + kind: "OpenCode".to_string(), + state: "Ready".to_string(), + supports_session_transfer: true, + agent_type: None, + }, + ConnectorSummaryInfo { + id: "claude-acp".to_string(), + title: "Claude (ACP)".to_string(), + kind: "Acp".to_string(), + state: "Ready".to_string(), + supports_session_transfer: true, + agent_type: Some(ConnectorAgentType::Claude), + }, + ConnectorSummaryInfo { + id: "gateway-1".to_string(), + title: "Gateway".to_string(), + kind: "Gateway".to_string(), + state: "Ready".to_string(), + supports_session_transfer: true, + agent_type: None, + }, + ] + }); + + let result = execute_list_connectors(Some(&callback)); + match result { + CommandResult::Message(msg) => { + // Verify the output contains all connectors + assert!(msg.contains("Available connectors:")); + assert!(msg.contains("opencode-1")); + assert!(msg.contains("OpenCode Local")); + assert!(msg.contains("claude-acp")); + assert!(msg.contains("Claude (ACP)")); + assert!(msg.contains("gateway-1")); + assert!(msg.contains("Gateway")); + assert!(msg.contains("/select-connector")); + } + _ => panic!("Expected Message result, got {:?}", result), + } + } + + #[test] + fn test_execute_list_connectors_empty() { + use std::sync::Arc; + + // Create a callback that returns empty list + let callback: super::ConnectorListCallback = Arc::new(|| Vec::new()); + + let result = execute_list_connectors(Some(&callback)); + match result { + CommandResult::Message(msg) => { + assert!(msg.contains("No connectors available")); + } + _ => panic!("Expected Message result, got {:?}", result), + } + } + + #[test] + fn test_sanitize_connector_id() { + assert_eq!(sanitize_connector_id("valid-id_123"), "valid-id_123"); + assert_eq!(sanitize_connector_id("has spaces"), "has"); + assert_eq!(sanitize_connector_id("weird@#$chars"), "weirdchars"); + assert_eq!(sanitize_connector_id(""), ""); + } + + #[test] + fn test_parse_select_connector_with_session_id() { + assert_eq!( + parse_command("/select-connector conn-1 session-abc"), + Some(Command::SelectConnector("conn-1".to_string(), Some("session-abc".to_string()))) + ); + assert_eq!( + parse_command("/select conn-1"), + Some(Command::SelectConnector("conn-1".to_string(), None)) + ); + } + + #[tokio::test] + async fn test_execute_select_connector_with_connector_info() { + use std::sync::Arc; + use crate::connectors::gateway::{SessionTransferCallback, SessionTransferResult}; + + // Create a list callback that returns test connectors + let list_callback: ConnectorListCallback = Arc::new(|| { + use crate::connectors::acp::config::ConnectorAgentType; + vec![ + ConnectorSummaryInfo { + id: "opencode-1".to_string(), + title: "OpenCode Local".to_string(), + kind: "OpenCode".to_string(), + state: "Ready".to_string(), + supports_session_transfer: true, + agent_type: None, + }, + ConnectorSummaryInfo { + id: "claude-acp".to_string(), + title: "Claude (ACP)".to_string(), + kind: "Acp".to_string(), + state: "Ready".to_string(), + supports_session_transfer: true, + agent_type: Some(ConnectorAgentType::Claude), + }, + ] + }); + + // Create a transfer callback that simulates a successful transfer + let transfer_callback: SessionTransferCallback = Arc::new(|request| { + let result = SessionTransferResult::Transferred { + connector_id: request.target_connector_id.clone(), + session_id: "new-session-123".to_string(), + is_new: true, + models: None, + modes: None, + }; + let _ = request.result_tx.send(result); + }); + + let result = execute_select_connector( + "gateway-1", + "session-1", + "opencode-1", + None, + "ask", + "default", + Some(&list_callback), + Some(&transfer_callback), + ) + .await; + + match result { + CommandResult::SessionTransferred { + to_connector, + message, + connector_info, + } => { + assert_eq!(to_connector, "opencode-1"); + assert!(message.is_some()); + + // Verify connector info is populated + let info = connector_info.expect("connector_info should be populated"); + assert_eq!(info.kind, "OpenCode"); + assert_eq!(info.title, "OpenCode Local"); + assert_eq!(info.model, None); // Model comes later via SessionMetadataReceived + } + _ => panic!("Expected SessionTransferred result, got {:?}", result), + } + } + + #[tokio::test] + async fn test_execute_select_connector_without_list_callback() { + use std::sync::Arc; + use crate::connectors::gateway::{SessionTransferCallback, SessionTransferResult}; + + // Create a transfer callback that simulates a successful transfer + let transfer_callback: SessionTransferCallback = Arc::new(|request| { + let result = SessionTransferResult::Transferred { + connector_id: request.target_connector_id.clone(), + session_id: "new-session-123".to_string(), + is_new: true, + models: None, + modes: None, + }; + let _ = request.result_tx.send(result); + }); + + let result = execute_select_connector( + "gateway-1", + "session-1", + "opencode-1", + None, + "ask", + "default", + None, // No list callback + Some(&transfer_callback), + ) + .await; + + match result { + CommandResult::SessionTransferred { + to_connector, + message, + connector_info, + } => { + assert_eq!(to_connector, "opencode-1"); + assert!(message.is_some()); + + // Without list callback, connector info should be None + assert!(connector_info.is_none(), "connector_info should be None when list_callback is not available"); + } + _ => panic!("Expected SessionTransferred result, got {:?}", result), + } + } +} diff --git a/crates/dirigent_core/src/connectors/gateway/echo.rs b/crates/dirigent_core/src/connectors/gateway/echo.rs new file mode 100644 index 0000000..d55d17a --- /dev/null +++ b/crates/dirigent_core/src/connectors/gateway/echo.rs @@ -0,0 +1,332 @@ +//! Echo mode implementation for the Gateway connector +//! +//! This module handles the echo response generation and optional +//! streaming simulation for realistic response behavior. + +use crate::sharing::bus::SharingBus; +use dirigent_protocol::types::ContentBlock; +use dirigent_protocol::{Event, Message, MessageRole, MessageStatus, SessionUpdate}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::broadcast; +use tracing::debug; +use uuid::Uuid; + +/// Configuration for echo behavior +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct EchoConfig { + /// Whether to simulate streaming by breaking the response into chunks + #[serde(default)] + pub simulate_streaming: bool, + + /// Delay between chunks in milliseconds (when simulate_streaming is true) + /// Default: 0 (instant) + #[serde(default)] + pub chunk_delay_ms: u64, + + /// Approximate size of each chunk in characters (when simulate_streaming is true) + /// Default: 10 + #[serde(default = "default_chunk_size")] + pub chunk_size: usize, + + /// Prefix to add to echoed messages (optional) + #[serde(default)] + pub prefix: Option<String>, + + /// Suffix to add to echoed messages (optional) + #[serde(default)] + pub suffix: Option<String>, +} + +fn default_chunk_size() -> usize { + 10 +} + +impl Default for EchoConfig { + fn default() -> Self { + Self { + simulate_streaming: false, + chunk_delay_ms: 0, + chunk_size: default_chunk_size(), + prefix: None, + suffix: None, + } + } +} + +/// Generate an echo response for the given input text +/// +/// # Arguments +/// +/// * `text` - The user's message text to echo +/// * `config` - Echo configuration for formatting +/// +/// # Returns +/// +/// The formatted echo response string +pub fn generate_echo_response(text: &str, config: &EchoConfig) -> String { + let mut response = String::new(); + + if let Some(ref prefix) = config.prefix { + response.push_str(prefix); + } + + response.push_str(text); + + if let Some(ref suffix) = config.suffix { + response.push_str(suffix); + } + + response +} + +/// Stream an echo response as multiple chunks +/// +/// This function simulates streaming behavior by breaking the response +/// into chunks and emitting them with optional delays. +/// +/// # Arguments +/// +/// * `connector_id` - The connector ID for events +/// * `session_id` - The session ID for events +/// * `message_id` - The message ID for the response +/// * `response` - The full response text to stream +/// * `config` - Echo configuration for chunk sizes and delays +/// * `events_tx` - The broadcast sender for emitting events +pub async fn stream_echo_response( + connector_id: &str, + connector_uid: Option<Uuid>, + session_id: &str, + message_id: &str, + response: &str, + config: &EchoConfig, + events_tx: &broadcast::Sender<Event>, + sharing_bus: &Arc<SharingBus>, +) { + debug!( + connector_id = %connector_id, + session_id = %session_id, + message_id = %message_id, + response_len = response.len(), + chunk_size = config.chunk_size, + delay_ms = config.chunk_delay_ms, + "Streaming echo response" + ); + + // Create a placeholder message for MessageStarted + let placeholder_message = Message { + id: message_id.to_string(), + session_id: session_id.to_string(), + role: MessageRole::Assistant, + content: vec![], + created_at: chrono::Utc::now(), + status: MessageStatus::Streaming, + metadata: None, + }; + + // Emit MessageStarted (bus + broadcast) + let started_event = Event::MessageStarted { + connector_id: connector_id.to_string(), + message: placeholder_message, + }; + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + started_event.clone(), + connector_uid, + connector_id.to_string(), + ); + sharing_bus.publish(bus_event).await; + let _ = events_tx.send(started_event); + + // Break response into chunks and stream them + let chunks = break_into_chunks(response, config.chunk_size); + + for chunk in chunks { + // Emit the chunk (bus + broadcast) + let chunk_event = Event::SessionUpdate { + connector_id: connector_id.to_string(), + session_id: session_id.to_string(), + update: SessionUpdate::AgentMessageChunk { + message_id: message_id.to_string(), + content: ContentBlock::Text { text: chunk }, + _meta: None, + }, + }; + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + chunk_event.clone(), + connector_uid, + connector_id.to_string(), + ); + sharing_bus.publish(bus_event).await; + let _ = events_tx.send(chunk_event); + + // Apply delay if configured + if config.chunk_delay_ms > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(config.chunk_delay_ms)).await; + } + } + + debug!( + connector_id = %connector_id, + message_id = %message_id, + "Finished streaming echo response" + ); +} + +/// Break a string into chunks of approximately the given size +/// +/// Tries to break at word boundaries when possible. +fn break_into_chunks(text: &str, chunk_size: usize) -> Vec<String> { + if chunk_size == 0 || text.is_empty() { + return vec![text.to_string()]; + } + + let mut chunks = Vec::new(); + let mut current_chunk = String::new(); + + for word in text.split_inclusive(char::is_whitespace) { + if current_chunk.len() + word.len() > chunk_size && !current_chunk.is_empty() { + chunks.push(current_chunk); + current_chunk = String::new(); + } + current_chunk.push_str(word); + } + + if !current_chunk.is_empty() { + chunks.push(current_chunk); + } + + // If no chunks were created (e.g., single long word), just return the whole text + if chunks.is_empty() { + chunks.push(text.to_string()); + } + + chunks +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_echo_config_default() { + let config = EchoConfig::default(); + assert!(!config.simulate_streaming); + assert_eq!(config.chunk_delay_ms, 0); + assert_eq!(config.chunk_size, 10); + assert!(config.prefix.is_none()); + assert!(config.suffix.is_none()); + } + + #[test] + fn test_echo_config_serialization() { + let config = EchoConfig { + simulate_streaming: true, + chunk_delay_ms: 50, + chunk_size: 20, + prefix: Some("Echo: ".to_string()), + suffix: None, + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: EchoConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.simulate_streaming, config.simulate_streaming); + assert_eq!(deserialized.chunk_delay_ms, config.chunk_delay_ms); + assert_eq!(deserialized.chunk_size, config.chunk_size); + assert_eq!(deserialized.prefix, config.prefix); + } + + #[test] + fn test_generate_echo_response_simple() { + let config = EchoConfig::default(); + let response = generate_echo_response("Hello, World!", &config); + assert_eq!(response, "Hello, World!"); + } + + #[test] + fn test_generate_echo_response_with_prefix() { + let config = EchoConfig { + prefix: Some("Echo: ".to_string()), + ..Default::default() + }; + let response = generate_echo_response("Hello", &config); + assert_eq!(response, "Echo: Hello"); + } + + #[test] + fn test_generate_echo_response_with_suffix() { + let config = EchoConfig { + suffix: Some(" [echoed]".to_string()), + ..Default::default() + }; + let response = generate_echo_response("Hello", &config); + assert_eq!(response, "Hello [echoed]"); + } + + #[test] + fn test_generate_echo_response_with_prefix_and_suffix() { + let config = EchoConfig { + prefix: Some("> ".to_string()), + suffix: Some(" <".to_string()), + ..Default::default() + }; + let response = generate_echo_response("test", &config); + assert_eq!(response, "> test <"); + } + + #[test] + fn test_break_into_chunks_simple() { + let chunks = break_into_chunks("hello world", 5); + assert!(!chunks.is_empty()); + let combined: String = chunks.join(""); + assert_eq!(combined, "hello world"); + } + + #[test] + fn test_break_into_chunks_single_word() { + let chunks = break_into_chunks("superlongword", 5); + // Single word longer than chunk size - should still work + assert!(!chunks.is_empty()); + let combined: String = chunks.join(""); + assert_eq!(combined, "superlongword"); + } + + #[test] + fn test_break_into_chunks_empty() { + let chunks = break_into_chunks("", 5); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0], ""); + } + + #[test] + fn test_break_into_chunks_zero_size() { + let chunks = break_into_chunks("hello world", 0); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0], "hello world"); + } + + #[test] + fn test_break_into_chunks_large_size() { + let text = "hello world"; + let chunks = break_into_chunks(text, 1000); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0], "hello world"); + } + + #[test] + fn test_break_into_chunks_preserves_whitespace() { + let text = "word1 word2 word3"; + let chunks = break_into_chunks(text, 10); + let combined: String = chunks.join(""); + assert_eq!(combined, text); + } + + #[test] + fn test_break_into_chunks_multiple() { + let text = "This is a longer piece of text that should be broken into multiple chunks."; + let chunks = break_into_chunks(text, 15); + assert!(chunks.len() > 1); + let combined: String = chunks.join(""); + assert_eq!(combined, text); + } +} diff --git a/crates/dirigent_core/src/connectors/gateway/mappings.rs b/crates/dirigent_core/src/connectors/gateway/mappings.rs new file mode 100644 index 0000000..92b50e9 --- /dev/null +++ b/crates/dirigent_core/src/connectors/gateway/mappings.rs @@ -0,0 +1,502 @@ +//! Mode and model mapping module for Gateway connector +//! +//! This module translates between Gateway's standardized modes/models and agent-specific values. +//! Gateway advertises generic modes like "simple", "dailydriver", "high" which map to agent-specific +//! values like "haiku", "sonnet", "opus" for Claude. +//! +//! # Architecture +//! +//! The Gateway connector serves as a session entry point for incoming ACP connections. When +//! sessions are transferred to target agents (Claude, Codex, Gemini), the Gateway's standardized +//! mode/model identifiers must be translated to agent-specific identifiers. +//! +//! # Mapping Direction +//! +//! - **Forward mapping** (`map_mode`, `map_model`): Gateway → Agent +//! Used when setting mode/model on a target agent after session transfer. +//! Example: User clicks "High" in editor → send "opus" to Claude. +//! +//! - **Reverse mapping** (`reverse_map_mode`, `reverse_map_model`): Agent → Gateway +//! Used in legacy mode when reporting agent state back to editor. +//! Example: Claude reports "opus" → display "High" in editor. +//! +//! # Legacy vs Config Options +//! +//! These mappings only apply in legacy mode (without Zed's `acp-beta` flag). +//! When using the new `config_options` system, values pass through unchanged. +//! +//! # Implementation +//! +//! All vendor-specific mappings are delegated to the vendor registry at +//! `dirigent_core::vendors::VENDOR_REGISTRY`. Each vendor implements the +//! `VendorInfo` trait with its specific mappings. +//! +//! # Example +//! +//! ```rust +//! use dirigent_core::connectors::gateway::mappings::{map_mode, map_model, gateway_modes, gateway_models}; +//! use dirigent_core::connectors::gateway::mappings::{reverse_map_mode, reverse_map_model}; +//! use dirigent_core::connectors::acp::config::ConnectorAgentType; +//! +//! // Forward mapping: Gateway mode to Claude mode +//! let result = map_mode(gateway_modes::WRITE, ConnectorAgentType::Claude); +//! assert_eq!(result.mapped_id, "acceptEdits"); +//! assert!(result.warning.is_none()); +//! +//! // Forward mapping: Gateway model to Claude model +//! let result = map_model(gateway_models::SIMPLE, ConnectorAgentType::Claude); +//! assert_eq!(result.mapped_id, "haiku"); +//! assert!(result.warning.is_none()); +//! +//! // Reverse mapping: Claude model to Gateway model +//! let result = reverse_map_model("opus", ConnectorAgentType::Claude); +//! assert_eq!(result.mapped_id, "high"); +//! assert!(result.warning.is_none()); +//! ``` + +use crate::connectors::acp::config::ConnectorAgentType; +use crate::vendors::VENDOR_REGISTRY; + +/// Gateway mode identifiers advertised to clients +pub mod gateway_modes { + /// Plan mode - agent creates execution plans but doesn't execute + pub const PLAN: &str = "plan"; + /// Read-only mode - agent can read but not modify files + pub const READONLY: &str = "readonly"; + /// Ask mode - agent asks before every action + pub const ASK: &str = "ask"; + /// Write mode - agent can modify files with explicit permission + pub const WRITE: &str = "write"; + /// YOLO mode - agent can perform any action without asking + pub const YOLO: &str = "yolo"; +} + +/// Gateway model identifiers advertised to clients +pub mod gateway_models { + /// Simple model - fast, lightweight, lower capability + pub const SIMPLE: &str = "simple"; + /// Daily driver model - balanced performance and capability, the go-to model for everyday tasks + pub const DAILYDRIVER: &str = "dailydriver"; + /// High model - most capable, slower, higher cost + pub const HIGH: &str = "high"; +} + +/// Claude-specific mode identifiers +/// +/// Note: These are re-exported for backwards compatibility. +/// Prefer using `dirigent_core::vendors::claude::modes` for new code. +pub mod claude_modes { + pub use crate::vendors::claude::modes::*; +} + +/// Claude-specific model identifiers +/// +/// Note: These are re-exported for backwards compatibility. +/// Prefer using `dirigent_core::vendors::claude::models` for new code. +pub mod claude_models { + pub use crate::vendors::claude::models::*; +} + +/// Result of a mode/model mapping operation. +/// +/// This type is re-exported from the vendors module for backwards compatibility. +pub use crate::vendors::MappingResult; + +// ============================================================================= +// FORWARD MAPPING: Gateway → Agent +// ============================================================================= + +/// Map a Gateway mode to an agent-specific mode. +/// +/// # Arguments +/// +/// * `gateway_mode` - The Gateway mode identifier to map +/// * `agent_type` - The target agent type +/// +/// # Returns +/// +/// A `MappingResult` containing the mapped agent-specific mode and optional warning. +/// +/// # Example +/// +/// ```rust +/// use dirigent_core::connectors::gateway::mappings::{map_mode, gateway_modes}; +/// use dirigent_core::connectors::acp::config::ConnectorAgentType; +/// +/// let result = map_mode(gateway_modes::WRITE, ConnectorAgentType::Claude); +/// assert_eq!(result.mapped_id, "acceptEdits"); +/// ``` +pub fn map_mode(gateway_mode: &str, agent_type: ConnectorAgentType) -> MappingResult { + VENDOR_REGISTRY + .get_by_agent_type(agent_type) + .map(|vendor| vendor.map_mode(gateway_mode)) + .unwrap_or_else(|| MappingResult::exact(gateway_mode)) +} + +/// Map a Gateway model to an agent-specific model. +/// +/// # Arguments +/// +/// * `gateway_model` - The Gateway model identifier to map +/// * `agent_type` - The target agent type +/// +/// # Returns +/// +/// A `MappingResult` containing the mapped agent-specific model and optional warning. +/// +/// # Example +/// +/// ```rust +/// use dirigent_core::connectors::gateway::mappings::{map_model, gateway_models}; +/// use dirigent_core::connectors::acp::config::ConnectorAgentType; +/// +/// let result = map_model(gateway_models::SIMPLE, ConnectorAgentType::Claude); +/// assert_eq!(result.mapped_id, "haiku"); +/// ``` +pub fn map_model(gateway_model: &str, agent_type: ConnectorAgentType) -> MappingResult { + VENDOR_REGISTRY + .get_by_agent_type(agent_type) + .map(|vendor| vendor.map_model(gateway_model)) + .unwrap_or_else(|| MappingResult::exact(gateway_model)) +} + +// ============================================================================= +// REVERSE MAPPING: Agent → Gateway +// ============================================================================= + +/// Reverse map an agent-specific mode to a Gateway mode. +/// +/// Used in legacy mode when reporting agent state back to the editor. +/// The editor only understands Gateway modes (plan, readonly, ask, write, yolo), +/// so agent-specific modes must be translated back. +/// +/// # Arguments +/// +/// * `agent_mode` - The agent-specific mode identifier to map +/// * `agent_type` - The agent type to map from +/// +/// # Returns +/// +/// A `MappingResult` containing the mapped Gateway mode and optional warning. +/// +/// # Example +/// +/// ```rust +/// use dirigent_core::connectors::gateway::mappings::reverse_map_mode; +/// use dirigent_core::connectors::acp::config::ConnectorAgentType; +/// +/// let result = reverse_map_mode("acceptEdits", ConnectorAgentType::Claude); +/// assert_eq!(result.mapped_id, "write"); +/// ``` +pub fn reverse_map_mode(agent_mode: &str, agent_type: ConnectorAgentType) -> MappingResult { + VENDOR_REGISTRY + .get_by_agent_type(agent_type) + .map(|vendor| vendor.reverse_map_mode(agent_mode)) + .unwrap_or_else(|| MappingResult::exact(agent_mode)) +} + +/// Reverse map an agent-specific model to a Gateway model. +/// +/// Used in legacy mode when reporting agent state back to the editor. +/// The editor only understands Gateway models (simple, daily, high), +/// so agent-specific models must be translated back. +/// +/// # Arguments +/// +/// * `agent_model` - The agent-specific model identifier to map +/// * `agent_type` - The agent type to map from +/// +/// # Returns +/// +/// A `MappingResult` containing the mapped Gateway model and optional warning. +/// +/// # Example +/// +/// ```rust +/// use dirigent_core::connectors::gateway::mappings::reverse_map_model; +/// use dirigent_core::connectors::acp::config::ConnectorAgentType; +/// +/// let result = reverse_map_model("opus", ConnectorAgentType::Claude); +/// assert_eq!(result.mapped_id, "high"); +/// ``` +pub fn reverse_map_model(agent_model: &str, agent_type: ConnectorAgentType) -> MappingResult { + VENDOR_REGISTRY + .get_by_agent_type(agent_type) + .map(|vendor| vendor.reverse_map_model(agent_model)) + .unwrap_or_else(|| MappingResult::exact(agent_model)) +} + +#[cfg(test)] +mod tests { + use super::*; + + // ========================================================================== + // MappingResult tests + // ========================================================================== + + #[test] + fn test_mapping_result_exact() { + let result = MappingResult::exact("test"); + assert_eq!(result.mapped_id, "test"); + assert!(result.warning.is_none()); + } + + #[test] + fn test_mapping_result_approximate() { + let result = MappingResult::approximate("test", "warning message"); + assert_eq!(result.mapped_id, "test"); + assert_eq!(result.warning, Some("warning message".to_string())); + } + + // ========================================================================== + // Forward mapping tests: Gateway → Agent + // ========================================================================== + + #[test] + fn test_map_mode_claude_exact() { + let result = map_mode(gateway_modes::ASK, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, claude_modes::DEFAULT); + assert!(result.warning.is_none()); + + let result = map_mode(gateway_modes::PLAN, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, claude_modes::PLAN); + assert!(result.warning.is_none()); + + let result = map_mode(gateway_modes::WRITE, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, claude_modes::ACCEPT_EDITS); + assert!(result.warning.is_none()); + + let result = map_mode(gateway_modes::YOLO, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, claude_modes::BYPASS_PERMISSIONS); + assert!(result.warning.is_none()); + } + + #[test] + fn test_map_mode_claude_approximate() { + let result = map_mode(gateway_modes::READONLY, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, claude_modes::PLAN); + assert!(result.warning.is_some()); + assert!(result.warning.unwrap().contains("readonly")); + } + + #[test] + fn test_map_mode_claude_unknown() { + let result = map_mode("unknown-mode", ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, claude_modes::DEFAULT); + assert!(result.warning.is_some()); + } + + #[test] + fn test_map_mode_custom_passthrough() { + let result = map_mode("custom-mode", ConnectorAgentType::Custom); + assert_eq!(result.mapped_id, "custom-mode"); + assert!(result.warning.is_none()); + } + + #[test] + fn test_map_mode_codex_passthrough() { + // Codex currently passes through with warning + let result = map_mode(gateway_modes::ASK, ConnectorAgentType::Codex); + assert_eq!(result.mapped_id, gateway_modes::ASK); + assert!(result.warning.is_some()); + } + + #[test] + fn test_map_mode_gemini_passthrough() { + // Gemini currently passes through with warning + let result = map_mode(gateway_modes::ASK, ConnectorAgentType::Gemini); + assert_eq!(result.mapped_id, gateway_modes::ASK); + assert!(result.warning.is_some()); + } + + #[test] + fn test_map_model_claude_exact() { + let result = map_model(gateway_models::SIMPLE, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, claude_models::HAIKU); + assert!(result.warning.is_none()); + + let result = map_model(gateway_models::DAILYDRIVER, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, claude_models::SONNET); + assert!(result.warning.is_none()); + + let result = map_model(gateway_models::HIGH, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, claude_models::OPUS); + assert!(result.warning.is_none()); + } + + #[test] + fn test_map_model_claude_unknown() { + let result = map_model("unknown-model", ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, claude_models::SONNET); + assert!(result.warning.is_some()); + } + + #[test] + fn test_map_model_custom_passthrough() { + let result = map_model("custom-model", ConnectorAgentType::Custom); + assert_eq!(result.mapped_id, "custom-model"); + assert!(result.warning.is_none()); + } + + #[test] + fn test_map_model_codex_passthrough() { + let result = map_model(gateway_models::SIMPLE, ConnectorAgentType::Codex); + assert_eq!(result.mapped_id, gateway_models::SIMPLE); + assert!(result.warning.is_some()); + } + + #[test] + fn test_map_model_gemini_passthrough() { + let result = map_model(gateway_models::SIMPLE, ConnectorAgentType::Gemini); + assert_eq!(result.mapped_id, gateway_models::SIMPLE); + assert!(result.warning.is_some()); + } + + // ========================================================================== + // Reverse mapping tests: Agent → Gateway + // ========================================================================== + + #[test] + fn test_reverse_map_mode_claude_exact() { + let result = reverse_map_mode(claude_modes::DEFAULT, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, gateway_modes::ASK); + assert!(result.warning.is_none()); + + let result = reverse_map_mode(claude_modes::PLAN, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, gateway_modes::PLAN); + assert!(result.warning.is_none()); + + let result = reverse_map_mode(claude_modes::ACCEPT_EDITS, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, gateway_modes::WRITE); + assert!(result.warning.is_none()); + + let result = reverse_map_mode(claude_modes::BYPASS_PERMISSIONS, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, gateway_modes::YOLO); + assert!(result.warning.is_none()); + } + + #[test] + fn test_reverse_map_mode_claude_unknown() { + let result = reverse_map_mode("unknown-mode", ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, gateway_modes::ASK); + assert!(result.warning.is_some()); + } + + #[test] + fn test_reverse_map_mode_custom_passthrough() { + let result = reverse_map_mode("custom-mode", ConnectorAgentType::Custom); + assert_eq!(result.mapped_id, "custom-mode"); + assert!(result.warning.is_none()); + } + + #[test] + fn test_reverse_map_model_claude_exact() { + let result = reverse_map_model(claude_models::HAIKU, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, gateway_models::SIMPLE); + assert!(result.warning.is_none()); + + let result = reverse_map_model(claude_models::SONNET, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, gateway_models::DAILYDRIVER); + assert!(result.warning.is_none()); + + let result = reverse_map_model(claude_models::OPUS, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, gateway_models::HIGH); + assert!(result.warning.is_none()); + + // Claude's "default" also maps to dailydriver + let result = reverse_map_model(claude_models::DEFAULT, ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, gateway_models::DAILYDRIVER); + assert!(result.warning.is_none()); + } + + #[test] + fn test_reverse_map_model_claude_unknown() { + let result = reverse_map_model("unknown-model", ConnectorAgentType::Claude); + assert_eq!(result.mapped_id, gateway_models::DAILYDRIVER); + assert!(result.warning.is_some()); + } + + #[test] + fn test_reverse_map_model_custom_passthrough() { + let result = reverse_map_model("custom-model", ConnectorAgentType::Custom); + assert_eq!(result.mapped_id, "custom-model"); + assert!(result.warning.is_none()); + } + + // ========================================================================== + // Round-trip tests: Gateway → Agent → Gateway + // ========================================================================== + + #[test] + fn test_mode_round_trip_claude() { + // Test that we can map Gateway → Claude → Gateway without losing meaning + for (gateway_mode, expected_claude, expected_gateway) in [ + ( + gateway_modes::ASK, + claude_modes::DEFAULT, + gateway_modes::ASK, + ), + (gateway_modes::PLAN, claude_modes::PLAN, gateway_modes::PLAN), + ( + gateway_modes::WRITE, + claude_modes::ACCEPT_EDITS, + gateway_modes::WRITE, + ), + ( + gateway_modes::YOLO, + claude_modes::BYPASS_PERMISSIONS, + gateway_modes::YOLO, + ), + ] { + let forward = map_mode(gateway_mode, ConnectorAgentType::Claude); + assert_eq!( + forward.mapped_id, expected_claude, + "Forward mapping failed for {}", + gateway_mode + ); + + let reverse = reverse_map_mode(&forward.mapped_id, ConnectorAgentType::Claude); + assert_eq!( + reverse.mapped_id, expected_gateway, + "Reverse mapping failed for {}", + gateway_mode + ); + } + } + + #[test] + fn test_model_round_trip_claude() { + // Test that we can map Gateway → Claude → Gateway without losing meaning + for (gateway_model, expected_claude, expected_gateway) in [ + ( + gateway_models::SIMPLE, + claude_models::HAIKU, + gateway_models::SIMPLE, + ), + ( + gateway_models::DAILYDRIVER, + claude_models::SONNET, + gateway_models::DAILYDRIVER, + ), + ( + gateway_models::HIGH, + claude_models::OPUS, + gateway_models::HIGH, + ), + ] { + let forward = map_model(gateway_model, ConnectorAgentType::Claude); + assert_eq!( + forward.mapped_id, expected_claude, + "Forward mapping failed for {}", + gateway_model + ); + + let reverse = reverse_map_model(&forward.mapped_id, ConnectorAgentType::Claude); + assert_eq!( + reverse.mapped_id, expected_gateway, + "Reverse mapping failed for {}", + gateway_model + ); + } + } +} diff --git a/crates/dirigent_core/src/connectors/gateway/mod.rs b/crates/dirigent_core/src/connectors/gateway/mod.rs new file mode 100644 index 0000000..9cc5019 --- /dev/null +++ b/crates/dirigent_core/src/connectors/gateway/mod.rs @@ -0,0 +1,1413 @@ +//! Gateway connector implementation +//! +//! This module provides a connector that handles messages locally with +//! configurable behavior, including echo mode and command processing. +//! It serves as the default connector for incoming ACP sessions before +//! they are routed to an external agent. +//! +//! # Features +//! +//! - **Echo Mode**: Echoes user messages back as assistant responses +//! - **Command System**: Built-in commands for session management +//! - **Session Routing**: Transfer sessions to other connectors via commands +//! +//! # Commands +//! +//! The GatewayConnector supports these commands: +//! - `/echo on|off` - Enable/disable echo mode +//! - `/list-connectors` - Show available connectors +//! - `/select-connector <id>` - Transfer session to a connector +//! - `/claude` - Shortcut to transfer to a Claude connector +//! - `/help` - Show available commands +//! +//! # Usage +//! +//! ```no_run +//! use dirigent_core::connectors::gateway::{GatewayConnector, GatewayConfig}; +//! use dirigent_core::connectors::{Connector, ConnectorCommand}; +//! +//! # async fn example() { +//! let config = GatewayConfig::default(); +//! let connector = GatewayConnector::new( +//! "gateway-1".to_string(), +//! uuid::Uuid::now_v7(), +//! config, +//! ); +//! +//! // Start the connector task +//! let task_handle = connector.start_task().await; +//! +//! // Subscribe to events +//! let mut events = connector.subscribe(); +//! +//! // Send a message (will echo if echo mode is on) +//! let cmd_tx = connector.command_tx(); +//! cmd_tx.send(ConnectorCommand::SendMessage { +//! session_id: "session-1".to_string(), +//! text: "Hello!".to_string(), +//! }).await.ok(); +//! # } +//! ``` + +pub mod commands; +pub mod echo; +pub mod mappings; + +use crate::connectors::{Connector, ConnectorCommand}; +use crate::types::{ConnectorId, ConnectorKind, ConnectorState, UserId}; +use chrono::Utc; +use dirigent_protocol::events::TurnCompleteTrigger; +use dirigent_protocol::session::{ModelInfo, SessionMode, SessionModeState, SessionModelState}; +use dirigent_protocol::types::ContentBlock; +use dirigent_protocol::{ + Event, Message, MessageRole, MessageStatus, Session, SessionMetadata, SessionUpdate, +}; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio::task::JoinHandle; +use tracing::{debug, info, warn}; +use uuid::Uuid; + +pub use commands::{Command, CommandResult}; +pub use echo::EchoConfig; + +use tokio::sync::oneshot; + +/// Callback type for listing connectors +/// +/// This is used to allow the GatewayConnector to query available +/// connectors without having a direct dependency on CoreHandle. +pub type ConnectorListCallback = Arc<dyn Fn() -> Vec<ConnectorSummaryInfo> + Send + Sync>; + +/// Request to transfer a session to another connector +#[derive(Debug)] +pub struct SessionTransferRequest { + /// ID of the Gateway connector + pub gateway_connector_id: String, + /// ID of the session being transferred (in Gateway) + pub gateway_session_id: String, + /// Target connector ID + pub target_connector_id: String, + /// Optional: specific session ID to load in target connector + pub target_session_id: Option<String>, + /// Current mode ID from the Gateway session (e.g., "yolo", "write", "ask") + pub current_mode_id: String, + /// Current model ID from the Gateway session (e.g., "high", "default", "simple") + pub current_model_id: String, + /// Channel to send the result back + pub result_tx: oneshot::Sender<SessionTransferResult>, +} + +/// Result of a session transfer attempt +#[derive(Debug, Clone)] +pub enum SessionTransferResult { + /// Transfer succeeded + Transferred { + /// ID of the connector transferred to + connector_id: String, + /// Session ID in the target connector + session_id: String, + /// Whether a new session was created + is_new: bool, + /// Model state from the target agent (if available) + models: Option<SessionModelState>, + /// Mode state from the target agent (if available) + modes: Option<SessionModeState>, + }, + /// Transfer failed + Failed(String), +} + +/// Callback type for transferring a session to another connector +/// +/// The callback receives a transfer request and should process it asynchronously. +/// Results are sent back via the oneshot channel in the request. +pub type SessionTransferCallback = Arc<dyn Fn(SessionTransferRequest) + Send + Sync>; + +/// Summary info about a connector (lightweight, no CoreHandle dependency) +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ConnectorSummaryInfo { + pub id: String, + pub title: String, + pub kind: String, + pub state: String, + /// Whether sessions can be transferred to this connector + pub supports_session_transfer: bool, + /// Agent type (for ACP connectors only, None otherwise) + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_type: Option<crate::connectors::acp::config::ConnectorAgentType>, +} + +/// Session state tracked by the GatewayConnector +#[derive(Clone, Debug)] +pub struct GatewaySession { + /// Session ID + pub id: String, + /// Session title + pub title: String, + /// Whether echo mode is enabled for this session + pub echo_enabled: bool, + /// Messages in this session + pub messages: Vec<Message>, + /// Created timestamp + pub created_at: chrono::DateTime<Utc>, + /// Updated timestamp + pub updated_at: chrono::DateTime<Utc>, + /// Current mode ID (Gateway naming) - default: "ask" + pub current_mode_id: String, + /// Current model ID (Gateway naming) - default: "default" + pub current_model_id: String, + /// Optional project ID associated with this session + pub project_id: Option<String>, +} + +impl GatewaySession { + /// Create a new session + pub fn new(id: String, title: String, echo_enabled: bool, project_id: Option<String>) -> Self { + let now = Utc::now(); + Self { + id, + title, + echo_enabled, + messages: Vec::new(), + created_at: now, + updated_at: now, + current_mode_id: mappings::gateway_modes::ASK.to_string(), + current_model_id: mappings::gateway_models::DAILYDRIVER.to_string(), + project_id, + } + } + + /// Convert to a dirigent_protocol::Session + pub fn to_protocol_session(&self) -> Session { + Session { + id: self.id.clone(), + title: self.title.clone(), + created_at: self.created_at, + updated_at: self.updated_at, + metadata: SessionMetadata { + project_path: ".".to_string(), + model: Some("gateway".to_string()), + total_messages: self.messages.len() as u32, + system_message: None, + current_mode_id: None, + _meta: None, + project_id: self.project_id.as_ref().and_then(|p| uuid::Uuid::parse_str(p).ok()), + }, + cwd: None, + models: Some(build_gateway_model_state(&self.current_model_id)), + modes: Some(build_gateway_mode_state(&self.current_mode_id)), + // Gateway connector doesn't have ACP client ID + config_options: None, + acp_client_id: None, + } + } +} + +/// Build the default mode state advertised by Gateway. +pub fn build_gateway_mode_state(current_mode_id: &str) -> SessionModeState { + SessionModeState { + current_mode_id: current_mode_id.to_string(), + available_modes: vec![ + SessionMode { + id: mappings::gateway_modes::PLAN.to_string(), + name: "Plan Mode".to_string(), + description: Some("Analyze and plan without modifying files".to_string()), + }, + SessionMode { + id: mappings::gateway_modes::READONLY.to_string(), + name: "Read Only".to_string(), + description: Some("Read files but don't modify anything".to_string()), + }, + SessionMode { + id: mappings::gateway_modes::ASK.to_string(), + name: "Ask Permission".to_string(), + description: Some("Ask before making changes (default)".to_string()), + }, + SessionMode { + id: mappings::gateway_modes::WRITE.to_string(), + name: "Write".to_string(), + description: Some("Automatically accept file edits".to_string()), + }, + SessionMode { + id: mappings::gateway_modes::YOLO.to_string(), + name: "YOLO".to_string(), + description: Some("Bypass all permission prompts".to_string()), + }, + ], + } +} + +/// Build the default model state advertised by Gateway. +pub fn build_gateway_model_state(current_model_id: &str) -> SessionModelState { + SessionModelState { + current_model_id: current_model_id.to_string(), + available_models: vec![ + ModelInfo { + model_id: mappings::gateway_models::SIMPLE.to_string(), + name: "Simple".to_string(), + description: Some("Fastest, best for quick answers".to_string()), + }, + ModelInfo { + model_id: mappings::gateway_models::DAILYDRIVER.to_string(), + name: "Daily Driver".to_string(), + description: Some("Balanced for everyday tasks (recommended)".to_string()), + }, + ModelInfo { + model_id: mappings::gateway_models::HIGH.to_string(), + name: "High".to_string(), + description: Some("Most capable for complex work".to_string()), + }, + ], + } +} + +/// Configuration for the Gateway connector +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GatewayConfig { + /// Human-readable title for this connector + pub title: String, + + /// Whether echo mode is enabled by default for new sessions + #[serde(default)] + pub default_echo_enabled: bool, + + /// Configuration for echo behavior + #[serde(default)] + pub echo: EchoConfig, +} + +impl Default for GatewayConfig { + fn default() -> Self { + Self { + title: "Gateway".to_string(), + default_echo_enabled: false, + echo: EchoConfig::default(), + } + } +} + +/// Register a Gateway session node in the inspector registry. +#[cfg(feature = "server")] +async fn inspector_register_gateway_session( + inspector: &Arc<dirigent_inspector::InspectorRegistry>, + connector_id: &str, + session: &GatewaySession, +) { + let sess_node_id = dirigent_inspector::NodeId::new(format!( + "dirigent/connectors/{}/sessions/{}", + connector_id, session.id + )); + let parent_id = + dirigent_inspector::NodeId::new(format!("dirigent/connectors/{}", connector_id)); + + let meta = dirigent_inspector::NodeMetadata::new( + dirigent_inspector::NodeKind::AsyncTask, + &session.title, + ) + .with_state(dirigent_inspector::NodeState::Running) + .with_property("session_id", serde_json::json!(&session.id)) + .with_property("status", serde_json::json!("Active")) + .with_property("message_count", serde_json::json!(session.messages.len())); + + match inspector + .register(sess_node_id, &parent_id, meta, None) + .await + { + Ok(mut handle) => { + handle.detach(); + } + Err(_) => { + // Already registered — that's fine + } + } +} + +/// Deregister all Gateway session nodes from the inspector. +#[cfg(feature = "server")] +async fn inspector_deregister_gateway_sessions( + inspector: &Arc<dirigent_inspector::InspectorRegistry>, + connector_id: &str, + sessions: &HashMap<String, GatewaySession>, +) { + for session_id in sessions.keys() { + let sess_node_id = dirigent_inspector::NodeId::new(format!( + "dirigent/connectors/{}/sessions/{}", + connector_id, session_id + )); + let _ = inspector.deregister_subtree(&sess_node_id).await; + } +} + +/// Gateway connector implementation +/// +/// A connector that handles messages locally with configurable behavior. +/// It supports echo mode and built-in commands for session management. +pub struct GatewayConnector { + /// Unique identifier for this connector instance + id: ConnectorId, + + /// Optional connector UID used for BusEvent routing. + /// + /// Populated by the runtime when the connector is registered; legacy + /// constructor paths default this to `None` and later wiring can fill it in. + pub connector_uid: Option<Uuid>, + + /// User who owns this connector + owner: UserId, + + /// Connector configuration + config: GatewayConfig, + + /// Active sessions tracked by this connector + sessions: Arc<RwLock<HashMap<String, GatewaySession>>>, + + /// Shared connector state + state: Arc<RwLock<ConnectorState>>, + + /// Sender for commands to the connector + cmd_tx: mpsc::Sender<ConnectorCommand>, + + /// Receiver for commands (internal to the connector task) + cmd_rx: Arc<RwLock<Option<mpsc::Receiver<ConnectorCommand>>>>, + + /// Broadcast sender for events + events_tx: broadcast::Sender<Event>, + + /// Shared event bus for direct-to-bus publishes. + sharing_bus: Arc<crate::sharing::bus::SharingBus>, + + /// Optional callback for listing available connectors + connector_list_callback: Option<ConnectorListCallback>, + + /// Optional callback for transferring sessions + session_transfer_callback: Option<SessionTransferCallback>, + + /// Optional inspector registry for session tracking + #[cfg(feature = "server")] + inspector: Option<Arc<dirigent_inspector::InspectorRegistry>>, +} + +/// Helper that publishes an event to both the per-connector broadcast +/// channel and the global `SharingBus`. +async fn emit_event( + sharing_bus: &Arc<crate::sharing::bus::SharingBus>, + events_tx: &broadcast::Sender<Event>, + connector_id: &ConnectorId, + connector_uid: Option<Uuid>, + event: Event, +) { + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + event.clone(), + connector_uid, + connector_id.clone(), + ); + sharing_bus.publish(bus_event).await; + let _ = events_tx.send(event); +} + +impl GatewayConnector { + /// Create a new Gateway connector + /// + /// # Arguments + /// + /// * `id` - Unique identifier for this connector + /// * `owner` - User ID of the connector owner + /// * `config` - Gateway configuration + pub fn new( + id: ConnectorId, + owner: UserId, + config: GatewayConfig, + sharing_bus: Arc<crate::sharing::bus::SharingBus>, + ) -> Self { + let (cmd_tx, cmd_rx) = mpsc::channel(100); + let (events_tx, _) = broadcast::channel(1000); + + info!( + connector_id = %id, + owner = %owner, + title = %config.title, + default_echo = config.default_echo_enabled, + "Creating Gateway connector" + ); + + Self { + id, + connector_uid: None, + owner, + config, + sessions: Arc::new(RwLock::new(HashMap::new())), + state: Arc::new(RwLock::new(ConnectorState::Initializing)), + cmd_tx, + cmd_rx: Arc::new(RwLock::new(Some(cmd_rx))), + events_tx, + sharing_bus, + connector_list_callback: None, + session_transfer_callback: None, + #[cfg(feature = "server")] + inspector: None, + } + } + + /// Wrap a raw connector `Event` in a `BusEvent` populated with this + /// connector's routing identity (`connector_uid` + `connector_id`). + /// + /// Call sites that currently broadcast a raw `Event` via `events_tx` can + /// use this helper when migrating to the `BusEvent` pipeline. This is an + /// additive helper — existing `events_tx.send(event)` emissions remain + /// unchanged. + pub fn to_bus_event( + &self, + event: dirigent_protocol::Event, + ) -> dirigent_protocol::streaming::BusEvent { + dirigent_protocol::streaming::BusEvent::from_connector_event( + event, + self.connector_uid, + self.id.clone(), + ) + } + + /// Set the inspector registry for session tracking. + #[cfg(feature = "server")] + pub fn set_inspector(&mut self, inspector: Option<Arc<dirigent_inspector::InspectorRegistry>>) { + self.inspector = inspector; + } + + /// Set the callback for listing connectors + /// + /// This callback is invoked when the `/list-connectors` command is used. + pub fn set_connector_list_callback(&mut self, callback: ConnectorListCallback) { + self.connector_list_callback = Some(callback); + } + + /// Set the callback for transferring sessions + /// + /// This callback is invoked when the `/select-connector` or `/claude` commands are used. + pub fn set_session_transfer_callback(&mut self, callback: SessionTransferCallback) { + self.session_transfer_callback = Some(callback); + } + + /// Get the events broadcast sender + pub fn events_sender(&self) -> broadcast::Sender<Event> { + self.events_tx.clone() + } + + /// Get the state Arc + pub fn state_arc(&self) -> Arc<RwLock<ConnectorState>> { + Arc::clone(&self.state) + } + + /// Start the connector's background task + pub async fn start_task(&self) -> JoinHandle<()> { + let id = self.id.clone(); + let connector_uid = self.connector_uid; + let config = self.config.clone(); + let sessions = Arc::clone(&self.sessions); + let state = Arc::clone(&self.state); + let events_tx = self.events_tx.clone(); + let sharing_bus = Arc::clone(&self.sharing_bus); + let connector_list_callback = self.connector_list_callback.clone(); + let session_transfer_callback = self.session_transfer_callback.clone(); + #[cfg(feature = "server")] + let inspector = self.inspector.clone(); + + let cmd_rx = self + .cmd_rx + .write() + .await + .take() + .expect("start_task() called more than once - command receiver already taken"); + + info!(connector_id = %id, "Starting Gateway connector task"); + + tokio::spawn(async move { + Self::run_task( + id, + connector_uid, + config, + sessions, + state, + events_tx, + sharing_bus, + cmd_rx, + connector_list_callback, + session_transfer_callback, + #[cfg(feature = "server")] + inspector, + ) + .await; + }) + } + + /// Main connector task loop + async fn run_task( + id: ConnectorId, + connector_uid: Option<Uuid>, + config: GatewayConfig, + sessions: Arc<RwLock<HashMap<String, GatewaySession>>>, + state: Arc<RwLock<ConnectorState>>, + events_tx: broadcast::Sender<Event>, + sharing_bus: Arc<crate::sharing::bus::SharingBus>, + mut cmd_rx: mpsc::Receiver<ConnectorCommand>, + connector_list_callback: Option<ConnectorListCallback>, + session_transfer_callback: Option<SessionTransferCallback>, + #[cfg(feature = "server")] inspector: Option<Arc<dirigent_inspector::InspectorRegistry>>, + ) { + info!(connector_id = %id, "Gateway connector task started"); + + // Immediately transition to Ready state (no external connection needed) + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Ready; + } + emit_event(&sharing_bus, &events_tx, &id, connector_uid, Event::Connected).await; + // Process commands + while let Some(cmd) = cmd_rx.recv().await { + match cmd { + ConnectorCommand::ListSessions => { + debug!(connector_id = %id, "Processing ListSessions command"); + let sessions_guard = sessions.read().await; + let session_list: Vec<Session> = sessions_guard + .values() + .map(|s| s.to_protocol_session()) + .collect(); + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionsListed { + connector_id: id.clone(), + sessions: session_list, + }, + ) + .await; + } + + ConnectorCommand::CreateSession { cwd, project_id, ownership } => { + debug!(connector_id = %id, cwd = ?cwd, ownership = ?ownership, "Processing CreateSession command"); + let session_id = uuid::Uuid::new_v4().to_string(); + let title = format!("Session {}", &session_id[..8]); + let session = GatewaySession::new( + session_id.clone(), + title.clone(), + config.default_echo_enabled, + project_id, + ); + let protocol_session = session.to_protocol_session(); + + // Register session with inspector + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + inspector_register_gateway_session(inspector, &id, &session).await; + } + + sessions.write().await.insert(session_id.clone(), session); + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionCreated { + connector_id: id.clone(), + session: protocol_session, + }, + ) + .await; + } + + ConnectorCommand::LoadSession { session_id, .. } => { + debug!(connector_id = %id, session_id = %session_id, "Processing LoadSession command"); + let sessions_guard = sessions.read().await; + if let Some(session) = sessions_guard.get(&session_id) { + // Session exists, emit updated event + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionUpdated { + connector_id: id.clone(), + session: session.to_protocol_session(), + }, + ) + .await; + } else { + // Session doesn't exist - create it + drop(sessions_guard); + let title = format!("Session {}", &session_id[..8.min(session_id.len())]); + let session = GatewaySession::new( + session_id.clone(), + title, + config.default_echo_enabled, + None, + ); + let protocol_session = session.to_protocol_session(); + + // Register session with inspector + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + inspector_register_gateway_session(inspector, &id, &session).await; + } + + sessions.write().await.insert(session_id.clone(), session); + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionCreated { + connector_id: id.clone(), + session: protocol_session, + }, + ) + .await; + } + } + + ConnectorCommand::ListMessages { session_id } => { + debug!(connector_id = %id, session_id = %session_id, "Processing ListMessages command"); + let sessions_guard = sessions.read().await; + if let Some(session) = sessions_guard.get(&session_id) { + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessagesListed { + messages: session.messages.clone(), + }, + ) + .await; + } else { + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessagesListed { messages: vec![] }, + ) + .await; + } + } + + ConnectorCommand::SendMessage { session_id, text } => { + debug!(connector_id = %id, session_id = %session_id, text_len = text.len(), "Processing SendMessage command"); + + // Get or create session + let mut sessions_guard = sessions.write().await; + let session = sessions_guard.entry(session_id.clone()).or_insert_with(|| { + GatewaySession::new( + session_id.clone(), + format!("Session {}", &session_id[..8.min(session_id.len())]), + config.default_echo_enabled, + None, + ) + }); + + // Check for system error message from runtime + if text.starts_with("__SYSTEM_ERROR__:") { + let error_content = text.strip_prefix("__SYSTEM_ERROR__:").unwrap_or(&text); + + // Create error message in session + let assistant_message_id = uuid::Uuid::new_v4().to_string(); + let assistant_message = Message { + id: assistant_message_id.clone(), + session_id: session_id.clone(), + role: MessageRole::Assistant, + content: vec![dirigent_protocol::MessagePart::Text { + text: error_content.to_string(), + }], + created_at: Utc::now(), + status: MessageStatus::Completed, + metadata: None, + }; + session.messages.push(assistant_message.clone()); + session.updated_at = Utc::now(); + + // Emit as assistant message + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessageStarted { + connector_id: id.clone(), + message: assistant_message.clone(), + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessageCompleted { + connector_id: id.clone(), + message: assistant_message.clone(), + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::TurnComplete { + connector_id: id.clone(), + session_id: session_id.clone(), + message_id: assistant_message_id.clone(), + trigger: TurnCompleteTrigger::ExplicitSignal, + }, + ) + .await; + // Don't process as command or echo + continue; + } + + // Create user message + let user_message_id = uuid::Uuid::new_v4().to_string(); + let user_message = Message { + id: user_message_id.clone(), + session_id: session_id.clone(), + role: MessageRole::User, + content: vec![dirigent_protocol::MessagePart::Text { text: text.clone() }], + created_at: Utc::now(), + status: MessageStatus::Completed, + metadata: None, + }; + session.messages.push(user_message.clone()); + session.updated_at = Utc::now(); + + // Emit user message completed + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessageCompleted { + connector_id: id.clone(), + message: user_message, + }, + ) + .await; + // Track if session was transferred to skip further events + let mut session_transferred = false; + + // Check for commands first + if let Some(command) = commands::parse_command(&text) { + let result = commands::execute_command( + command, + &id, + &session_id, + session, + connector_list_callback.as_ref(), + session_transfer_callback.as_ref(), + ) + .await; + + // Check if session was transferred + session_transferred = + matches!(result, CommandResult::SessionTransferred { .. }); + + // For transfers, the EventBridge handles the confirmation message + // and SessionIdle. We only emit messages for non-transfer commands. + if !session_transferred { + let assistant_message_id = uuid::Uuid::new_v4().to_string(); + let response_text = match result { + CommandResult::Message(msg) => msg, + CommandResult::Error(err) => format!("Error: {}", err), + CommandResult::SessionTransferred { .. } => { + // This branch won't be reached due to the if guard above + unreachable!() + } + }; + + let assistant_message = Message { + id: assistant_message_id.clone(), + session_id: session_id.clone(), + role: MessageRole::Assistant, + content: vec![dirigent_protocol::MessagePart::Text { + text: response_text.clone(), + }], + created_at: Utc::now(), + status: MessageStatus::Completed, + metadata: None, + }; + session.messages.push(assistant_message.clone()); + session.updated_at = Utc::now(); + + // Stream the response + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessageStarted { + connector_id: id.clone(), + message: assistant_message.clone(), + }, + ) + .await; + // Emit chunk update + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionUpdate { + connector_id: id.clone(), + session_id: session_id.clone(), + update: SessionUpdate::AgentMessageChunk { + message_id: assistant_message_id.clone(), + content: ContentBlock::Text { + text: response_text, + }, + _meta: None, + }, + }, + ) + .await; + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessageCompleted { + connector_id: id.clone(), + message: assistant_message.clone(), + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::TurnComplete { + connector_id: id.clone(), + session_id: session_id.clone(), + message_id: assistant_message_id.clone(), + trigger: TurnCompleteTrigger::ExplicitSignal, + }, + ) + .await; + } + // For transfers, the EventBridge handles confirmation + SessionIdle + } else if session.echo_enabled { + // Echo mode is on - echo the message back + let echo_response = echo::generate_echo_response(&text, &config.echo); + + let assistant_message_id = uuid::Uuid::new_v4().to_string(); + let assistant_message = Message { + id: assistant_message_id.clone(), + session_id: session_id.clone(), + role: MessageRole::Assistant, + content: vec![dirigent_protocol::MessagePart::Text { + text: echo_response.clone(), + }], + created_at: Utc::now(), + status: MessageStatus::Completed, + metadata: None, + }; + session.messages.push(assistant_message.clone()); + session.updated_at = Utc::now(); + + // Check if streaming simulation is enabled + if config.echo.simulate_streaming { + // Stream the echo response with delays between chunks + echo::stream_echo_response( + &id, + connector_uid, + &session_id, + &assistant_message_id, + &echo_response, + &config.echo, + &events_tx, + &sharing_bus, + ) + .await; + } else { + // Send message events without streaming delays + // Pattern: MessageStarted (empty) -> AgentMessageChunk (content) -> MessageCompleted + // This matches OpenCode behavior and tests chunk handling + let placeholder_message = Message { + id: assistant_message_id.clone(), + session_id: session_id.clone(), + role: MessageRole::Assistant, + content: vec![], // Empty - content comes via AgentMessageChunk + created_at: Utc::now(), + status: MessageStatus::Streaming, + metadata: None, + }; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessageStarted { + connector_id: id.clone(), + message: placeholder_message, + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionUpdate { + connector_id: id.clone(), + session_id: session_id.clone(), + update: SessionUpdate::AgentMessageChunk { + message_id: assistant_message_id, + content: ContentBlock::Text { + text: echo_response, + }, + _meta: None, + }, + }, + ) + .await; + } + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessageCompleted { + connector_id: id.clone(), + message: assistant_message.clone(), + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::TurnComplete { + connector_id: id.clone(), + session_id: session_id.clone(), + message_id: assistant_message.id.clone(), + trigger: TurnCompleteTrigger::ExplicitSignal, + }, + ) + .await; + } else { + // Echo mode is off - provide helpful response + let assistant_message_id = uuid::Uuid::new_v4().to_string(); + let response_text = "Echo mode is disabled. Use `/echo on` to enable, or `/help` for available commands.".to_string(); + let assistant_message = Message { + id: assistant_message_id.clone(), + session_id: session_id.clone(), + role: MessageRole::Assistant, + content: vec![dirigent_protocol::MessagePart::Text { + text: response_text.clone(), + }], + created_at: Utc::now(), + status: MessageStatus::Completed, + metadata: None, + }; + session.messages.push(assistant_message.clone()); + session.updated_at = Utc::now(); + + // Send message events using OpenCode-compatible pattern: + // MessageStarted (empty) -> AgentMessageChunk (content) -> MessageCompleted + let placeholder_message = Message { + id: assistant_message_id.clone(), + session_id: session_id.clone(), + role: MessageRole::Assistant, + content: vec![], // Empty - content comes via AgentMessageChunk + created_at: Utc::now(), + status: MessageStatus::Streaming, + metadata: None, + }; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessageStarted { + connector_id: id.clone(), + message: placeholder_message, + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionUpdate { + connector_id: id.clone(), + session_id: session_id.clone(), + update: SessionUpdate::AgentMessageChunk { + message_id: assistant_message_id.clone(), + content: ContentBlock::Text { + text: response_text, + }, + _meta: None, + }, + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessageCompleted { + connector_id: id.clone(), + message: assistant_message.clone(), + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::TurnComplete { + connector_id: id.clone(), + session_id: session_id.clone(), + message_id: assistant_message_id, + trigger: TurnCompleteTrigger::ExplicitSignal, + }, + ) + .await; + } + + // Emit session idle (unless session was transferred) + if !session_transferred { + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionIdle { + connector_id: id.to_string(), + session_id: session_id.clone(), + }, + ) + .await; + } + } + + ConnectorCommand::CancelGeneration { session_id } => { + debug!(connector_id = %id, session_id = %session_id, "Processing CancelGeneration (no-op for Gateway)"); + // No-op for Gateway connector since responses are instant + } + + ConnectorCommand::Reconnect => { + debug!(connector_id = %id, "Processing Reconnect (no-op for Gateway)"); + // No-op since we don't have external connections + } + + ConnectorCommand::AgentResponse { .. } => { + // Gateway connector does not support agent-initiated requests + warn!( + connector_id = %id, + "Received AgentResponse command but Gateway connector does not support agent-initiated requests" + ); + } + + ConnectorCommand::SetSessionMode { + session_id, + mode_id, + } => { + if let Some(session) = sessions.write().await.get_mut(&session_id) { + session.current_mode_id = mode_id.clone(); + session.updated_at = Utc::now(); + + // Emit SessionMetadataReceived event to notify clients + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionMetadataReceived { + connector_id: id.clone(), + session_id: session_id.clone(), + models: Some(build_gateway_model_state(&session.current_model_id)), + modes: Some(build_gateway_mode_state(&session.current_mode_id)), + config_options: None, + }, + ) + .await; + debug!(session_id = %session_id, mode_id = %mode_id, "Gateway session mode updated"); + } else { + warn!(session_id = %session_id, "SetSessionMode: session not found"); + } + } + + ConnectorCommand::CloseSession { session_id } => { + warn!(connector_id = %id, session_id = %session_id, "CloseSession not supported by Gateway connector"); + } + + ConnectorCommand::SetSessionModel { + session_id, + model_id, + } => { + if let Some(session) = sessions.write().await.get_mut(&session_id) { + session.current_model_id = model_id.clone(); + session.updated_at = Utc::now(); + + // Emit SessionMetadataReceived event to notify clients + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionMetadataReceived { + connector_id: id.clone(), + session_id: session_id.clone(), + models: Some(build_gateway_model_state(&session.current_model_id)), + modes: Some(build_gateway_mode_state(&session.current_mode_id)), + config_options: None, + }, + ) + .await; + debug!(session_id = %session_id, model_id = %model_id, "Gateway session model updated"); + } else { + warn!(session_id = %session_id, "SetSessionModel: session not found"); + } + } + + ConnectorCommand::SetConfigOption { .. } => { + warn!(connector_id = %id, "SetConfigOption not supported by Gateway connector"); + } + + ConnectorCommand::Shutdown => { + info!(connector_id = %id, "Received Shutdown command, stopping connector"); + + // Deregister session nodes from inspector + #[cfg(feature = "server")] + if let Some(ref inspector) = inspector { + let sessions_guard = sessions.read().await; + inspector_deregister_gateway_sessions(inspector, &id, &sessions_guard) + .await; + } + + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Stopped; + } + break; + } + } + } + + info!(connector_id = %id, "Gateway connector task stopped"); + } +} + +impl Connector for GatewayConnector { + fn id(&self) -> &ConnectorId { + &self.id + } + + fn kind(&self) -> ConnectorKind { + ConnectorKind::Gateway + } + + fn owner(&self) -> &UserId { + &self.owner + } + + fn title(&self) -> &str { + &self.config.title + } + + fn state(&self) -> ConnectorState { + match self.state.try_read() { + Ok(state_guard) => state_guard.clone(), + Err(_) => ConnectorState::Initializing, + } + } + + fn command_tx(&self) -> mpsc::Sender<ConnectorCommand> { + self.cmd_tx.clone() + } + + fn subscribe(&self) -> broadcast::Receiver<Event> { + self.events_tx.subscribe() + } + + fn stop(&self) { + let cmd_tx = self.cmd_tx.clone(); + tokio::spawn(async move { + let _ = cmd_tx.send(ConnectorCommand::Shutdown).await; + }); + } + + fn get_available_commands(&self) -> Vec<crate::acp::protocol::streaming::Command> { + // Return the built-in Gateway commands + vec![ + crate::acp::protocol::streaming::Command { + name: "echo".to_string(), + description: Some("Enable or disable echo mode (/echo on|off)".to_string()), + }, + crate::acp::protocol::streaming::Command { + name: "help".to_string(), + description: Some("Show available commands".to_string()), + }, + crate::acp::protocol::streaming::Command { + name: "list-connectors".to_string(), + description: Some("Show available connectors".to_string()), + }, + crate::acp::protocol::streaming::Command { + name: "select-connector".to_string(), + description: Some( + "Transfer session to a connector (/select-connector <id>)".to_string(), + ), + }, + crate::acp::protocol::streaming::Command { + name: "claude".to_string(), + description: Some("Transfer session to a Claude connector".to_string()), + }, + ] + } +} + +/// Returns the transient commands that should be available on all connectors +/// +/// These are session management commands provided by the Gateway that should be +/// merged with any target connector's commands (unless the connector already +/// provides a command with the same name). +pub fn get_transient_commands() -> Vec<crate::acp::protocol::streaming::Command> { + vec![ + crate::acp::protocol::streaming::Command { + name: "list-connectors".to_string(), + description: Some("Show available connectors".to_string()), + }, + crate::acp::protocol::streaming::Command { + name: "select-connector".to_string(), + description: Some( + "Transfer session to a connector (/select-connector <id>)".to_string(), + ), + }, + crate::acp::protocol::streaming::Command { + name: "claude".to_string(), + description: Some("Transfer session to a Claude connector".to_string()), + }, + ] +} + +/// Merges agent commands with transient gateway commands +/// +/// Transient commands are added unless the agent already provides a command with the same name. +/// This ensures session management commands are available on all connectors. +pub fn merge_with_transient_commands( + mut agent_commands: Vec<crate::acp::protocol::streaming::Command>, +) -> Vec<crate::acp::protocol::streaming::Command> { + use std::collections::HashSet; + + let agent_names: HashSet<String> = agent_commands.iter().map(|c| c.name.clone()).collect(); + + for cmd in get_transient_commands() { + if !agent_names.contains(&cmd.name) { + agent_commands.push(cmd); + } + } + + agent_commands +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_gateway_config_default() { + let config = GatewayConfig::default(); + assert_eq!(config.title, "Gateway"); + assert!(!config.default_echo_enabled); + } + + #[test] + fn test_gateway_config_serialization() { + let config = GatewayConfig { + title: "Test".to_string(), + default_echo_enabled: true, + echo: EchoConfig::default(), + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: GatewayConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.title, config.title); + assert_eq!( + deserialized.default_echo_enabled, + config.default_echo_enabled + ); + } + + #[tokio::test] + async fn test_gateway_connector_creation() { + let config = GatewayConfig::default(); + let connector = GatewayConnector::new( + "test-conn".to_string(), + uuid::Uuid::nil(), + config, + crate::sharing::bus::SharingBus::new(), + ); + + assert_eq!(connector.id(), "test-conn"); + assert_eq!(*connector.owner(), uuid::Uuid::nil()); + assert_eq!(connector.title(), "Gateway"); + assert_eq!(connector.kind(), ConnectorKind::Gateway); + assert_eq!(connector.state(), ConnectorState::Initializing); + } + + #[test] + fn test_gateway_session_creation() { + let session = + GatewaySession::new("session-1".to_string(), "Test Session".to_string(), true, None); + + assert_eq!(session.id, "session-1"); + assert_eq!(session.title, "Test Session"); + assert!(session.echo_enabled); + assert!(session.messages.is_empty()); + } + + #[test] + fn test_gateway_session_to_protocol() { + let session = + GatewaySession::new("session-1".to_string(), "Test Session".to_string(), true, None); + let protocol_session = session.to_protocol_session(); + + assert_eq!(protocol_session.id, "session-1"); + assert_eq!(protocol_session.title, "Test Session"); + assert_eq!(protocol_session.metadata.model, Some("gateway".to_string())); + } + + #[test] + fn test_session_transfer_result_types() { + let success = SessionTransferResult::Transferred { + connector_id: "opencode-1".to_string(), + session_id: "session-new".to_string(), + is_new: true, + models: None, + modes: None, + }; + + let failure = SessionTransferResult::Failed("Connector not ready".to_string()); + + match success { + SessionTransferResult::Transferred { is_new, .. } => assert!(is_new), + _ => panic!("Expected Transferred"), + } + + match failure { + SessionTransferResult::Failed(reason) => assert!(reason.contains("not ready")), + _ => panic!("Expected Failed"), + } + } +} diff --git a/crates/dirigent_core/src/connectors/mod.rs b/crates/dirigent_core/src/connectors/mod.rs new file mode 100644 index 0000000..329f900 --- /dev/null +++ b/crates/dirigent_core/src/connectors/mod.rs @@ -0,0 +1,930 @@ +//! Connector abstraction layer +//! +//! This module provides the connector abstraction that allows dirigent_core to +//! manage long-lived connections to external agent systems (OpenCode, ACP, etc.). +//! +//! # Architecture +//! +//! Each connector: +//! - Runs in its own async task with lifecycle management +//! - Has a command channel (mpsc) for receiving control commands +//! - Has an event broadcast channel for publishing events to subscribers +//! - Tracks its own state (Initializing, Connecting, Ready, Error, Stopped) +//! - Is owned by a specific user (for authorization) +//! +//! # Usage +//! +//! ```no_run +//! use dirigent_core::connectors::{Connector, ConnectorCommand}; +//! use tokio::sync::broadcast; +//! +//! async fn example(connector: impl Connector) { +//! // Subscribe to connector events +//! let mut event_rx = connector.subscribe(); +//! +//! // Send a command +//! let cmd_tx = connector.command_tx(); +//! cmd_tx.send(ConnectorCommand::ListSessions).await.ok(); +//! +//! // Receive events +//! while let Ok(event) = event_rx.recv().await { +//! println!("Received event: {:?}", event); +//! } +//! } +//! ``` + +use crate::types::{ConnectorErrorKind, ConnectorId, ConnectorKind, ConnectorState, UserId}; +use std::sync::Arc; +use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio::task::JoinHandle; + +// Connector implementations +pub mod acceptor; +pub mod acp; +pub mod fingerprint; +pub mod gateway; +pub mod opencode; + +pub use fingerprint::compute_fingerprint; + +/// Commands that can be sent to a connector +/// +/// These commands control connector behavior and trigger operations on the +/// underlying agent system. Commands are sent via the connector's command +/// channel and processed asynchronously by the connector's task loop. +#[derive(Clone, Debug)] +pub enum ConnectorCommand { + /// Request a list of all sessions + /// + /// The connector will query the underlying agent system and emit + /// a SessionsListed event with the results. + ListSessions, + + /// Request messages for a specific session + /// + /// The connector will fetch messages for the given session and emit + /// a MessagesListed event with the results. + ListMessages { + /// ID of the session to list messages from + session_id: String, + }, + + /// Create a new session + /// + /// The connector will create a new session with the agent system and emit + /// a SessionCreated event with the new session information. + CreateSession { + /// Optional current working directory for the session + cwd: Option<String>, + /// Optional project ID to associate with the session + project_id: Option<String>, + /// Session ownership model (internal UI session vs external forwarded session) + ownership: dirigent_protocol::SessionOwnership, + }, + + /// Load an existing session + /// + /// The connector will load an existing session (if the protocol supports it) + /// and replay its history. This may emit multiple SessionUpdate events during + /// replay. + LoadSession { + /// ID of the session to load + session_id: String, + /// Working directory (project path from archivist metadata or session/list) + cwd: String, + /// MCP server configurations for the agent + mcp_servers: Option<serde_json::Value>, + }, + + /// Send a message to a session + /// + /// The connector will send the message to the agent system and emit + /// streaming events as the response is generated. + SendMessage { + /// ID of the session to send to + session_id: String, + /// Message text content + text: String, + }, + + /// Cancel generation in a session + /// + /// Requests the agent to stop generating the current response. + CancelGeneration { + /// ID of the session to cancel generation in + session_id: String, + }, + + /// Attempt to reconnect the connector + /// + /// This is useful for recovering from transient connection failures. + /// The connector will transition to Connecting state and attempt to + /// re-establish its connection to the agent system. + Reconnect, + + /// Shutdown the connector gracefully + /// + /// The connector will clean up resources, close connections, and + /// transition to Stopped state. The connector's task loop will exit. + Shutdown, + + /// Respond to an agent-initiated request + /// + /// This command allows external systems (like the ACP Server) to inject + /// responses for pending agent requests (e.g., permission prompts). + /// The connector will look up the pending request by request_id and + /// send the response to the agent via transport. + AgentResponse { + /// The request ID from the agent (for correlation) + request_id: serde_json::Value, + /// The response to send back to the agent + response: serde_json::Value, + }, + + /// Set the session mode + /// + /// Change the active mode for a session. The mode determines behavior + /// characteristics like conversation style or capabilities. + /// The connector will send the mode change request to the agent system. + SetSessionMode { + /// ID of the session to update + session_id: String, + /// ID of the mode to switch to + mode_id: String, + }, + + /// Set the session model + /// + /// Change the active model for a session. The model determines which + /// LLM/agent backend is used for processing. + /// The connector will send the model change request to the agent system. + SetSessionModel { + /// ID of the session to update + session_id: String, + /// ID of the model to switch to + model_id: String, + }, + + /// Close a session (release agent resources, session remains listable). + /// Only works when agent supports sessionCapabilities.close. + CloseSession { + session_id: String, + }, + + /// Set a config option value (preferred over SetSessionMode/SetSessionModel). + /// Uses session/set_config_option per ACP spec. + SetConfigOption { + session_id: String, + config_id: String, + value: String, + }, +} + +/// Trait for connector implementations +/// +/// This trait defines the interface that all connectors must implement. +/// It is object-safe, allowing for dynamic dispatch via `dyn Connector`. +/// +/// # Object Safety +/// +/// All methods return references or simple Copy types to ensure the trait +/// can be used as a trait object. Methods that need owned values use +/// cloneable types like `mpsc::Sender` and `broadcast::Receiver`. +pub trait Connector: Send + Sync { + /// Get the unique identifier for this connector + fn id(&self) -> &ConnectorId; + + /// Get the type of connector (OpenCode, Acp, Mock, etc.) + fn kind(&self) -> ConnectorKind; + + /// Get the user who owns this connector + /// + /// Used for authorization checks when routing commands and operations. + fn owner(&self) -> &UserId; + + /// Get the human-readable title for this connector + /// + /// This is typically set during connector creation and used for display + /// in UI components. + fn title(&self) -> &str; + + /// Get the current state of this connector + /// + /// The state reflects the connector's position in its lifecycle: + /// - Initializing: Just created, not yet connecting + /// - Connecting: Attempting to establish connection + /// - Ready: Connected and operational + /// - Error: Encountered a failure (with error message) + /// - Stopped: Shutdown or unrecoverable error + fn state(&self) -> ConnectorState; + + /// Get a sender for sending commands to this connector + /// + /// Commands sent via this channel will be processed asynchronously + /// by the connector's task loop. The sender is cloneable and can be + /// shared across multiple clients. + /// + /// # Returns + /// + /// An mpsc sender that can be used to send `ConnectorCommand` values. + fn command_tx(&self) -> mpsc::Sender<ConnectorCommand>; + + /// Subscribe to events from this connector + /// + /// Returns a broadcast receiver that will receive all events published + /// by this connector. Each call to subscribe creates a new independent + /// receiver. + /// + /// # Returns + /// + /// A broadcast receiver for `dirigent_protocol::Event` values. + /// + /// # Notes + /// + /// Broadcast channels have a bounded capacity. If a receiver falls too + /// far behind (doesn't consume events fast enough), it will miss events + /// and receive a `RecvError::Lagged` error. + fn subscribe(&self) -> broadcast::Receiver<dirigent_protocol::Event>; + + /// Stop this connector gracefully + /// + /// This is a convenience method that sends a Shutdown command to the + /// connector's command channel. The connector will clean up and transition + /// to Stopped state asynchronously. + fn stop(&self); + + /// Get available commands/tools for this connector + /// + /// Returns a list of commands that this connector exposes to external + /// ACP clients. These commands will be forwarded when sessions are + /// routed to this connector. + /// + /// # Returns + /// + /// A vector of command definitions with name and optional description. + /// Returns an empty vector if the connector doesn't expose commands. + fn get_available_commands(&self) -> Vec<crate::acp::protocol::streaming::Command> { + // Default implementation: no commands + Vec::new() + } +} + +/// Handle to a running connector +/// +/// This struct contains all the state and channels needed to interact with +/// a connector. It implements the `Connector` trait and can be cloned to +/// create multiple handles to the same underlying connector. +/// +/// # Lifecycle +/// +/// A ConnectorHandle is typically created by the CoreRuntime when a connector +/// is instantiated. The handle contains: +/// - Metadata (id, kind, owner, title) +/// - Shared state protected by RwLock +/// - Command channel sender +/// - Event broadcast sender (for publishing events) +/// - Optional task handle (for graceful shutdown) +/// +/// # Cloning +/// +/// ConnectorHandle uses Arc internally, so cloning is cheap and creates +/// a new handle to the same connector. All clones share the same state +/// and channels. +#[derive(Clone)] +pub struct ConnectorHandle { + /// Unique identifier for this connector + id: ConnectorId, + + /// Type of connector (OpenCode, Acp, Mock, etc.) + kind: ConnectorKind, + + /// User who owns this connector + owner: UserId, + + /// Human-readable title for display + title: String, + + /// Shared connector state + /// + /// Protected by RwLock to allow concurrent reads and exclusive writes. + /// The connector's task loop updates this as state transitions occur. + state: Arc<RwLock<ConnectorState>>, + + /// Sender for commands to the connector + /// + /// Commands sent via this channel are processed by the connector's + /// async task loop. + cmd_tx: mpsc::Sender<ConnectorCommand>, + + /// Broadcast sender for events + /// + /// The connector publishes events to this channel. Subscribers get + /// receivers via the `subscribe()` method. + events_tx: broadcast::Sender<dirigent_protocol::Event>, + + /// Join handle for the connector's background task + /// + /// This is None until the connector is started. Once started, this + /// handle can be used to wait for the task to complete during shutdown. + task_join: Option<Arc<RwLock<Option<JoinHandle<()>>>>>, + + /// Connector-specific configuration as JSON + /// + /// This stores the serialized configuration parameters (e.g., OpenCodeConfig) + /// that were used to create this connector. It is preserved across restarts + /// to allow recreation of the connector instance with the same parameters. + config: Arc<RwLock<serde_json::Value>>, + + /// Resolved working directory for this connector + /// + /// This is the actual working directory path resolved from the connector + /// configuration and global settings. It is computed once during creation + /// and stored for quick access. + working_directory: Arc<RwLock<Option<std::path::PathBuf>>>, + + /// T034: Optional custom icon path for this connector + /// + /// If set, the UI should display this custom icon instead of the default + /// connector type emoji. + icon_path: Option<String>, + + /// T035: Show connector type emoji as overlay on custom icon + /// + /// When true and icon_path is set, the connector type emoji should appear + /// as a small overlay in the lower-right corner of the custom icon. + show_type_overlay: bool, + + /// Structured error classification for the current error state. + /// + /// Shared with the connector's background task via Arc. When `ConnectorState` + /// is `Error`, this provides a machine-readable classification. Cleared to + /// `None` when the connector recovers. + error_kind: Arc<RwLock<Option<ConnectorErrorKind>>>, + + /// Dynamic commands reported by the remote agent. + /// + /// For ACP connectors, the agent sends `available_commands_update` notifications + /// with its current slash commands. These are stored here so that + /// `get_available_commands()` can return them. For non-ACP connectors this + /// stays empty and static commands are returned based on connector kind. + available_commands: Arc<RwLock<Vec<crate::acp::protocol::streaming::Command>>>, +} + +impl ConnectorHandle { + /// Create a new ConnectorHandle + /// + /// # Arguments + /// + /// * `id` - Unique identifier for this connector + /// * `kind` - Type of connector (OpenCode, Acp, Mock, etc.) + /// * `owner` - User who owns this connector + /// * `title` - Human-readable title for display + /// * `cmd_tx` - Sender for commands to the connector + /// * `events_tx` - Broadcast sender for events + /// * `config` - Connector-specific configuration as JSON + /// * `working_directory` - Optional working directory path + /// * `icon_path` - T034: Optional custom icon path + /// * `show_type_overlay` - T035: Show connector type emoji as overlay + /// + /// # Returns + /// + /// A new ConnectorHandle with state initialized to `Initializing` + #[allow(clippy::too_many_arguments)] + pub fn new( + id: ConnectorId, + kind: ConnectorKind, + owner: UserId, + title: String, + cmd_tx: mpsc::Sender<ConnectorCommand>, + events_tx: broadcast::Sender<dirigent_protocol::Event>, + config: serde_json::Value, + working_directory: Option<std::path::PathBuf>, + icon_path: Option<String>, + show_type_overlay: bool, + ) -> Self { + Self { + id, + kind, + owner, + title, + state: Arc::new(RwLock::new(ConnectorState::Initializing)), + cmd_tx, + events_tx, + task_join: Some(Arc::new(RwLock::new(None))), + config: Arc::new(RwLock::new(config)), + working_directory: Arc::new(RwLock::new(working_directory)), + icon_path, + show_type_overlay, + error_kind: Arc::new(RwLock::new(None)), + available_commands: Arc::new(RwLock::new(Vec::new())), + } + } + + /// Create a new ConnectorHandle with a shared state Arc + /// + /// This constructor allows the handle to share the same state Arc as the + /// connector implementation, ensuring state updates from the connector's + /// background task are visible through the handle. + /// + /// # Arguments + /// + /// * `id` - Unique identifier for this connector + /// * `kind` - Type of connector (OpenCode, Acp, Mock, etc.) + /// * `owner` - User who owns this connector + /// * `title` - Human-readable title for display + /// * `state` - Shared state Arc from the connector + /// * `cmd_tx` - Sender for commands to the connector + /// * `events_tx` - Broadcast sender for events + /// * `config` - Connector-specific configuration as JSON + /// * `working_directory` - Optional working directory path + /// * `icon_path` - T034: Optional custom icon path + /// * `show_type_overlay` - T035: Show connector type emoji as overlay + /// + /// # Returns + /// + /// A new ConnectorHandle sharing the provided state Arc + #[allow(clippy::too_many_arguments)] + pub fn new_with_state( + id: ConnectorId, + kind: ConnectorKind, + owner: UserId, + title: String, + state: Arc<RwLock<ConnectorState>>, + cmd_tx: mpsc::Sender<ConnectorCommand>, + events_tx: broadcast::Sender<dirigent_protocol::Event>, + config: serde_json::Value, + working_directory: Option<std::path::PathBuf>, + icon_path: Option<String>, + show_type_overlay: bool, + ) -> Self { + Self { + id, + kind, + owner, + title, + state, + cmd_tx, + events_tx, + task_join: Some(Arc::new(RwLock::new(None))), + config: Arc::new(RwLock::new(config)), + working_directory: Arc::new(RwLock::new(working_directory)), + icon_path, + show_type_overlay, + error_kind: Arc::new(RwLock::new(None)), + available_commands: Arc::new(RwLock::new(Vec::new())), + } + } + + /// Get a reference to the state RwLock + /// + /// This is useful for connector implementations that need to update + /// the state as the connector progresses through its lifecycle. + pub fn state_lock(&self) -> Arc<RwLock<ConnectorState>> { + Arc::clone(&self.state) + } + + /// Get the current error classification for this connector. + /// + /// Returns `None` when the connector is healthy, or `Some(kind)` when + /// the connector is in an error state with a classified error. + /// Uses `try_read()` to avoid blocking, returning `None` if the lock is held. + pub fn error_kind(&self) -> Option<ConnectorErrorKind> { + match self.error_kind.try_read() { + Ok(guard) => guard.clone(), + Err(_) => None, + } + } + + /// Replace the error_kind Arc with one shared from a connector implementation. + /// + /// This allows the ConnectorHandle to observe the same error_kind state + /// that the connector's background task updates. + pub fn set_error_kind_arc(&mut self, error_kind: Arc<RwLock<Option<ConnectorErrorKind>>>) { + self.error_kind = error_kind; + } + + /// Get a reference to the events broadcast sender + /// + /// This is useful for connector implementations that need to publish + /// events to subscribers. + pub fn events_sender(&self) -> broadcast::Sender<dirigent_protocol::Event> { + self.events_tx.clone() + } + + /// Set the task join handle + /// + /// This should be called by connector implementations after spawning + /// their background task. The handle is used during shutdown to wait + /// for the task to complete gracefully. + pub async fn set_task_handle(&self, handle: JoinHandle<()>) { + if let Some(task_join) = &self.task_join { + let mut guard = task_join.write().await; + *guard = Some(handle); + } + } + + /// Get the task join handle (if set) + /// + /// This is useful during shutdown to wait for the connector's task + /// to complete. + pub async fn take_task_handle(&self) -> Option<JoinHandle<()>> { + if let Some(task_join) = &self.task_join { + let mut guard = task_join.write().await; + guard.take() + } else { + None + } + } + + /// Get a reference to the config RwLock + /// + /// This returns a clone of the config Arc, allowing access to the + /// connector's configuration. This is primarily used during restart + /// to recreate the connector with the same parameters. + pub fn config(&self) -> Arc<RwLock<serde_json::Value>> { + Arc::clone(&self.config) + } + + /// Get a cloned copy of the connector configuration + /// + /// This is a convenience method that locks the config and returns + /// a cloned copy of the JSON value. Use this when you need to read + /// the config without holding the lock. + pub async fn get_config_cloned(&self) -> serde_json::Value { + let config_lock = self.config.read().await; + config_lock.clone() + } + + /// Update the stored configuration + /// + /// This method replaces the connector's configuration with the provided value. + /// The changes take effect immediately in the stored config but only affect + /// connector behavior after a restart. + /// + /// # Arguments + /// + /// * `config` - The new configuration value to store + pub async fn set_config(&self, config: serde_json::Value) { + let mut cfg = self.config.write().await; + *cfg = config; + } + + /// Get the resolved working directory for this connector + /// + /// Returns the working directory path that was resolved during connector + /// creation. This may be a custom path set in the connector configuration, + /// or the default project directory if no custom path was specified. + pub fn working_directory(&self) -> Option<std::path::PathBuf> { + // Try to read the working directory without blocking + if let Ok(guard) = self.working_directory.try_read() { + guard.clone() + } else { + None + } + } + + /// T034: Get the custom icon path for this connector + /// + /// Returns the optional custom icon path if set. The UI should use this + /// icon instead of the default connector type emoji when present. + pub fn icon_path(&self) -> Option<&str> { + self.icon_path.as_deref() + } + + /// T035: Check if type overlay should be shown on custom icon + /// + /// Returns true if the connector type emoji should appear as a small + /// overlay in the lower-right corner of the custom icon. + pub fn show_type_overlay(&self) -> bool { + self.show_type_overlay + } + + /// Set the icon path for this connector + /// + /// Updates the custom icon path. This change is reflected immediately. + pub fn set_icon_path(&mut self, icon_path: Option<String>) { + self.icon_path = icon_path; + } + + /// Set whether to show type overlay on custom icon + /// + /// Updates the overlay setting. This change is reflected immediately. + pub fn set_show_type_overlay(&mut self, show_overlay: bool) { + self.show_type_overlay = show_overlay; + } + + /// Set the connector title + /// + /// Updates the display title for this connector. + pub fn set_title(&mut self, title: String) { + self.title = title; + } + + /// Set the working directory + /// + /// Updates the working directory for this connector. This is an async operation + /// since working_directory is protected by an RwLock. + pub async fn set_working_directory(&self, working_directory: Option<std::path::PathBuf>) { + let mut wd = self.working_directory.write().await; + *wd = working_directory; + } + + /// Replace the command sender with a new one + /// + /// This is an internal method used during connector restart to update + /// the command channel after recreating the connector instance. + /// + /// # Warning + /// + /// This is a low-level operation intended only for use by CoreRuntime's + /// restart_connector method. External callers should not use this method. + pub async fn replace_command_tx(&self, _new_tx: mpsc::Sender<ConnectorCommand>) { + // We can't replace cmd_tx directly since it's not behind a lock + // Instead, we rely on the fact that restart creates a new handle + // This method is kept for potential future use with handle mutation + // For now, document that restart will update references via CoreRuntime + } + + /// Update the dynamic available commands for this connector. + /// + /// Called when the remote agent sends an `available_commands_update` + /// notification. The stored commands are returned by + /// `get_available_commands()` for ACP (and any other dynamic) connectors. + pub fn set_available_commands( + &self, + commands: Vec<crate::acp::protocol::streaming::Command>, + ) { + if let Ok(mut guard) = self.available_commands.try_write() { + *guard = commands; + } + } +} + +impl Connector for ConnectorHandle { + fn id(&self) -> &ConnectorId { + &self.id + } + + fn kind(&self) -> ConnectorKind { + self.kind.clone() + } + + fn owner(&self) -> &UserId { + &self.owner + } + + fn title(&self) -> &str { + &self.title + } + + fn state(&self) -> ConnectorState { + // Try to read the state without blocking + // If the lock is held, we return the last known state + // This is acceptable for status queries where eventual consistency is fine + match self.state.try_read() { + Ok(state_lock) => state_lock.clone(), + Err(_) => { + // Lock is held, return Initializing as a safe default + // In practice, this is extremely rare and only happens during + // state transitions, which are very fast + ConnectorState::Initializing + } + } + } + + fn command_tx(&self) -> mpsc::Sender<ConnectorCommand> { + self.cmd_tx.clone() + } + + fn subscribe(&self) -> broadcast::Receiver<dirigent_protocol::Event> { + self.events_tx.subscribe() + } + + fn stop(&self) { + // Send shutdown command + // We don't wait for it to be processed - this is a cooperative shutdown + let cmd_tx = self.cmd_tx.clone(); + tokio::spawn(async move { + let _ = cmd_tx.send(ConnectorCommand::Shutdown).await; + }); + } + + fn get_available_commands(&self) -> Vec<crate::acp::protocol::streaming::Command> { + // First, read any dynamically-stored commands (populated by ACP agents + // via available_commands_update notifications). + let dynamic = self + .available_commands + .try_read() + .map(|g| g.clone()) + .unwrap_or_default(); + + if !dynamic.is_empty() { + return dynamic; + } + + // Fall back to static commands based on connector kind + match self.kind { + ConnectorKind::Gateway => vec![ + crate::acp::protocol::streaming::Command { + name: "echo".to_string(), + description: Some("Enable or disable echo mode (/echo on|off)".to_string()), + }, + crate::acp::protocol::streaming::Command { + name: "help".to_string(), + description: Some("Show available commands".to_string()), + }, + crate::acp::protocol::streaming::Command { + name: "list-connectors".to_string(), + description: Some("Show available connectors".to_string()), + }, + crate::acp::protocol::streaming::Command { + name: "select-connector".to_string(), + description: Some( + "Transfer session to a connector (/select-connector <id>)".to_string(), + ), + }, + crate::acp::protocol::streaming::Command { + name: "claude".to_string(), + description: Some("Transfer session to a Claude connector".to_string()), + }, + ], + _ => Vec::new(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_connector_command_clone() { + let cmd1 = ConnectorCommand::ListSessions; + let cmd2 = cmd1.clone(); + + // Both should be ListSessions variant + matches!(cmd1, ConnectorCommand::ListSessions); + matches!(cmd2, ConnectorCommand::ListSessions); + } + + #[test] + fn test_connector_command_debug() { + let cmd = ConnectorCommand::SendMessage { + session_id: "test-session".to_string(), + text: "Hello".to_string(), + }; + + let debug_str = format!("{:?}", cmd); + assert!(debug_str.contains("SendMessage")); + assert!(debug_str.contains("test-session")); + } + + #[tokio::test] + async fn test_connector_handle_creation() { + let (cmd_tx, _cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + + let handle = ConnectorHandle::new( + "test-connector".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test Connector".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + assert_eq!(handle.id(), "test-connector"); + assert_eq!(handle.kind(), ConnectorKind::Mock); + assert_eq!(*handle.owner(), uuid::Uuid::nil()); + assert_eq!(handle.title(), "Test Connector"); + assert_eq!(handle.state(), ConnectorState::Initializing); + } + + #[tokio::test] + async fn test_connector_handle_state_update() { + let (cmd_tx, _cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + + let handle = ConnectorHandle::new( + "test-connector".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test Connector".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + // Update state to Ready + { + let state_lock = handle.state_lock(); + let mut state = state_lock.write().await; + *state = ConnectorState::Ready; + } + + assert_eq!(handle.state(), ConnectorState::Ready); + } + + #[tokio::test] + async fn test_connector_handle_subscribe() { + let (cmd_tx, _cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + + let handle = ConnectorHandle::new( + "test-connector".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test Connector".to_string(), + cmd_tx, + events_tx.clone(), + serde_json::json!({}), + None, + None, + false, + ); + + // Subscribe to events + let mut rx1 = handle.subscribe(); + let mut rx2 = handle.subscribe(); + + // Send an event + let event = dirigent_protocol::Event::Connected; + events_tx.send(event.clone()).ok(); + + // Both receivers should get the event + let received1 = rx1.recv().await.unwrap(); + let received2 = rx2.recv().await.unwrap(); + + matches!(received1, dirigent_protocol::Event::Connected); + matches!(received2, dirigent_protocol::Event::Connected); + } + + #[tokio::test] + async fn test_connector_handle_command_tx() { + let (cmd_tx, mut cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + + let handle = ConnectorHandle::new( + "test-connector".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test Connector".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + // Send a command + let sender = handle.command_tx(); + sender.send(ConnectorCommand::ListSessions).await.unwrap(); + + // Receive the command + let cmd = cmd_rx.recv().await.unwrap(); + matches!(cmd, ConnectorCommand::ListSessions); + } + + #[tokio::test] + async fn test_connector_handle_clone() { + let (cmd_tx, _cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + + let handle1 = ConnectorHandle::new( + "test-connector".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test Connector".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + let handle2 = handle1.clone(); + + // Both handles should refer to the same connector + assert_eq!(handle1.id(), handle2.id()); + assert_eq!(handle1.kind(), handle2.kind()); + + // Update state via handle1 + { + let state_lock = handle1.state_lock(); + let mut state = state_lock.write().await; + *state = ConnectorState::Ready; + } + + // handle2 should see the update + assert_eq!(handle2.state(), ConnectorState::Ready); + } +} diff --git a/crates/dirigent_core/src/connectors/opencode.rs b/crates/dirigent_core/src/connectors/opencode.rs new file mode 100644 index 0000000..4f90419 --- /dev/null +++ b/crates/dirigent_core/src/connectors/opencode.rs @@ -0,0 +1,1409 @@ +//! OpenCode connector implementation +//! +//! This module provides a concrete connector implementation that wraps the +//! `opencode_client` library to provide OpenCode.ai integration for Dirigent. +//! +//! # Architecture +//! +//! The OpenCodeConnector: +//! - Wraps `opencode_client::OpenCodeClient` for REST API operations +//! - Manages SSE (Server-Sent Events) subscription for real-time updates +//! - Translates OpenCode events to Dirigent protocol events using `OpenCodeAdapter` +//! - Handles automatic reconnection with exponential backoff +//! - Processes commands (list sessions, send messages, etc.) +//! +//! # Lifecycle +//! +//! 1. Create connector with `OpenCodeConnector::new(id, owner, config)` +//! 2. Start the connector task with `start_task()` which returns a JoinHandle +//! 3. Send commands via the command channel +//! 4. Receive events via the event broadcast channel +//! 5. Stop with `Connector::stop()` which sends Shutdown command +//! +//! # Example +//! +//! ```no_run +//! use dirigent_core::connectors::opencode::{OpenCodeConnector, OpenCodeConfig}; +//! use dirigent_core::connectors::{Connector, ConnectorCommand}; +//! +//! # async fn example() { +//! let config = OpenCodeConfig { +//! base_url: "http://localhost:12225".to_string(), +//! title: "My OpenCode".to_string(), +//! initial_session: None, +//! }; +//! +//! let connector = OpenCodeConnector::new( +//! "conn-1".to_string(), +//! uuid::Uuid::now_v7(), +//! "My OpenCode".to_string(), +//! config, +//! ); +//! +//! // Start the connector task +//! let task_handle = connector.start_task().await; +//! +//! // Subscribe to events +//! let mut events = connector.subscribe(); +//! +//! // Send a command +//! let cmd_tx = connector.command_tx(); +//! cmd_tx.send(ConnectorCommand::ListSessions).await.ok(); +//! +//! // Later: stop the connector +//! connector.stop(); +//! task_handle.await.ok(); +//! # } +//! ``` + +use crate::connectors::{Connector, ConnectorCommand}; +use crate::sharing::bus::SharingBus; +use crate::types::{ConnectorErrorKind, ConnectorId, ConnectorKind, ConnectorState, UserId}; +use dirigent_protocol::adapters::OpenCodeAdapter; +use dirigent_protocol::{Event, TurnCompleteTrigger}; +use futures::StreamExt; +use opencode_client::{OpenCodeClient, SseClient}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio::task::JoinHandle; +use tracing::{debug, error, info, warn}; +use uuid::Uuid; + +/// Configuration for OpenCode connector +/// +/// Contains all the necessary information to connect to an OpenCode.ai instance +/// and configure the connector's behavior. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct OpenCodeConfig { + /// Base URL of the OpenCode API + /// + /// Example: "http://localhost:12225" or "https://api.opencode.ai" + pub base_url: String, + + /// Optional initial session to load on connection + /// + /// If provided, the connector will attempt to load this session immediately + /// after connecting. If None, no initial session is loaded. + pub initial_session: Option<String>, +} + +/// OpenCode connector implementation +/// +/// Provides integration with OpenCode.ai by wrapping the `opencode_client` library +/// and implementing the `Connector` trait. This allows the Dirigent core runtime +/// to manage OpenCode connections alongside other connector types. +/// +/// # Fields +/// +/// The connector maintains both the client libraries and the channels needed for +/// async communication: +/// - REST client for API calls (list sessions, send messages, etc.) +/// - SSE client for real-time event streaming +/// - Command channel for receiving control commands +/// - Event broadcast channel for publishing events to subscribers +/// - Shared state for tracking connector lifecycle +pub struct OpenCodeConnector { + /// Unique identifier for this connector instance + id: ConnectorId, + + /// Optional connector UID used for BusEvent routing. + /// + /// Populated by the runtime when the connector is registered; legacy + /// constructor paths default this to `None` and later wiring can fill it in. + pub connector_uid: Option<Uuid>, + + /// User who owns this connector + owner: UserId, + + /// Human-readable title for this connector + /// + /// Used in UI display and logging. This is stored separately from + /// OpenCodeConfig as it's an orchestration-level field. + title: String, + + /// Connector configuration + config: OpenCodeConfig, + + /// Shared connector state (Initializing, Connecting, Ready, Error, Stopped) + /// + /// Protected by RwLock to allow concurrent reads and exclusive writes. + /// The connector task updates this as it progresses through its lifecycle. + state: Arc<RwLock<ConnectorState>>, + + /// Structured error classification for the current error state. + /// + /// Set alongside `ConnectorState::Error` to provide machine-readable error + /// categorization. Cleared to `None` when the connector recovers. + error_kind: Arc<RwLock<Option<ConnectorErrorKind>>>, + + /// Sender for commands to the connector + /// + /// Commands sent via this channel are processed by the connector's + /// background task loop. This is cloneable for sharing across tasks. + cmd_tx: mpsc::Sender<ConnectorCommand>, + + /// Receiver for commands (internal to the connector task) + /// + /// This is consumed by the `start_task()` method and moved into the + /// background task. It's wrapped in Option so we can take it. + cmd_rx: Arc<RwLock<Option<mpsc::Receiver<ConnectorCommand>>>>, + + /// Broadcast sender for events + /// + /// Events from the OpenCode SSE stream are translated and published here. + /// Subscribers get receivers via the `subscribe()` method. + events_tx: broadcast::Sender<Event>, + + /// Shared event bus for direct-to-bus publishes. + /// + /// Every event emitted by this connector is published here in addition + /// to `events_tx`, eliminating the forwarder task tier in the runtime. + sharing_bus: Arc<SharingBus>, + + /// OpenCode REST client + /// + /// Used for API operations like listing sessions, listing messages, + /// and sending messages. + client: OpenCodeClient, +} + +/// Helper that publishes an event to both the per-connector broadcast +/// channel and the global `SharingBus`. The bus publish goes first so any +/// subscriber reading from both sees the bus event no later than the +/// broadcast event. +async fn emit_event( + sharing_bus: &Arc<SharingBus>, + events_tx: &broadcast::Sender<Event>, + connector_id: &ConnectorId, + connector_uid: Option<Uuid>, + event: Event, +) { + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + event.clone(), + connector_uid, + connector_id.clone(), + ); + sharing_bus.publish(bus_event).await; + let _ = events_tx.send(event); +} + +impl OpenCodeConnector { + /// Create a new OpenCode connector + /// + /// Initializes the connector with the given configuration but does not + /// start it. Call `start_task()` to begin the connector's background task. + /// + /// # Arguments + /// + /// * `id` - Unique identifier for this connector + /// * `owner` - User ID of the connector owner + /// * `title` - Human-readable title for this connector + /// * `config` - OpenCode configuration (base URL, etc.) + /// + /// # Returns + /// + /// A new OpenCodeConnector in Initializing state + /// + /// # Example + /// + /// ```no_run + /// use dirigent_core::connectors::opencode::{OpenCodeConnector, OpenCodeConfig}; + /// + /// let config = OpenCodeConfig { + /// base_url: "http://localhost:12225".to_string(), + /// initial_session: None, + /// }; + /// + /// let connector = OpenCodeConnector::new( + /// "my-connector".to_string(), + /// uuid::Uuid::now_v7(), + /// "Local OpenCode".to_string(), + /// config, + /// ); + /// ``` + pub fn new( + id: ConnectorId, + owner: UserId, + title: String, + config: OpenCodeConfig, + sharing_bus: Arc<SharingBus>, + ) -> Self { + // Create the OpenCode REST client + let client = OpenCodeClient::new(&config.base_url); + + // Create command channel (capacity 100 for buffering commands) + let (cmd_tx, cmd_rx) = mpsc::channel(100); + + // Create event broadcast channel (capacity 1000 for event buffering) + let (events_tx, _) = broadcast::channel(1000); + + info!( + connector_id = %id, + owner = %owner, + title = %title, + base_url = %config.base_url, + "Creating OpenCode connector" + ); + + Self { + id, + connector_uid: None, + owner, + title, + config, + state: Arc::new(RwLock::new(ConnectorState::Initializing)), + error_kind: Arc::new(RwLock::new(None)), + cmd_tx, + cmd_rx: Arc::new(RwLock::new(Some(cmd_rx))), + events_tx, + sharing_bus, + client, + } + } + + /// Wrap a raw connector `Event` in a `BusEvent` populated with this + /// connector's routing identity (`connector_uid` + `connector_id`). + /// + /// Call sites that currently broadcast a raw `Event` via `events_tx` can + /// use this helper when migrating to the `BusEvent` pipeline. This is an + /// additive helper — existing `events_tx.send(event)` emissions remain + /// unchanged. + pub fn to_bus_event( + &self, + event: dirigent_protocol::Event, + ) -> dirigent_protocol::streaming::BusEvent { + dirigent_protocol::streaming::BusEvent::from_connector_event( + event, + self.connector_uid, + self.id.clone(), + ) + } + + /// Get the events broadcast sender + /// + /// This method returns a clone of the broadcast sender for publishing events. + /// It's useful when creating a ConnectorHandle from this connector. + /// + /// # Returns + /// + /// A clone of the broadcast sender + pub fn events_sender(&self) -> broadcast::Sender<Event> { + self.events_tx.clone() + } + + /// Get the state Arc + /// + /// This method returns a clone of the state Arc, allowing the ConnectorHandle + /// to share the same state that the connector task updates. + /// + /// # Returns + /// + /// A clone of the Arc<RwLock<ConnectorState>> + pub fn state_arc(&self) -> Arc<RwLock<ConnectorState>> { + Arc::clone(&self.state) + } + + /// Get the error_kind Arc + /// + /// This method returns a clone of the error_kind Arc, allowing the ConnectorHandle + /// to share the same error classification that the connector task updates. + pub fn error_kind_arc(&self) -> Arc<RwLock<Option<ConnectorErrorKind>>> { + Arc::clone(&self.error_kind) + } + + /// Set connector_id on events that require it + /// + /// This helper function injects the connector ID into events that were + /// translated by the adapter with placeholder empty strings. + fn set_connector_id(event: &mut Event, connector_id: &str) { + match event { + Event::SessionsListed { + connector_id: cid, .. + } => *cid = connector_id.to_string(), + Event::SessionCreated { + connector_id: cid, .. + } => *cid = connector_id.to_string(), + Event::SessionUpdated { + connector_id: cid, .. + } => *cid = connector_id.to_string(), + Event::MessageStarted { + connector_id: cid, .. + } => *cid = connector_id.to_string(), + Event::MessageCompleted { + connector_id: cid, .. + } => *cid = connector_id.to_string(), + Event::SessionUpdate { + connector_id: cid, .. + } => *cid = connector_id.to_string(), + _ => {} // Other events don't need connector_id + } + } + + /// Start the connector's background task + /// + /// This spawns an async task that: + /// - Connects to the OpenCode SSE endpoint + /// - Processes incoming events and translates them to Dirigent events + /// - Handles commands from the command channel + /// - Manages automatic reconnection on failures + /// + /// # Returns + /// + /// A JoinHandle for the background task, which can be used to wait for + /// the task to complete or to detect if it panics. + /// + /// # Panics + /// + /// Panics if called more than once (the command receiver can only be taken once). + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::connectors::opencode::{OpenCodeConnector, OpenCodeConfig}; + /// # async fn example(connector: OpenCodeConnector) { + /// let task_handle = connector.start_task().await; + /// + /// // Do other work... + /// + /// // Wait for task to complete + /// task_handle.await.ok(); + /// # } + /// ``` + pub async fn start_task(&self) -> JoinHandle<()> { + let id = self.id.clone(); + let connector_uid = self.connector_uid; + let config = self.config.clone(); + let state = Arc::clone(&self.state); + let error_kind = Arc::clone(&self.error_kind); + let events_tx = self.events_tx.clone(); + let sharing_bus = Arc::clone(&self.sharing_bus); + let client = self.client.clone(); + + // Take the command receiver (this can only be done once) + let cmd_rx = self + .cmd_rx + .write() + .await + .take() + .expect("start_task() called more than once - command receiver already taken"); + + info!(connector_id = %id, "Starting OpenCode connector task"); + + tokio::spawn(async move { + Self::run_task( + id, + connector_uid, + config, + state, + error_kind, + events_tx, + sharing_bus, + client, + cmd_rx, + ) + .await; + }) + } + + /// Main connector task loop + /// + /// This is the core of the connector's async behavior. It manages: + /// - SSE connection and reconnection + /// - Event translation and broadcasting + /// - Command processing + /// - Error handling and recovery + /// + /// This function runs until a Shutdown command is received or an + /// unrecoverable error occurs. + async fn run_task( + id: ConnectorId, + connector_uid: Option<Uuid>, + config: OpenCodeConfig, + state: Arc<RwLock<ConnectorState>>, + error_kind: Arc<RwLock<Option<ConnectorErrorKind>>>, + events_tx: broadcast::Sender<Event>, + sharing_bus: Arc<SharingBus>, + client: OpenCodeClient, + mut cmd_rx: mpsc::Receiver<ConnectorCommand>, + ) { + info!(connector_id = %id, "OpenCode connector task started"); + + // Create the event adapter for translating OpenCode events to Dirigent events + let adapter = OpenCodeAdapter::new(); + + // Reconnection state - Counter 1: connection failures (HTTP/TCP level) + let mut retry_count = 0; + let max_retries = 5; + let retry_delays = [ + Duration::from_secs(1), + Duration::from_secs(3), + Duration::from_secs(5), + Duration::from_secs(5), + Duration::from_secs(5), + ]; + + // Reconnection state - Counter 2: rapid disconnects (stream dies within stability threshold) + let mut rapid_disconnect_count: usize = 0; + let rapid_disconnect_max: usize = 5; + let rapid_disconnect_delays = [ + Duration::from_secs(2), + Duration::from_secs(5), + Duration::from_secs(10), + Duration::from_secs(30), + Duration::from_secs(60), + ]; + let stability_threshold = Duration::from_secs(30); + let mut connected_at: Option<tokio::time::Instant> = None; + // Track whether we've ever had a stable connection (for adaptive offline detection) + // If never stable, declare offline after just 2 rapid disconnects (fast startup detection) + // If was stable before, use full 5 rapid disconnect patience (transient issues) + let mut ever_stable = false; + + // Main loop: connect, process events, handle reconnects + 'reconnect_loop: loop { + // Update state to Connecting + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Connecting; + } + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: "Connecting".to_string(), + error_kind: None, + }, + ) + .await; + + // Emit a Connecting pseudo-event (not part of protocol, but useful for logging) + debug!(connector_id = %id, "Attempting to connect to OpenCode SSE"); + + // Create SSE client and attempt connection + let sse_client = SseClient::new(&config.base_url); + let stream = match sse_client.connect() { + Ok(stream) => { + info!(connector_id = %id, "SSE stream channel acquired, awaiting events"); + + // Reset connection failure count (we did connect successfully) + retry_count = 0; + // Record when we connected to detect rapid disconnects + connected_at = Some(tokio::time::Instant::now()); + + // Update state to Ready and clear error_kind + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Ready; + } + { + let mut ek = error_kind.write().await; + *ek = None; + } + + // Emit state change + Connected events + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: "Ready".to_string(), + error_kind: None, + }, + ) + .await; + emit_event(&sharing_bus, &events_tx, &id, connector_uid, Event::Connected).await; + + stream + } + Err(e) => { + error!( + connector_id = %id, + error = %e, + "Failed to connect to OpenCode SSE" + ); + + // Update state to Error with ConnectionFailed kind + { + let mut state_guard = state.write().await; + *state_guard = + ConnectorState::Error(format!("SSE connection failed: {}", e)); + } + { + let mut ek = error_kind.write().await; + *ek = Some(ConnectorErrorKind::ConnectionFailed); + } + + // Emit state change + error events + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: format!("Error(SSE connection failed: {})", e), + error_kind: Some("connection_failed".to_string()), + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("Failed to connect: {}", e), + }, + ) + .await; + + // Handle retry logic + if retry_count < max_retries { + let delay = retry_delays[retry_count.min(retry_delays.len() - 1)]; + warn!( + connector_id = %id, + retry_count = retry_count + 1, + max_retries = max_retries, + delay_secs = delay.as_secs(), + "Retrying connection" + ); + + retry_count += 1; + + // Wait before retrying + tokio::time::sleep(delay).await; + + continue 'reconnect_loop; + } else { + error!( + connector_id = %id, + "Max retries exceeded, staying in Error state until manual Reconnect" + ); + + // Stay in Error state and wait for commands + loop { + match cmd_rx.recv().await { + Some(ConnectorCommand::Reconnect) => { + info!(connector_id = %id, "Received Reconnect command, resetting all retry counters"); + retry_count = 0; + rapid_disconnect_count = 0; + { + let mut ek = error_kind.write().await; + *ek = None; + } + continue 'reconnect_loop; + } + Some(ConnectorCommand::Shutdown) => { + info!(connector_id = %id, "Received Shutdown command"); + break 'reconnect_loop; + } + Some(cmd) => { + warn!( + connector_id = %id, + command = ?cmd, + "Received command while in Error state, ignoring" + ); + } + None => { + error!(connector_id = %id, "Command channel closed"); + break 'reconnect_loop; + } + } + } + } + } + }; + + // Convert stream to tokio::select!-compatible type + let mut stream = Box::pin(stream); + + // Event processing loop + 'event_loop: loop { + tokio::select! { + // Handle incoming SSE events + event_result = stream.next() => { + match event_result { + Some(Ok(oc_event)) => { + debug!(connector_id = %id, event = ?oc_event, "Received OpenCode event"); + + // Translate the event using the adapter + match adapter.translate_event(oc_event) { + Ok(mut dirigent_event) => { + // Set connector_id on events that need it + Self::set_connector_id(&mut dirigent_event, &id); + + // ENHANCEMENT: Fetch user message content if empty + // User messages arrive via SSE with empty content because they don't stream + // We need to fetch the full message with parts from the API + if let Event::MessageCompleted { message, connector_id: _ } = &mut dirigent_event { + if message.role == dirigent_protocol::MessageRole::User && message.content.is_empty() { + debug!( + connector_id = %id, + message_id = %message.id, + session_id = %message.session_id, + "User message has empty content, fetching from API" + ); + + // Fetch all messages for this session from the API + match client.list_messages(&message.session_id).await { + Ok(all_messages) => { + // Find the specific message we need + if let Some(msg_with_parts) = all_messages.iter().find(|m| { + match &m.info { + opencode_client::types::Message::User(u) => u.id == message.id, + opencode_client::types::Message::Assistant(a) => a.id == message.id, + } + }) { + // Defensive: clear content before adding fetched parts + // (should already be empty, but this prevents duplication) + message.content.clear(); + + // Translate parts and add to message + for oc_part in &msg_with_parts.parts { + match dirigent_protocol::adapters::opencode::OpenCodeAdapter::translate_part_for_list(oc_part.clone()) { + Ok(part) => { + message.content.push(part); + } + Err(e) => { + warn!( + connector_id = %id, + message_id = %message.id, + error = %e, + "Failed to translate part" + ); + } + } + } + + debug!( + connector_id = %id, + message_id = %message.id, + parts_count = message.content.len(), + "Fetched user message content from API" + ); + } else { + warn!( + connector_id = %id, + message_id = %message.id, + session_id = %message.session_id, + "Message not found in list_messages response" + ); + } + } + Err(e) => { + warn!( + connector_id = %id, + message_id = %message.id, + session_id = %message.session_id, + error = %e, + "Failed to fetch user message content from API" + ); + // Continue with empty content - better than failing + } + } + } + } + + debug!(connector_id = %id, event = ?dirigent_event, "Translated to Dirigent event"); + + // Track if this is a MessageCompleted event for TurnComplete and SessionIdle emission + let should_emit_turn_complete = if let Event::MessageCompleted { message, .. } = &dirigent_event { + // Emit TurnComplete and SessionIdle after messages complete + // User messages also get MessageCompleted, but we emit for all for consistency + Some((message.session_id.clone(), message.id.clone())) + } else { + None + }; + + // Broadcast the event (via bus + per-connector channel) + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + dirigent_event, + ) + .await; + + // Emit TurnComplete then SessionIdle after MessageCompleted to signal completion + // Order: MessageCompleted → TurnComplete → SessionIdle + // This ensures archivist receives complete turn signal before idle flush + if let Some((session_id, message_id)) = should_emit_turn_complete { + // First emit TurnComplete with ExplicitSignal trigger + // (OpenCode session.idle events are explicit completion signals) + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::TurnComplete { + connector_id: id.clone(), + session_id: session_id.clone(), + message_id: message_id.clone(), + trigger: TurnCompleteTrigger::ExplicitSignal, + }, + ) + .await; + debug!( + connector_id = %id, + session_id = %session_id, + message_id = %message_id, + "Emitted TurnComplete with ExplicitSignal trigger" + ); + + // Then emit SessionIdle + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionIdle { + connector_id: id.to_string(), + session_id: session_id.clone(), + }, + ) + .await; + debug!(connector_id = %id, session_id = %session_id, "Emitted SessionIdle after TurnComplete"); + } + } + Err(dirigent_protocol::adapters::opencode::TranslationError::Duplicate) => { + // Skip duplicate events (this is normal) + debug!(connector_id = %id, "Skipped duplicate event"); + } + Err(e) => { + warn!(connector_id = %id, error = %e, "Failed to translate event"); + } + } + } + Some(Err(e)) => { + error!(connector_id = %id, error = %e, "SSE stream error"); + + // Update state to Error + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Error(format!("SSE stream error: {}", e)); + } + + // Emit state change + error events + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: format!("Error(SSE stream error: {})", e), + error_kind: None, + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("SSE stream error: {}", e), + }, + ) + .await; + + // Attempt to reconnect + break 'event_loop; + } + None => { + warn!(connector_id = %id, "SSE stream closed"); + + // Update state to Connecting (will reconnect) + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Connecting; + } + + // Emit state change + disconnected events + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: "Connecting".to_string(), + error_kind: None, + }, + ) + .await; + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Disconnected, + ) + .await; + + // Attempt to reconnect + break 'event_loop; + } + } + } + + // Handle commands + cmd = cmd_rx.recv() => { + match cmd { + Some(ConnectorCommand::ListSessions) => { + debug!(connector_id = %id, "Processing ListSessions command"); + + match client.list_sessions().await { + Ok(oc_sessions) => { + info!(connector_id = %id, count = oc_sessions.len(), "Listed sessions"); + + // Translate all sessions to Dirigent protocol sessions + let mut sessions = Vec::new(); + for oc_session in oc_sessions { + match adapter.translate_event(opencode_client::types::Event::SessionCreated { + properties: opencode_client::types::SessionEventInfo { + info: oc_session, + }, + }) { + Ok(Event::SessionCreated { session, .. }) => { + sessions.push(session); + } + Err(e) => { + warn!(connector_id = %id, error = %e, "Failed to translate session"); + } + _ => { + // Unexpected event type from translation + warn!(connector_id = %id, "Unexpected event type from session translation"); + } + } + } + + // Emit a single SessionsListed event with all sessions + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionsListed { + connector_id: id.clone(), + sessions, + }, + ) + .await; + } + Err(e) => { + error!(connector_id = %id, error = %e, "Failed to list sessions"); + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("Failed to list sessions: {}", e), + }, + ) + .await; + } + } + } + + Some(ConnectorCommand::ListMessages { session_id }) => { + debug!(connector_id = %id, session_id = %session_id, "Processing ListMessages command"); + + match client.list_messages(&session_id).await { + Ok(messages) => { + info!(connector_id = %id, session_id = %session_id, count = messages.len(), "Listed messages"); + + // Collect all translated messages with their parts + let mut dirigent_messages = Vec::new(); + + // Use a fresh adapter for batch translation to avoid deduplication issues + // The stateful adapter is for SSE streams, not for batch list operations + let batch_adapter = dirigent_protocol::adapters::opencode::OpenCodeAdapter::new(); + + for msg_with_parts in messages { + // Translate the message using fresh adapter + match batch_adapter.translate_event(opencode_client::types::Event::MessageUpdated { + properties: opencode_client::types::MessageEventInfo { + info: msg_with_parts.info.clone(), + }, + }) { + Ok(Event::MessageCompleted { message: mut msg, .. }) | Ok(Event::MessageStarted { message: mut msg, .. }) => { + // Translate all parts and add them to the message + for oc_part in msg_with_parts.parts { + // Directly translate part to MessagePart using translate_part + // This avoids the roundtrip through SessionUpdate + match dirigent_protocol::adapters::opencode::OpenCodeAdapter::translate_part_for_list(oc_part.clone()) { + Ok(part) => { + msg.content.push(part); + } + Err(e) => { + warn!(connector_id = %id, error = %e, "Failed to translate part"); + } + } + } + dirigent_messages.push(msg); + } + Ok(Event::SessionSystemMessageSet { .. }) => { + // Skip system message events when listing messages + // (they're already part of session metadata) + } + Err(dirigent_protocol::adapters::opencode::TranslationError::Duplicate) => { + // Skip duplicates + } + Err(e) => { + warn!(connector_id = %id, error = %e, "Failed to translate message"); + } + _ => { + // Unexpected event type from translation + warn!(connector_id = %id, "Unexpected event type from message translation"); + } + } + } + + // Calculate message count for metadata update + let message_count = dirigent_messages.len() as u32; + + // Emit a single MessagesListed event with all messages + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::MessagesListed { messages: dirigent_messages }, + ) + .await; + + // Emit SessionMetadataUpdated with the accurate count + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::SessionMetadataUpdated { + connector_id: id.clone(), + session_id: session_id.clone(), + title: None, // OpenCode doesn't support title updates yet + total_messages: Some(message_count), + model: None, + }, + ) + .await; + } + Err(e) => { + error!(connector_id = %id, session_id = %session_id, error = %e, "Failed to list messages"); + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("Failed to list messages: {}", e), + }, + ) + .await; + } + } + } + + Some(ConnectorCommand::CreateSession { .. }) => { + warn!(connector_id = %id, "CreateSession command not implemented for OpenCode connector"); + // OpenCode doesn't have a create session API endpoint + } + + Some(ConnectorCommand::LoadSession { session_id, .. }) => { + warn!(connector_id = %id, session_id = %session_id, "LoadSession command not implemented for OpenCode connector"); + // OpenCode doesn't have a session/load API endpoint + } + + Some(ConnectorCommand::SendMessage { session_id, text }) => { + debug!(connector_id = %id, session_id = %session_id, text_len = text.len(), "Processing SendMessage command"); + + match client.send_message(&session_id, text).await { + Ok(response) => { + info!(connector_id = %id, session_id = %session_id, "Message sent successfully"); + + // The response streaming is handled by SSE events + // We just log success here + debug!(connector_id = %id, response = ?response, "Send message response"); + } + Err(e) => { + error!(connector_id = %id, session_id = %session_id, error = %e, "Failed to send message"); + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::Error { + message: format!("Failed to send message: {}", e), + }, + ) + .await; + } + } + } + + Some(ConnectorCommand::CancelGeneration { .. }) => { + warn!(connector_id = %id, "CancelGeneration command not implemented for OpenCode connector"); + // OpenCode doesn't have a cancel generation API endpoint + } + + Some(ConnectorCommand::Reconnect) => { + info!(connector_id = %id, "Received Reconnect command, restarting SSE stream"); + + // Reset all retry counters and break to reconnect loop + retry_count = 0; + rapid_disconnect_count = 0; + connected_at = None; // Clear so stability check doesn't misclassify this + break 'event_loop; + } + + Some(ConnectorCommand::AgentResponse { .. }) => { + // OpenCode connector does not support agent-initiated requests + warn!( + connector_id = %id, + "Received AgentResponse command but OpenCode connector does not support agent-initiated requests" + ); + } + + Some(ConnectorCommand::SetSessionMode { .. }) => { + // OpenCode connector does not support mode switching + warn!( + connector_id = %id, + "Received SetSessionMode command but OpenCode connector does not support mode switching" + ); + } + + Some(ConnectorCommand::SetSessionModel { .. }) => { + // OpenCode connector does not support model switching + warn!( + connector_id = %id, + "Received SetSessionModel command but OpenCode connector does not support model switching" + ); + } + + Some(ConnectorCommand::CloseSession { session_id }) => { + warn!(connector_id = %id, session_id = %session_id, "CloseSession not supported by OpenCode connector"); + } + + Some(ConnectorCommand::SetConfigOption { .. }) => { + warn!( + connector_id = %id, + "Received SetConfigOption command but OpenCode connector does not support config options" + ); + } + + Some(ConnectorCommand::Shutdown) => { + info!(connector_id = %id, "Received Shutdown command, stopping connector"); + + // Update state to Stopped + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Stopped; + } + + // Exit both loops + break 'reconnect_loop; + } + + None => { + error!(connector_id = %id, "Command channel closed, stopping connector"); + + // Update state to Stopped + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Stopped; + } + + // Exit both loops + break 'reconnect_loop; + } + } + } + } + } + + // If we exited the event loop (not due to Shutdown), check connection stability + if let Some(connect_time) = connected_at.take() { + let uptime = connect_time.elapsed(); + if uptime >= stability_threshold { + // Established connection broke - reset rapid counter, retry immediately + info!( + connector_id = %id, + uptime_secs = uptime.as_secs(), + "Established connection lost, reconnecting immediately" + ); + rapid_disconnect_count = 0; + ever_stable = true; + } else { + // Rapid disconnect - connection died too quickly + let offline_threshold = if ever_stable { rapid_disconnect_max } else { 2 }; + if rapid_disconnect_count < offline_threshold { + let delay = rapid_disconnect_delays + [rapid_disconnect_count.min(rapid_disconnect_delays.len() - 1)]; + rapid_disconnect_count += 1; + warn!( + connector_id = %id, + rapid_disconnect_count, + rapid_disconnect_max, + uptime_secs = uptime.as_secs(), + delay_secs = delay.as_secs(), + "Rapid disconnect, backing off" + ); + + // Update state to Error while backing off + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Error(format!( + "Rapid disconnect #{}/{} (connected for {}s), retrying in {}s", + rapid_disconnect_count, + rapid_disconnect_max, + uptime.as_secs(), + delay.as_secs() + )); + } + { + let mut ek = error_kind.write().await; + *ek = Some(ConnectorErrorKind::Unstable); + } + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: format!("Error(Rapid disconnect #{}/{})", rapid_disconnect_count, rapid_disconnect_max), + error_kind: Some("unstable".to_string()), + }, + ) + .await; + emit_event(&sharing_bus, &events_tx, &id, connector_uid, Event::Disconnected).await; + + tokio::time::sleep(delay).await; + continue 'reconnect_loop; + } else { + error!( + connector_id = %id, + "Max rapid disconnects exceeded, pausing until manual Reconnect" + ); + + // Update state to Error + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Error( + "Connection unstable: too many rapid disconnects. Use Reconnect to retry.".to_string() + ); + } + let final_error_kind = if ever_stable { + ConnectorErrorKind::Unstable + } else { + ConnectorErrorKind::Offline + }; + let ek_str = match &final_error_kind { + ConnectorErrorKind::Offline => "offline", + ConnectorErrorKind::Unstable => "unstable", + ConnectorErrorKind::ConnectionFailed => "connection_failed", + }; + { + let mut ek = error_kind.write().await; + *ek = Some(final_error_kind); + } + + emit_event( + &sharing_bus, + &events_tx, + &id, + connector_uid, + Event::ConnectorStateChanged { + connector_id: id.clone(), + state: "Error(Connection unstable: too many rapid disconnects)".to_string(), + error_kind: Some(ek_str.to_string()), + }, + ) + .await; + emit_event(&sharing_bus, &events_tx, &id, connector_uid, Event::Disconnected).await; + + // Reuse the error-state command loop pattern + loop { + match cmd_rx.recv().await { + Some(ConnectorCommand::Reconnect) => { + info!(connector_id = %id, "Received Reconnect command, resetting all retry counters"); + retry_count = 0; + rapid_disconnect_count = 0; + { + let mut ek = error_kind.write().await; + *ek = None; + } + continue 'reconnect_loop; + } + Some(ConnectorCommand::Shutdown) => { + info!(connector_id = %id, "Received Shutdown command"); + break 'reconnect_loop; + } + Some(cmd) => { + warn!( + connector_id = %id, + command = ?cmd, + "Received command while in Error state, ignoring" + ); + } + None => { + error!(connector_id = %id, "Command channel closed"); + break 'reconnect_loop; + } + } + } + } + } + } else { + // No connected_at means we never fully connected - just retry + info!(connector_id = %id, "Event loop exited before connection established, attempting reconnect"); + } + } + + info!(connector_id = %id, "OpenCode connector task stopped"); + } +} + +impl Connector for OpenCodeConnector { + fn id(&self) -> &ConnectorId { + &self.id + } + + fn kind(&self) -> ConnectorKind { + ConnectorKind::OpenCode + } + + fn owner(&self) -> &UserId { + &self.owner + } + + fn title(&self) -> &str { + &self.title + } + + fn state(&self) -> ConnectorState { + // Try to read the state without blocking + match self.state.try_read() { + Ok(state_guard) => state_guard.clone(), + Err(_) => { + // Lock is held, return Initializing as a safe default + ConnectorState::Initializing + } + } + } + + fn command_tx(&self) -> mpsc::Sender<ConnectorCommand> { + self.cmd_tx.clone() + } + + fn subscribe(&self) -> broadcast::Receiver<Event> { + self.events_tx.subscribe() + } + + fn stop(&self) { + // Send shutdown command + let cmd_tx = self.cmd_tx.clone(); + tokio::spawn(async move { + let _ = cmd_tx.send(ConnectorCommand::Shutdown).await; + }); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_opencode_config_serialization() { + let config = OpenCodeConfig { + base_url: "http://localhost:12225".to_string(), + initial_session: Some("session-123".to_string()), + }; + + let json = serde_json::to_string(&config).unwrap(); + let deserialized: OpenCodeConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.base_url, config.base_url); + assert_eq!(deserialized.initial_session, config.initial_session); + } + + #[tokio::test] + async fn test_opencode_connector_creation() { + let config = OpenCodeConfig { + base_url: "http://localhost:12225".to_string(), + initial_session: None, + }; + + let connector = OpenCodeConnector::new( + "test-conn".to_string(), + uuid::Uuid::nil(), + "Test Connector".to_string(), + config, + SharingBus::new(), + ); + + assert_eq!(connector.id(), "test-conn"); + assert_eq!(*connector.owner(), uuid::Uuid::nil()); + assert_eq!(connector.title(), "Test Connector"); + assert_eq!(connector.kind(), ConnectorKind::OpenCode); + assert_eq!(connector.state(), ConnectorState::Initializing); + } + + #[tokio::test] + async fn test_opencode_connector_clone_command_tx() { + let config = OpenCodeConfig { + base_url: "http://localhost:12225".to_string(), + initial_session: None, + }; + + let connector = OpenCodeConnector::new( + "test-conn".to_string(), + uuid::Uuid::nil(), + "Test".to_string(), + config, + SharingBus::new(), + ); + + // Should be able to clone the command sender + let cmd_tx1 = connector.command_tx(); + let cmd_tx2 = connector.command_tx(); + + // Both should be valid senders + drop(cmd_tx1); + drop(cmd_tx2); + } + + #[tokio::test] + async fn test_opencode_connector_subscribe() { + let config = OpenCodeConfig { + base_url: "http://localhost:12225".to_string(), + initial_session: None, + }; + + let connector = OpenCodeConnector::new( + "test-conn".to_string(), + uuid::Uuid::nil(), + "Test".to_string(), + config, + SharingBus::new(), + ); + + // Should be able to create multiple subscriptions + let _rx1 = connector.subscribe(); + let _rx2 = connector.subscribe(); + } +} diff --git a/crates/dirigent_core/src/error.rs b/crates/dirigent_core/src/error.rs new file mode 100644 index 0000000..4d442c8 --- /dev/null +++ b/crates/dirigent_core/src/error.rs @@ -0,0 +1,163 @@ +//! Error types for dirigent_core +//! +//! This module defines the core error types used throughout the dirigent_core +//! orchestration engine. All operations that can fail should return a `Result<T, CoreError>`. + +use std::fmt; + +/// Core error type for dirigent_core operations +/// +/// This enum represents all possible errors that can occur during +/// core runtime operations, including connector management, configuration, +/// and authorization. +#[derive(Debug, Clone, PartialEq)] +pub enum CoreError { + /// Resource was not found + /// + /// Used when attempting to access a connector, user, or other resource + /// that doesn't exist in the runtime registry. + NotFound, + + /// Resource already exists + /// + /// Used when attempting to create a resource (e.g., connector) with an ID + /// that is already in use. + AlreadyExists, + + /// Invalid configuration provided + /// + /// Used when configuration validation fails, such as invalid URLs, + /// missing required fields, or malformed connector parameters. + InvalidConfig, + + /// Failed to start a connector or service + /// + /// Used when a connector fails to initialize or start its background task, + /// such as connection failures or missing dependencies. + StartFailed, + + /// Operation not authorized for the requesting user + /// + /// Used when ownership checks fail or the user doesn't have permission + /// to perform the requested operation on a resource. + Unauthorized, + + /// Internal error with descriptive message + /// + /// Used for unexpected errors, I/O failures, serialization errors, + /// or any other internal issues that don't fit other categories. + /// The string provides context about what went wrong. + Internal(String), + + /// Operation not valid in current state + /// + /// Used when attempting an operation on a resource that is in an + /// incompatible state (e.g., sending commands to an errored connector, + /// starting an already-running connector). + InvalidState, +} + +impl fmt::Display for CoreError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + CoreError::NotFound => { + write!(f, "Resource not found") + } + CoreError::AlreadyExists => { + write!(f, "Resource already exists") + } + CoreError::InvalidConfig => { + write!(f, "Invalid configuration provided") + } + CoreError::StartFailed => { + write!(f, "Failed to start connector or service") + } + CoreError::Unauthorized => { + write!(f, "Operation not authorized") + } + CoreError::Internal(msg) => { + write!(f, "Internal error: {}", msg) + } + CoreError::InvalidState => { + write!(f, "Operation not valid in current state") + } + } + } +} + +impl std::error::Error for CoreError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + // None of our error variants wrap other errors currently + None + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_error_display() { + assert_eq!(CoreError::NotFound.to_string(), "Resource not found"); + assert_eq!( + CoreError::AlreadyExists.to_string(), + "Resource already exists" + ); + assert_eq!( + CoreError::InvalidConfig.to_string(), + "Invalid configuration provided" + ); + assert_eq!( + CoreError::StartFailed.to_string(), + "Failed to start connector or service" + ); + assert_eq!( + CoreError::Unauthorized.to_string(), + "Operation not authorized" + ); + assert_eq!( + CoreError::Internal("test error".to_string()).to_string(), + "Internal error: test error" + ); + assert_eq!( + CoreError::InvalidState.to_string(), + "Operation not valid in current state" + ); + } + + #[test] + fn test_error_equality() { + assert_eq!(CoreError::NotFound, CoreError::NotFound); + assert_eq!(CoreError::AlreadyExists, CoreError::AlreadyExists); + assert_eq!( + CoreError::Internal("same".to_string()), + CoreError::Internal("same".to_string()) + ); + assert_ne!( + CoreError::Internal("different".to_string()), + CoreError::Internal("other".to_string()) + ); + assert_ne!(CoreError::NotFound, CoreError::AlreadyExists); + } + + #[test] + fn test_error_is_error_trait() { + // Verify that CoreError implements std::error::Error + fn assert_is_error<T: std::error::Error>() {} + assert_is_error::<CoreError>(); + } + + #[test] + fn test_error_clone() { + let error = CoreError::Internal("test".to_string()); + let cloned = error.clone(); + assert_eq!(error, cloned); + } + + #[test] + fn test_error_debug() { + let error = CoreError::NotFound; + let debug_str = format!("{:?}", error); + assert!(debug_str.contains("NotFound")); + } +} diff --git a/crates/dirigent_core/src/lib.rs b/crates/dirigent_core/src/lib.rs new file mode 100644 index 0000000..3050ecb --- /dev/null +++ b/crates/dirigent_core/src/lib.rs @@ -0,0 +1,116 @@ +//! Dirigent Core +//! +//! The core functionality of Dirigent: +//! - Setup and manage ACP (Agent-Client Protocol) agents +//! - Connect as ACP client to ACP agents +//! - Manage projects and sessions +//! - Handle file access and terminal operations +//! +//! # Core Runtime Architecture +//! +//! The runtime provides a centralized orchestrator for managing connectors: +//! - `CoreRuntime` - Main runtime that manages connector lifecycle +//! - `CoreHandle` - Cloneable handle to the runtime for easy sharing +//! - `Connector` - Trait for connector implementations +//! - `ConnectorHandle` - Handle to a running connector instance +//! +//! # Quick Start +//! +//! ```no_run +//! use dirigent_core::{CoreRuntime, CoreConfig, CoreHandle}; +//! +//! # async fn example() { +//! // Create runtime with default config +//! let runtime = CoreRuntime::new(CoreConfig::default()); +//! let handle = CoreHandle::new(runtime); +//! +//! // List connectors +//! let connectors = handle.list_connectors(None).await; +//! +//! // Subscribe to events via the SharingBus +//! let _bus_rx = handle.sharing_bus().subscribe_all().await; +//! # } +//! ``` + +// Core types module - always available +pub mod types; + +// Plugin system types (scaffolding) - always available +pub mod plugins; + +// Tool directive and configuration types - always available (WASM-compatible) +pub mod tools; + +// Re-export commonly used types (always available for WASM) +pub use types::{ + ConnectorErrorKind, ConnectorId, ConnectorKind, ConnectorState, ConnectorSummary, User, UserId, + UserProfile, +}; + +// Server-only modules and exports +#[cfg(feature = "server")] +mod error; +#[cfg(feature = "server")] +mod runtime; + +#[cfg(feature = "server")] +pub use error::CoreError; +#[cfg(feature = "server")] +pub use runtime::{CoreHandle, CoreRuntime}; + +// Re-export Zed agent → ConnectorConfig conversion (used by API) +#[cfg(feature = "server")] +pub use runtime::zed_detection::{refresh_zed_connector_binaries, zed_agent_to_connector_config}; + +// Configuration module (server-only) +#[cfg(feature = "server")] +pub mod config; + +// Connectors module - abstraction layer for agent system connections (server-only) +#[cfg(feature = "server")] +pub mod connectors; + +// Vendors module - vendor-specific knowledge (CLI detection, mappings, templates) +#[cfg(feature = "server")] +pub mod vendors; + +// ACP module - Agent-Client Protocol implementation (server-only) +#[cfg(feature = "server")] +pub mod acp; + +// Re-export configuration types (server-only) +#[cfg(feature = "server")] +pub use config::{ + apply_template, resolve_default_runtime_working_directory, AcpServerConfig, ConnectorConfig, + CoreConfig, TaskConfig, +}; + +// Re-export connector abstractions (server-only) +#[cfg(feature = "server")] +pub use connectors::{Connector, ConnectorCommand, ConnectorHandle}; + +// Re-export connector implementations (server-only) +#[cfg(feature = "server")] +pub use connectors::opencode::{OpenCodeConfig, OpenCodeConnector}; + +// Re-export acceptor types for incoming connections (server-only) +#[cfg(feature = "server")] +pub use connectors::acceptor::{ + Acceptor, AcceptorError, AcceptorHandle, AcceptorId, AcceptorSummary, AcpAcceptor, + IncomingSession, SessionRouting, +}; + +// Re-export gateway connector types (server-only) +#[cfg(feature = "server")] +pub use connectors::gateway::{ + get_transient_commands, merge_with_transient_commands, Command as GatewayCommand, + CommandResult, ConnectorListCallback, ConnectorSummaryInfo, EchoConfig, GatewayConfig, + GatewayConnector, GatewaySession, SessionTransferCallback, +}; + +// Session sharing abstraction (server-only) +#[cfg(feature = "server")] +pub mod sharing; +#[cfg(feature = "server")] +pub use dirigent_protocol::sharing::{SessionShare, ShareId, ShareSummary}; + diff --git a/crates/dirigent_core/src/plugins/definition.rs b/crates/dirigent_core/src/plugins/definition.rs new file mode 100644 index 0000000..e41c558 --- /dev/null +++ b/crates/dirigent_core/src/plugins/definition.rs @@ -0,0 +1,30 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum PluginKind { + Observer, + Modifier, + Provider, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PluginDefinition { + pub id: String, + pub name: String, + pub kind: PluginKind, + #[serde(default)] + pub config: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct PluginAssignment { + pub plugin_id: String, + pub enabled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ProjectAssignment { + pub project_id: uuid::Uuid, + pub assigned: bool, +} diff --git a/crates/dirigent_core/src/plugins/mod.rs b/crates/dirigent_core/src/plugins/mod.rs new file mode 100644 index 0000000..dbfeb51 --- /dev/null +++ b/crates/dirigent_core/src/plugins/mod.rs @@ -0,0 +1,7 @@ +//! Plugin system types (scaffolding). +//! Defines plugin definitions and assignments for connectors and sessions. +//! The runtime execution layer is not yet implemented. + +mod definition; + +pub use definition::{PluginAssignment, PluginDefinition, PluginKind, ProjectAssignment}; diff --git a/crates/dirigent_core/src/runtime/config_manager.rs b/crates/dirigent_core/src/runtime/config_manager.rs new file mode 100644 index 0000000..c210b88 --- /dev/null +++ b/crates/dirigent_core/src/runtime/config_manager.rs @@ -0,0 +1,274 @@ +//! Configuration Management +//! +//! This module handles saving and updating runtime configuration. +//! It provides atomic file writes and supports JSON/TOML formats. + +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{debug, error, info, warn}; + +use crate::config::{AcpServerConfig, CoreConfig}; +use crate::CoreError; + +/// Recursively strip JSON null values from a serde_json::Value. +/// +/// TOML has no null type, so `serde_json::Value::Null` serializes as `()` +/// which the toml crate rejects with "unsupported unit type". This function +/// removes null entries from objects and arrays before TOML serialization. +pub fn strip_json_nulls(value: &mut serde_json::Value) { + match value { + serde_json::Value::Object(map) => { + map.retain(|_, v| !v.is_null()); + for v in map.values_mut() { + strip_json_nulls(v); + } + } + serde_json::Value::Array(arr) => { + arr.retain(|v| !v.is_null()); + for v in arr.iter_mut() { + strip_json_nulls(v); + } + } + _ => {} + } +} + +/// Check whether a config path is project-local (relative or under CWD). +fn is_project_local(path: &Path) -> bool { + path.is_relative() || path.starts_with(std::env::current_dir().unwrap_or_default()) +} + +/// Promote a project-local config path to the user config directory, +/// preserving the file extension. Returns the original path if the +/// user config directory cannot be resolved. +fn promote_to_user_dir(source: &Path) -> PathBuf { + if let Ok(paths) = dirigent_config::DirigentPaths::resolve() { + let user_dir = paths.config_dir().to_path_buf(); + let _ = std::fs::create_dir_all(&user_dir); + let ext = source + .extension() + .and_then(|e| e.to_str()) + .unwrap_or("toml"); + user_dir.join(format!("dirigent.{}", ext)) + } else { + source.to_path_buf() + } +} + +/// Return the default user config path for new saves. +fn default_user_config_path() -> PathBuf { + if let Ok(paths) = dirigent_config::DirigentPaths::resolve() { + let user_dir = paths.config_dir().to_path_buf(); + let _ = std::fs::create_dir_all(&user_dir); + user_dir.join("dirigent.toml") + } else { + PathBuf::from("dirigent.json") + } +} + +/// Update the ACP server configuration +/// +/// This updates the in-memory configuration. Call `save_config` to persist. +pub async fn update_acp_server_config( + config: &Arc<RwLock<CoreConfig>>, + acp_config: AcpServerConfig, +) -> Result<(), CoreError> { + let mut cfg = config.write().await; + cfg.acp_server = Some(acp_config); + info!("Updated ACP Server configuration"); + Ok(()) +} + +/// Save the current configuration to disk +/// +/// This method serializes the current runtime configuration to JSON or TOML format +/// and writes it to the filesystem atomically (write to temp file, then rename). +/// +/// The save path is determined as follows: +/// 1. The file the config was originally loaded from (`config_source_path`), +/// promoted to the user config directory if it was project-local +/// 2. User config directory default (`~/.config/dirigent/dirigent.toml`) +/// +/// Project-local config files (relative paths or under CWD) are treated as +/// read-only: on save the config is promoted to the user config directory +/// so that the original project-local file is never modified. +/// +/// The file format is determined by the extension of the save path. +pub async fn save_config(config: &Arc<RwLock<CoreConfig>>) -> Result<(), CoreError> { + use std::io::Write; + + // Lock config for read and create a sanitized copy for saving. + // Zed-sourced connectors (source == "zed") are transient runtime entries. + // They should not be persisted — stripping them prevents ghost duplicates. + let config_guard = config.read().await; + let mut save_config = config_guard.clone(); + let zed_count = save_config + .connectors + .iter() + .filter(|c| c.source.as_deref() == Some("zed")) + .count(); + if zed_count > 0 { + debug!( + count = zed_count, + "Stripping transient Zed-sourced connectors from saved config" + ); + save_config + .connectors + .retain(|c| c.source.as_deref() != Some("zed")); + } + // Capture the source path before releasing the lock + let source_path = config_guard.config_source_path.clone(); + drop(config_guard); // Release lock early — we work from the cloned copy + + // Determine save path with promotion logic: + // - Project-local sources are promoted to the user config directory + // - Explicit user-dir or env-var paths are used as-is + let save_path = if let Some(ref source) = source_path { + if is_project_local(source) { + let promoted = promote_to_user_dir(source); + info!( + from = %source.display(), + to = %promoted.display(), + "Promoting project-local config to user config directory" + ); + promoted + } else { + source.clone() + } + } else { + default_user_config_path() + }; + + // Update config_source_path so subsequent saves go to the same location + if source_path.as_ref() != Some(&save_path) { + config.write().await.config_source_path = Some(save_path.clone()); + } + + info!(path = ?save_path, "Saving configuration"); + + // Determine format by extension + let contents = match save_path.extension().and_then(|s| s.to_str()) { + Some("toml") => { + debug!("Serializing config as TOML"); + // TOML has no null type — strip JSON nulls from connector params + for connector in &mut save_config.connectors { + strip_json_nulls(&mut connector.params); + } + toml::to_string_pretty(&save_config).map_err(|e| { + error!(error = %e, "Failed to serialize config as TOML"); + CoreError::Internal(format!("Failed to serialize config as TOML: {}", e)) + })? + } + Some("json") | None => { + debug!("Serializing config as JSON"); + serde_json::to_string_pretty(&save_config).map_err(|e| { + error!(error = %e, "Failed to serialize config as JSON"); + CoreError::Internal(format!("Failed to serialize config as JSON: {}", e)) + })? + } + Some(ext) => { + warn!( + extension = ext, + "Unsupported config file extension, defaulting to JSON" + ); + serde_json::to_string_pretty(&save_config).map_err(|e| { + error!(error = %e, "Failed to serialize config as JSON"); + CoreError::Internal(format!("Failed to serialize config as JSON: {}", e)) + })? + } + }; + + // Ensure the parent directory exists (critical for first save to user config dir) + if let Some(parent) = save_path.parent() { + if !parent.exists() { + std::fs::create_dir_all(parent).map_err(|e| { + error!(error = %e, path = ?parent, "Failed to create config directory"); + CoreError::Internal(format!( + "Failed to create config directory {}: {}", + parent.display(), + e + )) + })?; + info!(path = ?parent, "Created config directory"); + } + } + + // Write atomically: write to temp file, then rename + let temp_path = save_path.with_extension("tmp"); + debug!(temp_path = ?temp_path, "Writing to temporary file"); + + // Write to temp file + { + let mut file = std::fs::File::create(&temp_path).map_err(|e| { + error!(error = %e, path = ?temp_path, "Failed to create temp config file"); + CoreError::Internal(format!( + "Failed to create temp config file {}: {}", + temp_path.display(), + e + )) + })?; + + file.write_all(contents.as_bytes()).map_err(|e| { + error!(error = %e, path = ?temp_path, "Failed to write to temp config file"); + CoreError::Internal(format!( + "Failed to write to temp config file {}: {}", + temp_path.display(), + e + )) + })?; + + file.sync_all().map_err(|e| { + error!(error = %e, path = ?temp_path, "Failed to sync temp config file"); + CoreError::Internal(format!( + "Failed to sync temp config file {}: {}", + temp_path.display(), + e + )) + })?; + } + + // Rename temp file to final path + debug!(from = ?temp_path, to = ?save_path, "Renaming temp file to final path"); + std::fs::rename(&temp_path, &save_path).map_err(|e| { + error!(error = %e, from = ?temp_path, to = ?save_path, "Failed to rename temp file"); + CoreError::Internal(format!( + "Failed to rename temp file {} to {}: {}", + temp_path.display(), + save_path.display(), + e + )) + })?; + + info!(path = ?save_path, "Configuration saved successfully"); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::CoreConfig; + use std::sync::Arc; + use tokio::sync::RwLock; + + #[tokio::test] + async fn test_update_acp_server_config() { + let config = Arc::new(RwLock::new(CoreConfig::default())); + + let acp_config = AcpServerConfig { + enabled: true, + port: Some(3001), + allowed_origins: None, + max_connections: 100, + default_connector_id: None, + }; + + let result = update_acp_server_config(&config, acp_config).await; + assert!(result.is_ok()); + + let cfg = config.read().await; + assert!(cfg.acp_server.is_some()); + assert_eq!(cfg.acp_server.as_ref().unwrap().port, Some(3001)); + } +} diff --git a/crates/dirigent_core/src/runtime/mod.rs b/crates/dirigent_core/src/runtime/mod.rs new file mode 100644 index 0000000..a3b02e9 --- /dev/null +++ b/crates/dirigent_core/src/runtime/mod.rs @@ -0,0 +1,4591 @@ +//! Core runtime for connector orchestration +//! +//! This module implements the `CoreRuntime`, which is the central orchestrator +//! for managing long-lived connectors and routing events in the Dirigent system. +//! +//! # Architecture +//! +//! The CoreRuntime: +//! - Maintains a registry of active connectors keyed by ConnectorId +//! - Provides lifecycle management (create, start, stop, remove) +//! - Routes commands to specific connectors +//! - Broadcasts global events to all subscribers +//! - Enforces basic ownership and authorization +//! +//! # Module Organization +//! +//! The runtime functionality is split across focused submodules: +//! - `config_manager` - Configuration persistence and updates +//! - `summary_cache` - Connector summary cache for sync callbacks +//! +//! # Usage +//! +//! ```no_run +//! use dirigent_core::{CoreRuntime, CoreConfig, CoreHandle}; +//! +//! # async fn example() { +//! // Create a runtime with default config +//! let config = CoreConfig::default(); +//! let runtime = CoreRuntime::new(config, None); +//! let handle = CoreHandle::new(runtime); +//! +//! // List all connectors (none initially) +//! let connectors = handle.list_connectors(None).await; +//! assert_eq!(connectors.len(), 0); +//! +//! // Subscribe to events via the SharingBus +//! let _bus_rx = handle.sharing_bus().subscribe_all().await; +//! # } +//! ``` + +// Submodules +pub(crate) mod config_manager; +mod session_transfer; +mod summary_cache; +#[cfg(feature = "server")] +pub mod zed_detection; + +use crate::config::{AcpServerConfig, ConnectorConfig, CoreConfig}; +use crate::connectors::acp::{AcpConfig, AcpConnector}; +use crate::connectors::gateway::{ + ConnectorSummaryInfo, SessionTransferCallback, SessionTransferRequest, SessionTransferResult, +}; +use crate::connectors::{Connector, ConnectorCommand, ConnectorHandle}; +use crate::sharing::bus::SharingBus; +use crate::types::{ConnectorId, ConnectorKind, ConnectorState, ConnectorSummary, User, UserId}; +use crate::{CoreError, OpenCodeConfig, OpenCodeConnector}; + +use std::collections::HashMap; +use std::ops::Deref; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::RwLock; +use tracing::{debug, error, info, warn}; + +/// Error type for [`CoreRuntime::replay_session_to_stream`]. +/// +/// Wraps the standalone [`crate::sharing::replay::ReplayError`] with the +/// two runtime-level preconditions that the replay fn itself has no way to +/// check (missing stream id, missing archivist). +#[cfg(feature = "server")] +#[derive(Debug, thiserror::Error)] +pub enum ReplaySessionError { + /// No stream with the given id is currently registered. + #[error("stream not found: {0:?}")] + StreamNotFound(crate::sharing::StreamId), + /// The runtime has no archivist configured, so there's nothing to + /// replay from. + #[error("no archivist configured")] + NoArchivist, + /// The replay itself failed (session missing, archive I/O, …). + #[error(transparent)] + Replay(#[from] crate::sharing::replay::ReplayError), +} + +/// Tracks an active session transfer for fallback purposes +#[derive(Clone, Debug)] +struct TransferredSession { + /// Client's session (in Gateway before transfer) + gateway_session_id: String, + /// Gateway connector ID (for fallback) + gateway_connector_id: String, + /// Target connector that received the session + target_connector_id: String, + /// Session ID in target connector + target_session_id: String, + /// When the transfer occurred + transferred_at: std::time::Instant, +} + +/// Core runtime for connector orchestration +/// +/// The CoreRuntime is the central component that manages all connectors +/// in the Dirigent system. It maintains: +/// - A registry of active connectors +/// - Configuration state +/// - A global event broadcast channel +/// - User registry (for ownership and authorization) +/// +/// # Thread Safety +/// +/// CoreRuntime uses tokio's async-aware RwLock for interior mutability, +/// allowing safe concurrent access from multiple async tasks. Read operations +/// can proceed in parallel, while write operations require exclusive access. +/// +/// # Lifecycle +/// +/// 1. Create with `CoreRuntime::new(config, None)` +/// 2. Wrap in `CoreHandle` for cheap cloning +/// 3. Create and manage connectors via runtime methods +/// 4. Subscribe to global events for system-wide event streaming +pub struct CoreRuntime { + /// Runtime configuration + /// + /// Wrapped in Arc<RwLock> to allow dynamic updates while maintaining + /// shared ownership across clones. + config: Arc<RwLock<CoreConfig>>, + + /// Registry of active connectors + /// + /// Maps ConnectorId to ConnectorHandle. Protected by RwLock to allow + /// concurrent reads (list/get operations) while serializing writes + /// (create/remove operations). + connectors: RwLock<HashMap<ConnectorId, ConnectorHandle>>, + + /// Primary event fan-out: a filtered, routing-aware event bus. + /// + /// All connector and runtime events are published onto the `SharingBus`. + /// Subscribers pick an `EventFilter` and receive only matching events. + /// See `docs/plans/2026-04-21-archivist-phase4-design.md`. + sharing_bus: Arc<SharingBus>, + + /// Registry of users + /// + /// Maps UserId to User. Currently a simple in-memory registry. + /// Future versions will integrate with authentication providers. + // TODO: Multi-user support - user registry for authorization and ownership tracking + #[allow(dead_code)] + users: RwLock<HashMap<UserId, User>>, + + /// Optional archivist for persistent storage + /// + /// If Some, the archivist coordinates connector UID generation and + /// persists session/message data. If None, the runtime operates + /// in ephemeral mode with no archival storage. + #[cfg(feature = "server")] + archivist: Arc<tokio::sync::RwLock<Option<Arc<dirigent_archivist::Archivist>>>>, + + /// Mapping from connector_id (user-facing) to archivist connector_uid (internal UUID) + /// + /// This mapping enables archive-first reads to work with non-UUID connector IDs. + /// Only populated when archivist is available and connectors are successfully registered. + #[cfg(feature = "server")] + archivist_connector_uids: Arc<RwLock<HashMap<String, uuid::Uuid>>>, + + /// Sync-accessible cache of connector summaries for use by GatewayConnector callbacks + /// + /// This cache is updated whenever connectors are added or removed. It uses a sync RwLock + /// (std::sync::RwLock) so that it can be accessed from synchronous callback functions. + /// The cache is wrapped in Arc so it can be shared with callbacks. + connector_summary_cache: Arc<std::sync::RwLock<Vec<ConnectorSummaryInfo>>>, + + /// Weak reference to self (set when wrapped in Arc by CoreHandle) + /// + /// This allows creating callbacks that need to call runtime methods asynchronously. + /// The weak reference prevents circular references and allows the runtime to be dropped. + self_weak: Arc<RwLock<Option<std::sync::Weak<CoreRuntime>>>>, + + /// Tracks sessions that were transferred, for fallback handling + /// + /// Maps target_session_id -> TransferredSession. When a connector fails, + /// we check this map to see if any transferred sessions need to fall back + /// to Gateway with user notification. + transferred_sessions: Arc<RwLock<HashMap<String, TransferredSession>>>, + + /// Optional inspector registry for internal visualization and introspection. + /// + /// When present, connectors and services register themselves as nodes + /// in the inspector tree, enabling monitoring of state, process health, + /// and system resources. + #[cfg(feature = "server")] + inspector: Option<Arc<dirigent_inspector::InspectorRegistry>>, + + /// Optional process group manager for lifecycle management of child processes. + /// + /// When present, ACP connectors using StdioTransport will create per-process + /// lifecycle handles so spawned agents are tracked in the platform job object / + /// process group and are shut down gracefully on disconnect. + #[cfg(feature = "server")] + process_manager: Option<Arc<dyn dirigent_process::ProcessGroupManager>>, + + /// Optional task runner for background process management + #[cfg(feature = "server")] + task_runner: Arc<tokio::sync::RwLock<Option<Arc<dirigent_taskrunner::TaskRunner>>>>, + + /// Optional Matrix service for session sharing + #[cfg(feature = "server")] + matrix_service: Arc<tokio::sync::RwLock<Option<Arc<dirigent_matrix::MatrixService>>>>, + + /// Registry of active streams wired to the `SharingBus`. + /// + /// Streams are attached at boot from `[[streams]]` config, or at runtime + /// via [`CoreRuntime::attach_stream`]. See + /// `docs/plans/2026-04-21-archivist-phase4-design.md`. + stream_registry: Arc<crate::sharing::StreamRegistry>, + + /// Factory lookup used to build concrete streams from `StreamConfig` + /// blocks. Populated by callers at boot with their compiled-in factories; + /// defaults to an empty registry for tests and pre-Phase-4 call sites. + stream_factories: Arc<crate::sharing::StreamFactoryRegistry>, +} + +impl CoreRuntime { + /// Create a new CoreRuntime with the given configuration + /// + /// This initializes all runtime state with defaults: + /// - Empty connector registry + /// - Global event channel with capacity 1000 + /// - Empty user registry + /// - Optional archivist for persistent storage + /// + /// # Arguments + /// + /// * `config` - Core configuration (port, project directory, etc.) + /// * `archivist` - Optional archivist for persistence (server feature only) + /// + /// # Returns + /// + /// A new CoreRuntime ready to manage connectors + /// + /// # Example + /// + /// ```no_run + /// use dirigent_core::{CoreRuntime, CoreConfig}; + /// + /// let config = CoreConfig::default(); + /// let runtime = CoreRuntime::new(config, None); + /// ``` + #[cfg(feature = "server")] + pub fn new(config: CoreConfig, archivist: Option<Arc<dirigent_archivist::Archivist>>) -> Self { + Self::new_with_inspector(config, archivist, None) + } + + /// Create a new CoreRuntime with archivist and inspector. + /// + /// This is a thin wrapper around [`Self::new_with_factories`] that + /// supplies an empty [`StreamFactoryRegistry`]. Callers that want to + /// attach streams at runtime should use `new_with_factories` directly + /// and pre-register their compiled-in factories. + #[cfg(feature = "server")] + pub fn new_with_inspector( + config: CoreConfig, + archivist: Option<Arc<dirigent_archivist::Archivist>>, + inspector: Option<Arc<dirigent_inspector::InspectorRegistry>>, + ) -> Self { + Self::new_with_factories( + config, + archivist, + inspector, + Arc::new(crate::sharing::StreamFactoryRegistry::default()), + ) + } + + /// Create a new CoreRuntime with archivist, inspector, and a + /// pre-populated [`StreamFactoryRegistry`]. + /// + /// Callers at boot build the factory registry with their compiled-in + /// `StreamFactory` implementations (e.g. Matrix, Langfuse, …) and hand + /// it in here so that [`CoreRuntime::attach_stream`] can resolve + /// `StreamConfig::kind` strings without further plumbing. + /// + /// Existing call sites that don't yet know about streams should keep + /// calling [`Self::new`] / [`Self::new_with_inspector`], which forward + /// to this constructor with an empty factory registry. + #[cfg(feature = "server")] + pub fn new_with_factories( + config: CoreConfig, + archivist: Option<Arc<dirigent_archivist::Archivist>>, + inspector: Option<Arc<dirigent_inspector::InspectorRegistry>>, + stream_factories: Arc<crate::sharing::StreamFactoryRegistry>, + ) -> Self { + // Primary event fan-out. All connector/runtime events flow through + // this bus; subscribers choose an `EventFilter` to receive only the + // events they care about. + let sharing_bus = SharingBus::new(); + let stream_registry = Arc::new(crate::sharing::StreamRegistry::new(Arc::clone(&sharing_bus))); + + Self { + config: Arc::new(RwLock::new(config)), + connectors: RwLock::new(HashMap::new()), + sharing_bus, + users: RwLock::new(HashMap::new()), + archivist: Arc::new(tokio::sync::RwLock::new(archivist)), + archivist_connector_uids: Arc::new(RwLock::new(HashMap::new())), + connector_summary_cache: Arc::new(std::sync::RwLock::new(Vec::new())), + self_weak: Arc::new(RwLock::new(None)), + transferred_sessions: Arc::new(RwLock::new(HashMap::new())), + inspector, + process_manager: None, + task_runner: Arc::new(tokio::sync::RwLock::new(None)), + matrix_service: Arc::new(tokio::sync::RwLock::new(None)), + stream_registry, + stream_factories, + } + } + + /// Create a new CoreRuntime without archivist (non-server builds) + #[cfg(not(feature = "server"))] + pub fn new(config: CoreConfig) -> Self { + let sharing_bus = SharingBus::new(); + let stream_registry = Arc::new(crate::sharing::StreamRegistry::new(Arc::clone(&sharing_bus))); + let stream_factories = Arc::new(crate::sharing::StreamFactoryRegistry::default()); + + Self { + config: Arc::new(RwLock::new(config)), + connectors: RwLock::new(HashMap::new()), + sharing_bus, + users: RwLock::new(HashMap::new()), + connector_summary_cache: Arc::new(std::sync::RwLock::new(Vec::new())), + self_weak: Arc::new(RwLock::new(None)), + transferred_sessions: Arc::new(RwLock::new(HashMap::new())), + stream_registry, + stream_factories, + } + } + + /// Set the process group manager used for lifecycle management of stdio agent processes. + /// + /// Call this before creating ACP connectors to ensure newly spawned processes are + /// tracked in the platform job object (Windows) or process group (Unix). After this + /// is set, every new `StdioTransport` will receive a per-process lifecycle handle. + #[cfg(feature = "server")] + pub fn set_process_manager(&mut self, mgr: Arc<dyn dirigent_process::ProcessGroupManager>) { + self.process_manager = Some(mgr); + } + + /// Get a reference to the process group manager, if configured. + #[cfg(feature = "server")] + pub fn process_manager(&self) -> Option<&Arc<dyn dirigent_process::ProcessGroupManager>> { + self.process_manager.as_ref() + } + + /// List all connectors, optionally filtered by owner + /// + /// Returns a summary of all connectors in the registry. If an owner + /// is specified, only connectors owned by that user are returned. + /// + /// # Arguments + /// + /// * `owner` - Optional user ID to filter by ownership + /// + /// # Returns + /// + /// A vector of `ConnectorSummary` structs containing id, kind, owner, + /// title, and current state for each matching connector. + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig}; + /// # async fn example(runtime: &CoreRuntime) { + /// // List all connectors + /// let all = runtime.list_connectors(None).await; + /// + /// // List connectors for a specific user + /// let user_connectors = runtime.list_connectors(Some(uuid::Uuid::nil())).await; + /// # } + /// ``` + pub async fn list_connectors(&self, owner: Option<UserId>) -> Vec<ConnectorSummary> { + let connectors = self.connectors.read().await; + + // Collect handles first, then build summaries with async config reading + let filtered_handles: Vec<_> = connectors + .values() + .filter(|handle| { + // Filter by owner if provided + if let Some(ref owner_id) = owner { + handle.owner() == owner_id + } else { + true + } + }) + .cloned() + .collect(); + + // Drop the connectors lock before async operations + drop(connectors); + + let mut summaries = Vec::with_capacity(filtered_handles.len()); + + for handle in filtered_handles { + // Extract supported_features and agent_type based on connector type + let (supported_features, agent_type) = match handle.kind() { + ConnectorKind::OpenCode => { + // OpenCode connectors support cancellation via stop_session, + // session resume, and session listing from the connector API + ( + vec![ + "cancellation".to_string(), + "session_resume".to_string(), + "session_list".to_string(), + ], + None, + ) + } + ConnectorKind::Acp => { + // Read supported_features and agent_type from the persisted ConnectorConfig + // (not from handle.get_config_cloned() which returns AcpConfig params + // that don't contain supported_features) + let config_guard = self.config.read().await; + let (features, agent_type) = config_guard + .connectors + .iter() + .find(|c| c.id.as_deref() == Some(handle.id().as_str())) + .map(|c| { + let features = c.supported_features.clone(); + let agent_type = serde_json::from_value::<crate::connectors::acp::config::AcpConfig>( + c.params.clone(), + ) + .ok() + .map(|cfg| cfg.agent_type); + (features, agent_type) + }) + .unwrap_or_default(); + drop(config_guard); + (features, agent_type) + } + ConnectorKind::Gateway => { + // Gateway supports session resume via Archivist and session listing + ( + vec![ + "session_resume".to_string(), + "session_list".to_string(), + ], + None, + ) + } + ConnectorKind::Acceptor | ConnectorKind::Mock => { + // These connectors don't have special features + (vec![], None) + } + }; + + // Determine if archiving can be toggled for this connector + // - OpenCode, ACP, Gateway: archiving can always be toggled + // - Mock, Acceptor: archiving toggle not supported (test/routing connectors) + let archiving_toggleable = match handle.kind() { + ConnectorKind::OpenCode | ConnectorKind::Acp | ConnectorKind::Gateway => true, + ConnectorKind::Mock | ConnectorKind::Acceptor => false, + }; + + // Look up tool_configuration and other config fields from saved config + let (tool_configuration, plugin_assignments, project_assignments, use_in_new_projects, source, zed_agent_name) = { + let config_lock = self.config.read().await; + config_lock + .connectors + .iter() + .find(|c| c.id.as_deref() == Some(handle.id().as_str())) + .map(|c| ( + c.tool_configuration.clone(), + c.plugin_assignments.clone(), + vec![], // project_assignments live on ConnectorSummary, not ConnectorConfig + c.use_in_new_projects, + c.source.clone(), + c.zed_agent_name.clone(), + )) + .unwrap_or((None, vec![], vec![], true, None, None)) + }; + + summaries.push(ConnectorSummary { + id: handle.id().clone(), + kind: handle.kind(), + owner: handle.owner().clone(), + title: handle.title().to_string(), + state: handle.state(), + working_directory: handle.working_directory().map(|p| p.display().to_string()), + supported_features, + icon_path: handle.icon_path().map(|s| s.to_string()), + show_type_overlay: handle.show_type_overlay(), + archiving_toggleable, + agent_type, + tool_configuration, + plugin_assignments, + project_assignments, + use_in_new_projects, + source, + zed_agent_name, + error_kind: handle.error_kind(), + }); + } + + summaries + } + + /// Get a handle to a specific connector by ID + /// + /// Returns a cloned ConnectorHandle if the connector exists, or None + /// if no connector with the given ID is registered. + /// + /// # Arguments + /// + /// * `id` - The unique identifier of the connector to retrieve + /// + /// # Returns + /// + /// `Some(ConnectorHandle)` if found, `None` otherwise + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig}; + /// # use dirigent_core::connectors::Connector; + /// # async fn example(runtime: &CoreRuntime) { + /// if let Some(handle) = runtime.get_connector(&"conn-123".to_string()).await { + /// println!("Found connector: {}", handle.title()); + /// println!("State: {:?}", handle.state()); + /// } else { + /// println!("Connector not found"); + /// } + /// # } + /// ``` + pub async fn get_connector(&self, id: &ConnectorId) -> Option<ConnectorHandle> { + let connectors = self.connectors.read().await; + connectors.get(id).cloned() + } + + /// Get the inspector registry, if configured. + #[cfg(feature = "server")] + pub fn inspector(&self) -> Option<&Arc<dirigent_inspector::InspectorRegistry>> { + self.inspector.as_ref() + } + + /// Create a new connector from configuration + /// + /// This method instantiates a connector of the specified type with the given + /// configuration. The connector is created but not started - call `start_connector` + /// to begin its background task. + /// + /// # Arguments + /// + /// * `owner` - User ID of the connector owner (for authorization) + /// * `cfg` - Connector configuration including type and parameters + /// + /// # Returns + /// + /// The unique ID of the created connector, or an error if creation failed. + /// + /// # Errors + /// + /// - `AlreadyExists` - A connector with the specified ID already exists + /// - `InvalidConfig` - Configuration validation failed or deserialization error + /// - `Internal` - Other errors during connector instantiation + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig, ConnectorConfig, ConnectorKind}; + /// # use serde_json::json; + /// # async fn example(runtime: &CoreRuntime) -> Result<(), Box<dyn std::error::Error>> { + /// let cfg = ConnectorConfig { + /// id: None, + /// kind: ConnectorKind::OpenCode, + /// owner: None, + /// title: Some("My Connector".to_string()), + /// params: json!({ + /// "base_url": "http://localhost:12225", + /// "title": "My Connector", + /// "initial_session": null + /// }), + /// }; + /// + /// let connector_id = runtime.create_connector(uuid::Uuid::now_v7(), cfg).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn create_connector( + &self, + owner: UserId, + mut cfg: ConnectorConfig, + ) -> Result<ConnectorId, CoreError> { + // Generate ID if not provided + let connector_id = cfg + .id + .clone() + .unwrap_or_else(|| uuid::Uuid::now_v7().to_string()); + + // Log connector count before creation + let count_before = { + let connectors = self.connectors.read().await; + connectors.len() + }; + + info!( + connector_id = %connector_id, + kind = ?cfg.kind, + owner = %owner, + connector_count = count_before, + "Creating connector (current count: {})", count_before + ); + + // Check if connector already exists + { + let connectors = self.connectors.read().await; + if connectors.contains_key(&connector_id) { + error!(connector_id = %connector_id, "Connector already exists"); + return Err(CoreError::AlreadyExists); + } + } + + // Override owner from config with the provided owner + cfg.owner = Some(owner.clone()); + + // Match on connector kind and instantiate the appropriate type + // For OpenCode connectors, we start the task immediately since they can't be restarted + let handle = match cfg.kind { + ConnectorKind::OpenCode => { + // Deserialize params to OpenCodeConfig + let oc_config: OpenCodeConfig = serde_json::from_value(cfg.params.clone()) + .map_err(|e| { + error!( + connector_id = %connector_id, + error = %e, + "Failed to deserialize OpenCodeConfig" + ); + CoreError::InvalidConfig + })?; + + debug!( + connector_id = %connector_id, + base_url = %oc_config.base_url, + "Creating and starting OpenCode connector" + ); + + // Get title from main config, with fallback + let title = cfg.title.clone().unwrap_or_else(|| "OpenCode".to_string()); + + // Create the connector + let oc_connector = OpenCodeConnector::new( + connector_id.clone(), + owner.clone(), + title, + oc_config, + Arc::clone(&self.sharing_bus), + ); + + // Resolve working directory using global config + let config_lock = self.config.read().await; + let working_directory = crate::resolve_default_runtime_working_directory(&cfg, &config_lock); + + // Create a handle from the connector's channels + // IMPORTANT: Use new_with_state to share the same state Arc + let mut handle = ConnectorHandle::new_with_state( + connector_id.clone(), + oc_connector.kind(), + owner.clone(), + oc_connector.title().to_string(), + oc_connector.state_arc(), // Share the state Arc from the connector + oc_connector.command_tx(), + oc_connector.events_sender(), + cfg.params.clone(), // Store config for restart + Some(working_directory), + cfg.icon_path.clone(), + cfg.show_type_overlay, + ); + // Share the error_kind Arc so the handle observes connector error classification + handle.set_error_kind_arc(oc_connector.error_kind_arc()); + + // Start the connector task immediately (before dropping the connector) + // This is necessary because OpenCodeConnector consumes the cmd_rx on start + let task_handle = oc_connector.start_task().await; + + // Store the task handle in the ConnectorHandle + handle.set_task_handle(task_handle).await; + + // Events are published directly to the SharingBus by the + // connector (Task 9); no forwarder task needed. + + info!( + connector_id = %connector_id, + "OpenCode connector task started" + ); + + handle + } + ConnectorKind::Acp => { + // Deserialize params to AcpConfig + let acp_config: AcpConfig = + serde_json::from_value(cfg.params.clone()).map_err(|e| { + error!( + connector_id = %connector_id, + error = %e, + "Failed to deserialize AcpConfig" + ); + CoreError::InvalidConfig + })?; + + // Extract title from main config or use default + let title = cfg + .title + .clone() + .unwrap_or_else(|| "ACP Connector".to_string()); + + debug!( + connector_id = %connector_id, + title = %title, + transport = ?acp_config.transport, + "Creating and starting ACP connector" + ); + + // Create the connector + let acp_connector = { + let connector = AcpConnector::new( + connector_id.clone(), + owner.clone(), + title, + acp_config, + Arc::clone(&self.sharing_bus), + ) + .map_err(|e| { + error!( + connector_id = %connector_id, + error = %e, + "Failed to create ACP connector" + ); + CoreError::InvalidConfig + })?; + let connector = + connector.with_tool_configuration(cfg.tool_configuration.clone()); + #[cfg(feature = "server")] + let connector = connector.with_inspector(self.inspector.clone()); + #[cfg(feature = "server")] + let connector = connector.with_process_manager(self.process_manager.clone()); + connector + }; + + // Resolve working directory using global config + let config_lock = self.config.read().await; + let working_directory = crate::resolve_default_runtime_working_directory(&cfg, &config_lock); + + // Create a handle from the connector's channels + // IMPORTANT: Use new_with_state to share the same state Arc + let handle = ConnectorHandle::new_with_state( + connector_id.clone(), + acp_connector.kind(), + owner.clone(), + acp_connector.title().to_string(), + acp_connector.state_arc(), // Share the state Arc from the connector + acp_connector.command_tx(), + acp_connector.events_sender(), + cfg.params.clone(), // Store config for restart + Some(working_directory), + cfg.icon_path.clone(), + cfg.show_type_overlay, + ); + + // Start the connector task immediately + let task_handle = acp_connector.start_task().await; + + // Store the task handle in the ConnectorHandle + handle.set_task_handle(task_handle).await; + + // The connector publishes events directly to the SharingBus + // (Task 9). We still subscribe to the per-connector broadcast + // here to intercept `available_commands_update` notifications + // and mirror the dynamic commands onto the handle so that + // `get_available_commands()` can return them. + let mut connector_events = handle.subscribe(); + let conn_id_for_intercept: ConnectorId = connector_id.clone(); + let handle_for_commands = handle.clone(); + tokio::spawn(async move { + while let Ok(event) = connector_events.recv().await { + if let dirigent_protocol::Event::SessionUpdate { + update: + dirigent_protocol::SessionUpdate::Unknown { ref data }, + .. + } = event + { + if data + .get("sessionUpdate") + .and_then(|s| s.as_str()) + == Some("available_commands_update") + { + if let Some(commands_json) = data.get("availableCommands") { + if let Ok(commands) = serde_json::from_value::< + Vec<crate::acp::protocol::streaming::Command>, + >( + commands_json.clone() + ) { + tracing::debug!( + connector_id = %conn_id_for_intercept, + count = commands.len(), + "Storing dynamic available commands on handle" + ); + handle_for_commands.set_available_commands(commands); + } + } + } + } + } + debug!(connector_id = %conn_id_for_intercept, "available_commands interceptor ended"); + }); + + info!( + connector_id = %connector_id, + "ACP connector task started" + ); + + handle + } + ConnectorKind::Mock => { + // Mock connector is only for testing + error!(connector_id = %connector_id, "Mock connector cannot be created via API"); + return Err(CoreError::InvalidConfig); + } + ConnectorKind::Acceptor => { + // Create an Acceptor connector (represents incoming ACP connections) + // This is a lightweight connector that shows in the UI but delegates + // actual session handling to the ACP Server and SessionManager + + let title = cfg + .title + .clone() + .unwrap_or_else(|| "Acceptor (ACP)".to_string()); + + debug!( + connector_id = %connector_id, + title = %title, + "Creating Acceptor connector for incoming connections" + ); + + // Create channels (mostly unused for Acceptor, but needed for ConnectorHandle) + let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel::<ConnectorCommand>(100); + let (events_tx, _) = + tokio::sync::broadcast::channel::<dirigent_protocol::Event>(1000); + + // Acceptor is immediately "Ready" since it's passive (waits for incoming connections) + let state = std::sync::Arc::new(tokio::sync::RwLock::new(ConnectorState::Ready)); + + // Resolve working directory using global config + let config_lock = self.config.read().await; + let working_directory = crate::resolve_default_runtime_working_directory(&cfg, &config_lock); + + // Create handle + let handle = ConnectorHandle::new_with_state( + connector_id.clone(), + ConnectorKind::Acceptor, + owner.clone(), + title, + state, + cmd_tx, + events_tx.clone(), + cfg.params.clone(), + Some(working_directory), + cfg.icon_path.clone(), + cfg.show_type_overlay, + ); + + // Emit Connected event for consistency + let _ = events_tx.send(dirigent_protocol::Event::Connected); + + info!( + connector_id = %connector_id, + "Acceptor connector created and ready for incoming connections" + ); + + handle + } + ConnectorKind::Gateway => { + // Deserialize params to GatewayConfig + let gateway_config: crate::connectors::gateway::GatewayConfig = + serde_json::from_value(cfg.params.clone()).map_err(|e| { + error!( + connector_id = %connector_id, + error = %e, + "Failed to deserialize GatewayConfig" + ); + CoreError::InvalidConfig + })?; + + debug!( + connector_id = %connector_id, + title = %gateway_config.title, + echo_enabled = gateway_config.default_echo_enabled, + "Creating and starting Gateway connector" + ); + + // Create the connector + let mut gateway_connector = crate::connectors::gateway::GatewayConnector::new( + connector_id.clone(), + owner.clone(), + gateway_config, + Arc::clone(&self.sharing_bus), + ); + + // Wire up callbacks for /list-connectors command + // The cache is a sync RwLock that gets updated whenever connectors change + let cache_for_list = self.connector_summary_cache(); + let list_callback: crate::connectors::gateway::ConnectorListCallback = + Arc::new(move || { + // Read from the sync cache (non-blocking) + match cache_for_list.read() { + Ok(cache) => cache.clone(), + Err(_) => Vec::new(), // Return empty if lock is poisoned + } + }); + gateway_connector.set_connector_list_callback(list_callback); + + // Wire up session transfer callback for /select-connector command + let self_weak_clone = Arc::clone(&self.self_weak); + let transfer_callback: SessionTransferCallback = Arc::new(move |request| { + let self_weak = Arc::clone(&self_weak_clone); + tokio::spawn(async move { + // Try to upgrade the weak reference + let weak_guard = self_weak.read().await; + if let Some(ref weak) = *weak_guard { + if let Some(runtime) = weak.upgrade() { + runtime.transfer_session(request).await; + } else { + tracing::warn!( + "Runtime has been dropped, cannot process transfer request" + ); + let _ = request.result_tx.send(SessionTransferResult::Failed( + "Runtime is shutting down".to_string(), + )); + } + } else { + tracing::error!("Runtime weak reference not initialized"); + let _ = request.result_tx.send(SessionTransferResult::Failed( + "Runtime not properly initialized".to_string(), + )); + } + }); + }); + gateway_connector.set_session_transfer_callback(transfer_callback); + + // Wire up inspector for session tracking + #[cfg(feature = "server")] + gateway_connector.set_inspector(self.inspector.clone()); + + // Resolve working directory using global config + let config_lock = self.config.read().await; + let working_directory = crate::resolve_default_runtime_working_directory(&cfg, &config_lock); + + // Create a handle from the connector's channels + let handle = ConnectorHandle::new_with_state( + connector_id.clone(), + gateway_connector.kind(), + owner.clone(), + gateway_connector.title().to_string(), + gateway_connector.state_arc(), + gateway_connector.command_tx(), + gateway_connector.events_sender(), + cfg.params.clone(), + Some(working_directory), + cfg.icon_path.clone(), + cfg.show_type_overlay, + ); + + // Start the connector task immediately + let task_handle = gateway_connector.start_task().await; + + // Store the task handle in the ConnectorHandle + handle.set_task_handle(task_handle).await; + + // Events are published directly to the SharingBus by the + // connector (Task 9); no forwarder task needed. + + info!( + connector_id = %connector_id, + "Gateway connector task started" + ); + + handle + } + }; + + // Insert the handle into the connectors map + { + let mut connectors = self.connectors.write().await; + connectors.insert(connector_id.clone(), handle); + } + + // Update the connector summary cache for GatewayConnector callbacks + self.update_connector_summary_cache().await; + + // Save values we'll need after moving cfg + let connector_kind = cfg.kind.clone(); + let connector_title = cfg + .title + .clone() + .unwrap_or_else(|| format!("{:?} Connector", cfg.kind)); + + // Register connector with archivist (if available and server feature enabled) + #[cfg(feature = "server")] + if let Some(archivist) = self.archivist.read().await.clone() { + // Determine custom_uid: use connector_id if it's a UUID, otherwise let archivist generate one + let custom_uid = uuid::Uuid::parse_str(&connector_id).ok(); + + // Compute a deterministic fingerprint for archivist re-association across restarts + let fingerprint = + crate::connectors::fingerprint::compute_fingerprint(&cfg.kind, &cfg.params); + + // Create registration request + let register_req = dirigent_archivist::types::RegisterConnectorRequest { + custom_uid, // Some(uuid) if connector_id was UUID, None otherwise + r#type: format!("{:?}", connector_kind), + title: connector_title.clone(), + client_native_id: connector_id.clone(), + metadata: serde_json::json!({}), + fingerprint: fingerprint.clone(), + }; + + // Attempt registration (non-blocking, best-effort) + match archivist.register_connector(register_req, None).await { + Ok(response) => { + // Store the mapping for archive-first reads (works for both UUID and custom IDs) + self.archivist_connector_uids + .write() + .await + .insert(connector_id.clone(), response.connector_uid); + + // Backfill fingerprint on existing records that were registered + // before fingerprinting was introduced (Task 11) + if response.status == dirigent_archivist::types::RegisterStatus::Aliased { + if let Some(ref fp) = fingerprint { + let _ = archivist + .update_connector_fingerprint( + response.connector_uid, + fp.clone(), + None, + ) + .await; + } + } + + info!( + connector_id = %connector_id, + connector_uid = %response.connector_uid, + status = ?response.status, + "Registered connector with archivist" + ); + } + Err(e) => { + warn!( + connector_id = %connector_id, + error = %e, + "Failed to register connector with archivist (non-fatal)" + ); + } + } + } + + // Register connector with inspector (if available) + #[cfg(feature = "server")] + if let Some(ref inspector) = self.inspector { + let node_id = + dirigent_inspector::NodeId::new(format!("dirigent/connectors/{}", connector_id)); + let parent_id = dirigent_inspector::NodeId::new("dirigent/connectors"); + let mut metadata = dirigent_inspector::NodeMetadata::new( + dirigent_inspector::NodeKind::Connector, + &connector_title, + ) + .with_state(dirigent_inspector::NodeState::Initializing) + .with_property("kind", serde_json::json!(format!("{:?}", connector_kind))) + .with_property("owner", serde_json::json!(&owner)); + + // Extract command/executable from params for display in inspector + if let Some(params_obj) = cfg.params.as_object() { + if let Some(transport) = params_obj.get("transport") { + match transport.get("type").and_then(|t| t.as_str()) { + Some("stdio") => { + if let Some(cmd) = transport.get("command").and_then(|c| c.as_str()) { + let args = transport + .get("args") + .and_then(|a| a.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str()) + .collect::<Vec<_>>() + .join(" ") + }) + .unwrap_or_default(); + let full_command = if args.is_empty() { + cmd.to_string() + } else { + format!("{} {}", cmd, args) + }; + metadata = metadata.with_property( + "command", + serde_json::json!(full_command), + ); + } + } + Some("http") => { + if let Some(url) = + transport.get("base_url").and_then(|u| u.as_str()) + { + metadata = metadata + .with_property("command", serde_json::json!(url)); + } + } + _ => {} + } + } else if let Some(base_url) = params_obj.get("base_url").and_then(|u| u.as_str()) + { + // OpenCode connectors have base_url at the top level + metadata = + metadata.with_property("command", serde_json::json!(base_url)); + } + } + match inspector + .register(node_id, &parent_id, metadata, None) + .await + { + Ok(mut handle) => { + handle.detach(); + info!( + connector_id = %connector_id, + "Registered connector with inspector" + ); + } + Err(e) => { + warn!( + connector_id = %connector_id, + error = %e, + "Failed to register connector with inspector (non-fatal)" + ); + } + } + } + + // Auto-save: persist the new connector to the configuration file + // This ensures the connector will be restored on server restart + { + let mut config_lock = self.config.write().await; + // Ensure the connector config has the generated ID set + cfg.id = Some(connector_id.clone()); + + // Check if connector already exists in config + if let Some(existing_idx) = config_lock + .connectors + .iter() + .position(|c| c.id == Some(connector_id.clone())) + { + // Update existing entry instead of appending + debug!(connector_id = %connector_id, "Updating existing connector in config"); + config_lock.connectors[existing_idx] = cfg; + } else { + // Append new entry + debug!(connector_id = %connector_id, "Adding new connector to config"); + config_lock.connectors.push(cfg); + } + } + + // Save configuration to disk (log warning on error, don't fail operation) + if let Err(e) = self.save_config().await { + warn!( + connector_id = %connector_id, + error = %e, + "Failed to save configuration after connector creation" + ); + } + + // Log connector count after creation + let count_after = { + let connectors = self.connectors.read().await; + connectors.len() + }; + + info!( + connector_id = %connector_id, + connector_count = count_after, + "Connector created successfully (count: {} -> {})", + count_before, + count_after + ); + + // Log special case: first connector created + if count_before == 0 && count_after == 1 { + info!("First connector created - transitioning from empty state"); + } + + // Emit ConnectorCreated onto the SharingBus. + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + dirigent_protocol::Event::ConnectorCreated { + connector_id: connector_id.clone(), + kind: format!("{:?}", connector_kind), + title: connector_title, + }, + None, // Task 9 will wire connector_uid + connector_id.clone(), + ); + self.sharing_bus.publish(bus_event).await; + + Ok(connector_id) + } + + /// Start a connector's background task + /// + /// This method starts the connector's async task loop, which manages the + /// connection to the underlying agent system and processes commands and events. + /// + /// # Arguments + /// + /// * `id` - The unique identifier of the connector to start + /// + /// # Returns + /// + /// Ok(()) if the connector was started successfully, or an error. + /// + /// # Errors + /// + /// - `NotFound` - No connector with the specified ID exists + /// - `InvalidState` - Connector is already running or in an invalid state + /// - `StartFailed` - Failed to start the connector task + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig}; + /// # async fn example(runtime: &CoreRuntime, connector_id: &str) -> Result<(), Box<dyn std::error::Error>> { + /// runtime.start_connector(&connector_id.to_string()).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn start_connector(&self, id: &ConnectorId) -> Result<(), CoreError> { + info!(connector_id = %id, "Starting connector"); + + // Get the connector handle + let handle = { + let connectors = self.connectors.read().await; + connectors.get(id).cloned().ok_or_else(|| { + error!(connector_id = %id, "Connector not found"); + CoreError::NotFound + })? + }; + + // Check if already started (has task handle) + if handle.take_task_handle().await.is_some() { + warn!(connector_id = %id, "Connector already started"); + return Err(CoreError::InvalidState); + } + + // Start the connector based on its kind + match handle.kind() { + ConnectorKind::OpenCode => { + // For OpenCode, we need to recreate the connector and start its task + // This is necessary because the start_task() method consumes the command receiver + + // We'll create a new OpenCodeConnector instance with the same config + // Extract config from handle metadata + // Since we can't easily get the config back, we'll use the command pattern + + // For now, we'll document that connectors should be started immediately after creation + // A better approach would be to store the connector trait object alongside the handle + error!( + connector_id = %id, + "Starting existing connector not yet supported - start immediately after creation" + ); + Err(CoreError::Internal( + "Starting existing connector not yet supported - start immediately after creation" + .to_string(), + )) + } + ConnectorKind::Acp => { + error!(connector_id = %id, "ACP connector not yet implemented"); + Err(CoreError::Internal( + "ACP connector not yet implemented".to_string(), + )) + } + ConnectorKind::Mock => { + // Mock connectors don't have background tasks + debug!(connector_id = %id, "Mock connector has no background task"); + Ok(()) + } + ConnectorKind::Acceptor => { + // Acceptors don't have background tasks - they are event-driven + debug!(connector_id = %id, "Acceptor connector has no background task"); + Ok(()) + } + ConnectorKind::Gateway => { + // Gateway connectors are started immediately after creation + // like OpenCode connectors + error!( + connector_id = %id, + "Starting existing Gateway connector not yet supported - start immediately after creation" + ); + Err(CoreError::Internal( + "Starting existing Gateway connector not yet supported - start immediately after creation" + .to_string(), + )) + } + } + } + + /// Stop a running connector gracefully + /// + /// This method sends a shutdown command to the connector and waits for its + /// background task to complete. The connector will clean up resources and + /// transition to the Stopped state. + /// + /// # Arguments + /// + /// * `id` - The unique identifier of the connector to stop + /// + /// # Returns + /// + /// Ok(()) if the connector was stopped successfully, or an error. + /// + /// # Errors + /// + /// - `NotFound` - No connector with the specified ID exists + /// - `Internal` - Failed to send shutdown command or task join failed + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig}; + /// # async fn example(runtime: &CoreRuntime, connector_id: &str) -> Result<(), Box<dyn std::error::Error>> { + /// runtime.stop_connector(&connector_id.to_string()).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn stop_connector(&self, id: &ConnectorId) -> Result<(), CoreError> { + info!(connector_id = %id, "Stopping connector"); + + // Get the connector handle + let handle = { + let connectors = self.connectors.read().await; + connectors.get(id).cloned().ok_or_else(|| { + error!(connector_id = %id, "Connector not found"); + CoreError::NotFound + })? + }; + + // Check if connector has been started + let current_state = handle.state(); + if matches!(current_state, ConnectorState::Initializing) { + // Connector was never started, so there's no task to shutdown + debug!(connector_id = %id, "Connector is in Initializing state, skipping shutdown command"); + } else { + // Send shutdown command + let cmd_tx = handle.command_tx(); + match cmd_tx.send(ConnectorCommand::Shutdown).await { + Ok(_) => { + debug!(connector_id = %id, "Shutdown command sent"); + } + Err(e) => { + // Command send failed - likely the receiver was dropped (connector not started) + debug!(connector_id = %id, error = %e, "Failed to send shutdown command (connector may not have been started)"); + } + } + } + + // Wait for task to complete (with timeout) + if let Some(task_handle) = handle.take_task_handle().await { + debug!(connector_id = %id, "Waiting for connector task to complete"); + + // Use tokio::time::timeout to avoid waiting forever + match tokio::time::timeout(Duration::from_secs(30), task_handle).await { + Ok(Ok(())) => { + info!(connector_id = %id, "Connector task completed successfully"); + } + Ok(Err(e)) if e.is_panic() => { + // Task panicked - log but don't treat as error + // The task is stopped (panicked tasks are stopped), which is what we wanted + error!(connector_id = %id, "Connector task panicked during shutdown: {:?}", e); + info!(connector_id = %id, "Treating panicked task as stopped (goal achieved)"); + } + Ok(Err(e)) => { + // Other JoinHandle errors (e.g., task was cancelled) + warn!(connector_id = %id, error = %e, "Connector task terminated with error"); + // Don't return error - task is stopped regardless + } + Err(_) => { + warn!(connector_id = %id, "Connector task did not complete within timeout"); + // Don't return error - the connector may still be shutting down + } + } + } else { + debug!(connector_id = %id, "No task handle found (connector may not have been started)"); + } + + // Update state to Stopped + { + let state_lock = handle.state_lock(); + let mut state = state_lock.write().await; + *state = crate::types::ConnectorState::Stopped; + } + + info!(connector_id = %id, "Connector stopped successfully"); + + Ok(()) + } + + /// Restart a stopped connector + /// + /// This method recreates a stopped connector's background task with fresh + /// channels while preserving its identity (ID, owner, configuration). + /// The connector must be in the `Stopped` or `Error` state to be restarted. + /// + /// # Arguments + /// + /// * `id` - The unique identifier of the connector to restart + /// + /// # Returns + /// + /// Ok(()) if the connector was restarted successfully, or an error. + /// + /// # Errors + /// + /// - `NotFound` - No connector with the specified ID exists + /// - `InvalidState` - Connector is not in `Stopped` or `Error` state + /// - `InvalidConfig` - Failed to deserialize connector configuration + /// - `Internal` - Other internal errors during restart + /// + /// # State Transitions + /// + /// The connector will transition through the following states: + /// - `Stopped` or `Error` → `Initializing` → `Connecting` → `Ready` + /// + /// # Preservation + /// + /// The following are preserved across restart: + /// - Connector ID (remains the same) + /// - Owner (authorization unchanged) + /// - Configuration (same parameters) + /// - Event broadcast channel (existing subscribers continue receiving) + /// - State Arc (external observers see state transitions) + /// + /// The following are recreated: + /// - Command channel (new sender/receiver pair) + /// - Connector instance (fresh OpenCodeConnector, etc.) + /// - Background task (new spawn) + /// - Task handle (new JoinHandle) + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig}; + /// # async fn example(runtime: &CoreRuntime, connector_id: &str) -> Result<(), Box<dyn std::error::Error>> { + /// // Stop a connector + /// runtime.stop_connector(&connector_id.to_string()).await?; + /// + /// // Restart it + /// runtime.restart_connector(&connector_id.to_string()).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn restart_connector(&self, id: &ConnectorId) -> Result<(), CoreError> { + info!(connector_id = %id, "Restarting connector"); + + // Get the connector handle from registry + let handle = { + let connectors = self.connectors.read().await; + connectors.get(id).cloned() + }; + + let handle = match handle { + Some(h) => h, + None => { + error!(connector_id = %id, "Connector not found"); + return Err(CoreError::NotFound); + } + }; + + // Verify state is Stopped or Error (only valid states for restart) + let current_state = handle.state(); + match current_state { + crate::types::ConnectorState::Stopped | crate::types::ConnectorState::Error(_) => { + debug!(connector_id = %id, state = ?current_state, "Connector is in restartable state"); + } + _ => { + error!( + connector_id = %id, + state = ?current_state, + "Cannot restart connector in current state. Stop it first." + ); + return Err(CoreError::InvalidState); + } + } + + // Get config from handle + let config_json = handle.get_config_cloned().await; + debug!(connector_id = %id, "Retrieved connector config for restart"); + + // Match on connector kind and recreate + match handle.kind() { + ConnectorKind::OpenCode => { + // Deserialize config to OpenCodeConfig + let oc_config: crate::connectors::opencode::OpenCodeConfig = + serde_json::from_value(config_json).map_err(|e| { + error!( + connector_id = %id, + error = %e, + "Failed to deserialize OpenCodeConfig during restart" + ); + CoreError::InvalidConfig + })?; + + debug!( + connector_id = %id, + base_url = %oc_config.base_url, + "Creating new OpenCode connector instance for restart" + ); + + // Use existing title from handle + let title = handle.title().to_string(); + + // Create new connector instance + let oc_connector = crate::connectors::opencode::OpenCodeConnector::new( + id.clone(), + handle.owner().clone(), + title, + oc_config, + Arc::clone(&self.sharing_bus), + ); + + // Create new handle with same identity but fresh channels + // IMPORTANT: Use new connector's events_sender, not old handle's + // The new connector creates a new broadcast channel in its constructor + let mut new_handle = ConnectorHandle::new_with_state( + id.clone(), + oc_connector.kind(), + handle.owner().clone(), + oc_connector.title().to_string(), + handle.state_lock(), // Reuse existing state Arc + oc_connector.command_tx(), + oc_connector.events_sender(), // NEW: Use fresh events channel + handle.get_config_cloned().await, // Preserve config + handle.working_directory(), // Preserve working directory + handle.icon_path().map(|s| s.to_string()), + handle.show_type_overlay(), + ); + // Share the error_kind Arc so the handle observes connector error classification + new_handle.set_error_kind_arc(oc_connector.error_kind_arc()); + + // Start the new connector task + let task_handle = oc_connector.start_task().await; + + // Store the task handle + new_handle.set_task_handle(task_handle).await; + + // Events are published directly to the SharingBus by the + // connector (Task 9); no forwarder task needed. + + info!( + connector_id = %id, + "OpenCode connector task restarted" + ); + + // Replace the handle in the registry + { + let mut connectors = self.connectors.write().await; + connectors.insert(id.clone(), new_handle); + } + } + ConnectorKind::Acp => { + // Deserialize config to AcpConfig + let acp_config: AcpConfig = serde_json::from_value(config_json).map_err(|e| { + error!( + connector_id = %id, + error = %e, + "Failed to deserialize AcpConfig during restart" + ); + CoreError::InvalidConfig + })?; + + // Use existing title from handle + let title = handle.title().to_string(); + + debug!( + connector_id = %id, + title = %title, + transport = ?acp_config.transport, + "Creating new ACP connector instance for restart" + ); + + // Create new connector instance + let acp_connector = AcpConnector::new( + id.clone(), + handle.owner().clone(), + title, + acp_config, + Arc::clone(&self.sharing_bus), + ) + .map_err(|e| { + error!( + connector_id = %id, + error = %e, + "Failed to create ACP connector during restart" + ); + CoreError::InvalidConfig + })?; + + // Look up tool configuration from saved config + let tool_config = { + let config_lock = self.config.read().await; + config_lock + .connectors + .iter() + .find(|c| c.id.as_deref() == Some(id.as_str())) + .and_then(|c| c.tool_configuration.clone()) + }; + let acp_connector = + acp_connector.with_tool_configuration(tool_config); + #[cfg(feature = "server")] + let acp_connector = acp_connector.with_process_manager(self.process_manager.clone()); + + // Create new handle with same identity but fresh channels + // IMPORTANT: Use new connector's events_sender, not old handle's + // The new connector creates a new broadcast channel in its constructor + let new_handle = ConnectorHandle::new_with_state( + id.clone(), + acp_connector.kind(), + handle.owner().clone(), + acp_connector.title().to_string(), + handle.state_lock(), // Reuse existing state Arc + acp_connector.command_tx(), + acp_connector.events_sender(), // NEW: Use fresh events channel + handle.get_config_cloned().await, // Preserve config + handle.working_directory(), // Preserve working directory + handle.icon_path().map(|s| s.to_string()), + handle.show_type_overlay(), + ); + + // Start the new connector task + let task_handle = acp_connector.start_task().await; + + // Store the task handle + new_handle.set_task_handle(task_handle).await; + + // The connector publishes events directly to the SharingBus + // (Task 9). Subscribe to the per-connector broadcast only to + // intercept `available_commands_update` notifications and + // mirror dynamic commands onto the handle. + let mut connector_events = new_handle.subscribe(); + let conn_id_for_intercept: ConnectorId = id.clone(); + let handle_for_commands = new_handle.clone(); + tokio::spawn(async move { + while let Ok(event) = connector_events.recv().await { + if let dirigent_protocol::Event::SessionUpdate { + update: + dirigent_protocol::SessionUpdate::Unknown { ref data }, + .. + } = event + { + if data + .get("sessionUpdate") + .and_then(|s| s.as_str()) + == Some("available_commands_update") + { + if let Some(commands_json) = data.get("availableCommands") { + if let Ok(commands) = serde_json::from_value::< + Vec<crate::acp::protocol::streaming::Command>, + >( + commands_json.clone() + ) { + handle_for_commands.set_available_commands(commands); + } + } + } + } + } + debug!(connector_id = %conn_id_for_intercept, "available_commands interceptor ended for restarted ACP connector"); + }); + + info!( + connector_id = %id, + "ACP connector task restarted" + ); + + // Replace the handle in the registry + { + let mut connectors = self.connectors.write().await; + connectors.insert(id.clone(), new_handle); + } + } + ConnectorKind::Mock => { + // Mock connectors don't have background tasks, so restart is a no-op + debug!(connector_id = %id, "Mock connector restart is a no-op"); + } + ConnectorKind::Acceptor => { + // Acceptors don't have background tasks, so restart is a no-op + debug!(connector_id = %id, "Acceptor connector restart is a no-op"); + } + ConnectorKind::Gateway => { + // Deserialize config to GatewayConfig + let gateway_config: crate::connectors::gateway::GatewayConfig = + serde_json::from_value(config_json).map_err(|e| { + error!( + connector_id = %id, + error = %e, + "Failed to deserialize GatewayConfig during restart" + ); + CoreError::InvalidConfig + })?; + + debug!( + connector_id = %id, + title = %gateway_config.title, + echo_enabled = gateway_config.default_echo_enabled, + "Creating new Gateway connector instance for restart" + ); + + // Create new connector instance + let mut gateway_connector = crate::connectors::gateway::GatewayConnector::new( + id.clone(), + handle.owner().clone(), + gateway_config, + Arc::clone(&self.sharing_bus), + ); + + // Wire up inspector for session tracking + #[cfg(feature = "server")] + gateway_connector.set_inspector(self.inspector.clone()); + + // Create new handle with same identity but fresh channels + // IMPORTANT: Use new connector's events_sender, not old handle's + // The new connector creates a new broadcast channel in its constructor + let new_handle = ConnectorHandle::new_with_state( + id.clone(), + gateway_connector.kind(), + handle.owner().clone(), + gateway_connector.title().to_string(), + handle.state_lock(), + gateway_connector.command_tx(), + gateway_connector.events_sender(), // NEW: Use fresh events channel + handle.get_config_cloned().await, + handle.working_directory(), + handle.icon_path().map(|s| s.to_string()), + handle.show_type_overlay(), + ); + + // Start the new connector task + let task_handle = gateway_connector.start_task().await; + + // Store the task handle + new_handle.set_task_handle(task_handle).await; + + // Events are published directly to the SharingBus by the + // connector (Task 9); no forwarder task needed. + + info!( + connector_id = %id, + "Gateway connector task restarted" + ); + + // Replace the handle in the registry + { + let mut connectors = self.connectors.write().await; + connectors.insert(id.clone(), new_handle); + } + } + } + + info!(connector_id = %id, "Connector restarted successfully"); + + Ok(()) + } + + /// Update a connector's configuration + /// + /// This method updates the stored configuration for a connector. + /// The changes are persisted immediately but only take effect when + /// the connector is restarted. + /// + /// # Arguments + /// + /// * `id` - The connector ID + /// * `patch` - JSON patch to merge into existing config + /// + /// # Returns + /// + /// Ok(()) on success, or CoreError on failure + /// + /// # Errors + /// + /// - `NotFound` - No connector with the specified ID exists + /// - `Internal` - Failed to save configuration to disk + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig}; + /// # use serde_json::json; + /// # async fn example(runtime: &CoreRuntime, connector_id: &str) -> Result<(), Box<dyn std::error::Error>> { + /// // Update the connector's title + /// runtime.update_connector_config( + /// &connector_id.to_string(), + /// json!({ + /// "title": "New Title" + /// }) + /// ).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn update_connector_config( + &self, + id: &ConnectorId, + patch: serde_json::Value, + ) -> Result<(), CoreError> { + info!(connector_id = %id, "Updating connector configuration"); + + // Get the connector handle + let handle = { + let connectors = self.connectors.read().await; + connectors.get(id).cloned().ok_or_else(|| { + error!(connector_id = %id, "Connector not found"); + CoreError::NotFound + })? + }; + + // Get current config + let mut config = handle.get_config_cloned().await; + debug!(connector_id = %id, "Retrieved current config for update"); + + // Merge the patch into the config (shallow merge) + if let (Some(base_obj), Some(patch_obj)) = (config.as_object_mut(), patch.as_object()) { + for (key, value) in patch_obj { + base_obj.insert(key.clone(), value.clone()); + } + } + + // Update the handle's config + handle.set_config(config.clone()).await; + debug!(connector_id = %id, "Updated connector handle config"); + + // Update the config in CoreConfig's connectors list and live ConnectorHandle + { + let mut config_lock = self.config.write().await; + let mut connectors_lock = self.connectors.write().await; + + // Update persisted config + if let Some(connector_cfg) = config_lock + .connectors + .iter_mut() + .find(|c| c.id.as_ref() == Some(id)) + { + // Handle top-level orchestration fields + if let Some(patch_obj) = patch.as_object() { + // Update title if present in patch + if let Some(title_value) = patch_obj.get("title") { + let new_title = title_value.as_str().map(|s| s.to_string()); + connector_cfg.title = new_title.clone(); + + // Also update the live ConnectorHandle + if let Some(live_handle) = connectors_lock.get_mut(id) { + if let Some(title) = new_title { + live_handle.set_title(title); + } + } + } + + // Update working_directory if present in patch + if let Some(wd_value) = patch_obj.get("working_directory") { + let new_wd = if wd_value.is_null() { + None + } else { + wd_value.as_str().map(|s| std::path::PathBuf::from(s)) + }; + connector_cfg.working_directory = new_wd.clone(); + + // Also update the live ConnectorHandle (async operation) + if let Some(live_handle) = connectors_lock.get(id) { + live_handle.set_working_directory(new_wd).await; + } + } + + // Update supported_features if present in patch + if let Some(features_value) = patch_obj.get("supported_features") { + if let Some(features_array) = features_value.as_array() { + let new_features: Vec<String> = features_array + .iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect(); + connector_cfg.supported_features = new_features; + } + } + + // Update icon_path if present in patch + if let Some(icon_value) = patch_obj.get("icon_path") { + let new_icon_path = if icon_value.is_null() { + None + } else { + icon_value.as_str().map(|s| s.to_string()) + }; + connector_cfg.icon_path = new_icon_path.clone(); + + // Also update the live ConnectorHandle + if let Some(live_handle) = connectors_lock.get_mut(id) { + live_handle.set_icon_path(new_icon_path); + } + } + + // Update show_type_overlay if present in patch + if let Some(overlay_value) = patch_obj.get("show_type_overlay") { + if let Some(overlay_bool) = overlay_value.as_bool() { + connector_cfg.show_type_overlay = overlay_bool; + + // Also update the live ConnectorHandle + if let Some(live_handle) = connectors_lock.get_mut(id) { + live_handle.set_show_type_overlay(overlay_bool); + } + } + } + + // Merge other fields into the stored connector params + // ONLY connector-specific fields should go in params + if let Some(params_obj) = connector_cfg.params.as_object_mut() { + for (key, value) in patch_obj { + // Skip ALL orchestration fields - they belong in main config, not params + if key != "title" + && key != "working_directory" + && key != "supported_features" + && key != "icon_path" + && key != "show_type_overlay" + && key != "kind" + { + params_obj.insert(key.clone(), value.clone()); + } + } + } + } + debug!(connector_id = %id, "Updated connector config in CoreConfig and live handle"); + } + } + + // Persist changes to disk + if let Err(e) = self.save_config().await { + warn!( + connector_id = %id, + error = %e, + "Failed to save configuration after connector update" + ); + return Err(CoreError::Internal(format!( + "Failed to save configuration: {}", + e + ))); + } + + info!(connector_id = %id, "Connector configuration updated successfully"); + + Ok(()) + } + + /// Remove a connector from the runtime + /// + /// This method stops the connector (if running) and removes it from the + /// connector registry. After removal, the connector ID can be reused. + /// + /// # Arguments + /// + /// * `id` - The unique identifier of the connector to remove + /// + /// # Returns + /// + /// Ok(()) if the connector was removed successfully, or an error. + /// + /// # Errors + /// + /// - `NotFound` - No connector with the specified ID exists + /// - Other errors from `stop_connector` if the connector is running + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig}; + /// # async fn example(runtime: &CoreRuntime, connector_id: &str) -> Result<(), Box<dyn std::error::Error>> { + /// runtime.remove_connector(&connector_id.to_string()).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn remove_connector(&self, id: &ConnectorId) -> Result<(), CoreError> { + // Log connector count before removal + let count_before = { + let connectors = self.connectors.read().await; + connectors.len() + }; + + info!( + connector_id = %id, + connector_count = count_before, + "Removing connector (current count: {})", count_before + ); + + // Check if connector exists + { + let connectors = self.connectors.read().await; + if !connectors.contains_key(id) { + error!(connector_id = %id, "Connector not found"); + return Err(CoreError::NotFound); + } + } + + // Stop the connector first (if running) + // We ignore NotFound errors here since we already checked above + match self.stop_connector(id).await { + Ok(()) => { + debug!(connector_id = %id, "Connector stopped before removal"); + } + Err(CoreError::NotFound) => { + // This shouldn't happen, but handle it gracefully + debug!(connector_id = %id, "Connector not found during stop (concurrent removal?)"); + } + Err(e) => { + warn!(connector_id = %id, error = %e, "Error stopping connector, continuing with removal"); + } + } + + // Remove from registry + { + let mut connectors = self.connectors.write().await; + connectors.remove(id); + } + + // Update the connector summary cache for GatewayConnector callbacks + self.update_connector_summary_cache().await; + + // Auto-save: remove the connector from the configuration file + // This ensures the connector won't be restored on server restart + { + let mut config_lock = self.config.write().await; + // If this is a Zed-sourced connector, extract its title so we can + // add it to the dismissed list (preventing re-addition on restart) + let dismissed_title = config_lock + .connectors + .iter() + .find(|cfg| cfg.id.as_ref().map(|cfg_id| cfg_id == id).unwrap_or(false)) + .filter(|cfg| cfg.source.as_deref() == Some("zed")) + .and_then(|cfg| cfg.title.clone()); + + if let Some(ref title) = dismissed_title { + if !config_lock.dismissed_zed_agents.contains(title) { + info!( + connector_id = %id, + title = %title, + "Adding Zed connector to dismissed list" + ); + config_lock.dismissed_zed_agents.push(title.clone()); + } + } + + config_lock.connectors.retain(|cfg| { + // Remove the connector by its ID + if cfg.id.as_ref().map(|cfg_id| cfg_id == id).unwrap_or(false) { + return false; + } + // Also remove null-ID Zed entries that match the dismissed title + // (these are transient entries from enrichment that haven't been + // assigned IDs yet) + if let Some(ref dismissed) = dismissed_title { + if cfg.id.is_none() + && cfg.source.as_deref() == Some("zed") + && cfg.title.as_deref() == Some(dismissed) + { + return false; + } + } + true + }); + } + + // Save configuration to disk (log warning on error, don't fail operation) + if let Err(e) = self.save_config().await { + warn!( + connector_id = %id, + error = %e, + "Failed to save configuration after connector removal" + ); + } + + // Log connector count after removal + let count_after = { + let connectors = self.connectors.read().await; + connectors.len() + }; + + info!( + connector_id = %id, + connector_count = count_after, + "Connector removed successfully (count: {} -> {})", + count_before, + count_after + ); + + // Log special case: last connector removed + if count_before == 1 && count_after == 0 { + info!("Last connector removed - transitioning to empty state"); + } + + // Deregister connector from inspector (if available) + #[cfg(feature = "server")] + if let Some(ref inspector) = self.inspector { + let node_id = dirigent_inspector::NodeId::new(format!("dirigent/connectors/{}", id)); + let _ = inspector.deregister_subtree(&node_id).await; + } + + // Emit ConnectorRemoved event onto the SharingBus. + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + dirigent_protocol::Event::ConnectorRemoved { + connector_id: id.clone(), + }, + None, // Task 9 will wire connector_uid + id.clone(), + ); + self.sharing_bus.publish(bus_event).await; + + Ok(()) + } + + /// Send a command to a specific connector + /// + /// This method validates ownership and then sends the command to the + /// connector's command channel for async processing. + /// + /// # Arguments + /// + /// * `id` - The unique identifier of the connector + /// * `cmd` - The command to send + /// + /// # Returns + /// + /// Ok(()) if the command was sent successfully, or an error. + /// + /// # Errors + /// + /// - `NotFound` - No connector with the specified ID exists + /// - `Internal` - Failed to send command to the connector's channel + /// + /// # Note + /// + /// This method does not perform ownership validation yet. In a future + /// version, it will check that the requesting user owns the connector. + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig, ConnectorCommand}; + /// # async fn example(runtime: &CoreRuntime, connector_id: &str) -> Result<(), Box<dyn std::error::Error>> { + /// runtime.send_command( + /// &connector_id.to_string(), + /// ConnectorCommand::ListSessions + /// ).await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn send_command( + &self, + id: &ConnectorId, + cmd: ConnectorCommand, + ) -> Result<(), CoreError> { + debug!(connector_id = %id, command = ?cmd, "Sending command to connector"); + + // Get the connector handle + let handle = { + let connectors = self.connectors.read().await; + connectors.get(id).cloned().ok_or_else(|| { + error!(connector_id = %id, "Connector not found"); + CoreError::NotFound + })? + }; + + // TODO: Add ownership validation here + // For now, we allow any command to any connector + // In the future, we'll check: + // if handle.owner() != current_user { + // return Err(CoreError::Unauthorized); + // } + + // Send the command + let cmd_tx = handle.command_tx(); + cmd_tx.send(cmd).await.map_err(|e| { + error!(connector_id = %id, error = %e, "Failed to send command to connector"); + CoreError::Internal(format!("Failed to send command: {}", e)) + })?; + + debug!(connector_id = %id, "Command sent successfully"); + + Ok(()) + } + + /// Transfer a session to another connector + /// + /// This is called by the Gateway connector when a user invokes /select-connector. + /// It validates the target, creates or loads a session, and emits events. + /// + /// # Arguments + /// * `request` - The transfer request from Gateway + /// + /// # Returns + /// The result is sent via the oneshot channel in the request + pub async fn transfer_session(&self, request: SessionTransferRequest) { + let result = self.execute_transfer(&request).await; + + // Emit event on success + // Models/modes are included directly in the SessionTransferred event + // so the event bridge can emit config_option_update without needing + // a separate SessionMetadataReceived event. + if let SessionTransferResult::Transferred { + ref connector_id, + ref session_id, + is_new, + ref models, + ref modes, + } = result + { + let event = dirigent_protocol::Event::SessionTransferred { + from_connector: "gateway".to_string(), // Could track actual gateway ID + from_session: request.gateway_session_id.clone(), + to_connector: connector_id.clone(), + to_session: session_id.clone(), + is_new_session: is_new, + models: models.clone(), + modes: modes.clone(), + }; + + // Broadcast onto the SharingBus. + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + event.clone(), + None, // Task 9 will wire connector_uid + connector_id.clone(), + ); + self.sharing_bus.publish(bus_event).await; + + // Also emit to Gateway connector's local channel so send_message() can receive it + if let Some(gateway_connector) = self.get_connector(&request.gateway_connector_id).await + { + let _ = gateway_connector.events_sender().send(event.clone()); + } + + // Track the transfer for potential fallback + // Find a Gateway connector to use as fallback + let gateway_connector_id = self + .find_gateway_connector() + .await + .unwrap_or_else(|| "gateway-placeholder".to_string()); + + let transfer_record = TransferredSession { + gateway_session_id: request.gateway_session_id.clone(), + gateway_connector_id, + target_connector_id: connector_id.clone(), + target_session_id: session_id.clone(), + transferred_at: std::time::Instant::now(), + }; + + let mut transfers = self.transferred_sessions.write().await; + transfers.insert(session_id.clone(), transfer_record); + } + + // Send result back to Gateway + let _ = request.result_tx.send(result); + } + + /// Get the agent_type from a connector's configuration + /// + /// Returns `ConnectorAgentType::Custom` if the connector is not found or not an ACP connector. + /// Used by the ACP server to apply mode/model mappings when setting session modes/models. + pub async fn get_connector_agent_type( + &self, + connector_id: &str, + ) -> crate::connectors::acp::config::ConnectorAgentType { + use crate::connectors::acp::config::ConnectorAgentType; + + // Read config to get connector config + let config_guard = self.config.read().await; + + // Find the connector config by ID + if let Some(connector_cfg) = config_guard + .connectors + .iter() + .find(|c| c.id.as_deref() == Some(connector_id)) + { + // Only ACP connectors have agent_type + if connector_cfg.kind == ConnectorKind::Acp { + // Try to deserialize to AcpConfig to extract agent_type + if let Ok(acp_config) = serde_json::from_value::< + crate::connectors::acp::config::AcpConfig, + >(connector_cfg.params.clone()) + { + return acp_config.agent_type; + } + } + } + + // Default to Custom if not found or not an ACP connector + ConnectorAgentType::Custom + } + + /// Execute the transfer logic + /// + /// Creates or loads a session in the target connector. The target connector's + /// actual modes/models are forwarded to the editor via `config_option_update` + /// (handled by the event bridge). + async fn execute_transfer(&self, request: &SessionTransferRequest) -> SessionTransferResult { + // 1. Validate target connector exists + let connector = match self.get_connector(&request.target_connector_id).await { + Some(c) => c, + None => { + return SessionTransferResult::Failed(format!( + "Connector '{}' not found", + request.target_connector_id + )); + } + }; + + // 2. Check connector kind supports transfer + if !connector.kind().supports_session_transfer() { + return SessionTransferResult::Failed(format!( + "Connector '{}' does not support session transfer", + request.target_connector_id + )); + } + + // 3. Check connector is Ready + let state = connector.state(); + if state != ConnectorState::Ready { + return SessionTransferResult::Failed(format!( + "Connector '{}' is not ready (state: {:?})", + request.target_connector_id, state + )); + } + + // 4. Create or load session in target connector + let cmd_tx = connector.command_tx(); + let mut events = connector.subscribe(); + + if let Some(ref target_session_id) = request.target_session_id { + // Try to load existing session + let transfer_cwd = connector + .working_directory() + .map(|p| p.to_string_lossy().to_string()) + .unwrap_or_else(|| ".".to_string()); + if let Err(e) = cmd_tx + .send(ConnectorCommand::LoadSession { + session_id: target_session_id.clone(), + cwd: transfer_cwd, + mcp_servers: None, + }) + .await + { + return SessionTransferResult::Failed(format!( + "Failed to send load command: {}", + e + )); + } + + // Wait for response + match session_transfer::wait_for_session_event( + &mut events, + target_session_id, + Duration::from_secs(10), + ) + .await + { + Ok(session_id) => { + // For loaded sessions, we don't have models/modes from the event + // They will be populated via SessionMetadataReceived event if the connector sends it + return SessionTransferResult::Transferred { + connector_id: request.target_connector_id.clone(), + session_id, + is_new: false, + models: None, + modes: None, + }; + } + Err(_) => { + // Load failed, create new session instead + tracing::info!( + "Session '{}' not found in connector '{}', creating new session", + target_session_id, + request.target_connector_id + ); + match session_transfer::create_session_in_connector(&cmd_tx, &mut events).await + { + Ok((id, models, modes)) => { + return SessionTransferResult::Transferred { + connector_id: request.target_connector_id.clone(), + session_id: id, + is_new: true, + models, + modes, + }; + } + Err(e) => return SessionTransferResult::Failed(e), + } + } + } + } else { + // Create new session + match session_transfer::create_session_in_connector(&cmd_tx, &mut events).await { + Ok((id, models, modes)) => { + return SessionTransferResult::Transferred { + connector_id: request.target_connector_id.clone(), + session_id: id, + is_new: true, + models, + modes, + }; + } + Err(e) => return SessionTransferResult::Failed(e), + } + } + } + + /// Find an available Gateway connector for fallback + /// + /// Returns the ID of a Ready Gateway connector, or None if none found. + async fn find_gateway_connector(&self) -> Option<String> { + let connectors = self.connectors.read().await; + + for (id, handle) in connectors.iter() { + if handle.kind() == ConnectorKind::Gateway && handle.state() == ConnectorState::Ready { + return Some(id.clone()); + } + } + + None + } + + /// Called when a connector's state changes + /// + /// Checks if any transferred sessions need to fall back to Gateway. + pub async fn check_transferred_sessions_fallback( + &self, + connector_id: &str, + new_state: &ConnectorState, + ) { + // Only care about failure states + if !matches!( + new_state, + ConnectorState::Error(_) | ConnectorState::Stopped + ) { + return; + } + + let affected_sessions: Vec<TransferredSession> = { + let transfers = self.transferred_sessions.read().await; + + transfers + .values() + .filter(|t| t.target_connector_id == connector_id) + .cloned() + .collect() + }; + + for session in affected_sessions { + self.trigger_fallback_to_gateway(&session, new_state).await; + } + } + + /// Trigger fallback for a transferred session + async fn trigger_fallback_to_gateway( + &self, + session: &TransferredSession, + failed_state: &ConnectorState, + ) { + let reason = match failed_state { + ConnectorState::Error(msg) => format!("Connector error: {}", msg), + ConnectorState::Stopped => "Connector stopped".to_string(), + _ => "Unknown failure".to_string(), + }; + + warn!( + target_connector = %session.target_connector_id, + target_session = %session.target_session_id, + gateway_session = %session.gateway_session_id, + reason = %reason, + "Triggering fallback to Gateway for transferred session" + ); + + // Emit ForwardingPanic onto the SharingBus. + let event = dirigent_protocol::Event::ForwardingPanic { + connector_id: session.target_connector_id.clone(), + session_id: session.target_session_id.clone(), + reason: reason.clone(), + fallback_gateway_session: Some(session.gateway_session_id.clone()), + }; + let bus_event = dirigent_protocol::streaming::BusEvent::from_connector_event( + event, + None, // Task 9 will wire connector_uid + session.target_connector_id.clone(), + ); + self.sharing_bus.publish(bus_event).await; + + // Remove from tracking + let mut transfers = self.transferred_sessions.write().await; + transfers.remove(&session.target_session_id); + + // Notify Gateway to create/restore session with error message + self.notify_gateway_fallback(session, &reason).await; + } + + /// Notify Gateway about a fallback + async fn notify_gateway_fallback(&self, session: &TransferredSession, reason: &str) { + // Get Gateway connector + let gateway = self.get_connector(&session.gateway_connector_id).await; + + if let Some(gateway) = gateway { + // Send a message to Gateway to display to the user + let cmd_tx = gateway.command_tx(); + let error_message = format!( + "**Connection Lost**\n\n\ + The session with {} was interrupted: {}\n\n\ + You have been returned to Gateway. Use `/list-connectors` to reconnect.", + session.target_connector_id, reason + ); + + // Ensure the Gateway session exists and send error message + let _ = cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session.gateway_session_id.clone(), + text: format!("__SYSTEM_ERROR__:{}", error_message), + }) + .await; + } + } + + /// Remove stale transfer records older than the given duration + pub fn cleanup_stale_transfers(&self, max_age: Duration) { + let now = std::time::Instant::now(); + + if let Ok(mut transfers) = self.transferred_sessions.try_write() { + transfers.retain(|_, t| now.duration_since(t.transferred_at) < max_age); + } + } + + /// Return a reference to the primary event bus. + /// + /// The bus is the primary fan-out for all runtime/connector events. + /// Callers can subscribe with an + /// [`dirigent_protocol::streaming::EventFilter`] to receive only the + /// events they care about. + pub fn sharing_bus(&self) -> &Arc<SharingBus> { + &self.sharing_bus + } + + /// Return a reference to the live [`StreamRegistry`]. + /// + /// Useful when callers want direct access (e.g. for telemetry) without + /// going through the [`attach_stream`](Self::attach_stream) / + /// [`detach_stream`](Self::detach_stream) convenience methods. + pub fn stream_registry(&self) -> &Arc<crate::sharing::StreamRegistry> { + &self.stream_registry + } + + /// Build a stream from `config` using the registered factory for + /// `config.kind`, then attach it to the bus. + /// + /// Returns the freshly assigned [`StreamId`] on success. Errors if no + /// factory is registered for the given kind, or if the factory's + /// `build` call fails (see [`StreamBuildError`]). + pub async fn attach_stream( + &self, + config: crate::sharing::StreamConfig, + ) -> Result<crate::sharing::StreamId, crate::sharing::StreamBuildError> { + let factory = self + .stream_factories + .get(&config.kind) + .cloned() + .ok_or_else(|| crate::sharing::StreamBuildError::UnknownKind(config.kind.clone()))?; + let stream = factory.build(&config).await?; + let id = self.stream_registry.attach(config.name, stream).await; + Ok(id) + } + + /// Detach a previously attached stream. Returns `true` if the id was + /// live and the stream has now been shut down; `false` if the id was + /// already absent from the registry. + pub async fn detach_stream(&self, id: crate::sharing::StreamId) -> bool { + self.stream_registry.detach(id).await.is_some() + } + + /// Snapshot every registered stream. See [`StreamRegistry::list`]. + pub async fn list_streams(&self) -> Vec<crate::sharing::StreamInfo> { + self.stream_registry.list().await + } + + /// Replay the archived messages of `scroll_id` onto the stream + /// identified by `stream_id`. + /// + /// Reads metadata + messages from the attached archivist and dispatches + /// synthetic `BusEvent`s with `EventOrigin::Replay { .. }` directly to + /// the target stream, bypassing the `SharingBus`. Live events remain + /// unaffected. + /// + /// Errors if the stream isn't registered, no archivist is configured on + /// this runtime, or the archive read fails fatally. Per-event stream + /// failures are counted in the returned [`ReplayReport`] instead of + /// propagated. + #[cfg(feature = "server")] + pub async fn replay_session_to_stream( + &self, + scroll_id: uuid::Uuid, + stream_id: crate::sharing::StreamId, + opts: crate::sharing::ReplayOptions, + ) -> Result<crate::sharing::ReplayReport, ReplaySessionError> { + let stream = self + .stream_registry + .get_stream(stream_id) + .await + .ok_or(ReplaySessionError::StreamNotFound(stream_id))?; + let archivist_guard = self.archivist.read().await; + let archivist = archivist_guard + .as_ref() + .ok_or(ReplaySessionError::NoArchivist)?; + crate::sharing::replay::replay_session_to_stream( + archivist.as_ref(), + scroll_id, + stream, + opts, + ) + .await + .map_err(ReplaySessionError::Replay) + } + + /// Emit a runtime-origin event onto the primary event bus. + /// + /// This allows non-connector components (e.g., system task registry, + /// archivist-adjacent flows, import jobs) to publish events that are + /// delivered to all bus subscribers. + /// + /// # Return value + /// + /// Returns `0` for historical signature compatibility with the + /// pre-Phase-4 API (which reported the number of `global_events` + /// receivers that saw the event). Routing is now fire-and-forget + /// through the bus worker, so this number is no longer available + /// synchronously and no remaining caller inspects it. + pub fn emit_event(&self, event: dirigent_protocol::Event) -> usize { + let bus = Arc::clone(&self.sharing_bus); + let bus_event = dirigent_protocol::streaming::BusEvent { + routing: dirigent_protocol::streaming::EventRouting::default(), + origin: dirigent_protocol::streaming::EventOrigin::Runtime, + event: Arc::new(event), + }; + // `publish` is async; we have no connector context here and must + // remain `fn` (non-async) for the existing call-site ergonomics. + tokio::spawn(async move { + bus.publish(bus_event).await; + }); + 0 + } + + /// Get a reference to the runtime configuration + /// + /// Returns an Arc to the RwLock-protected configuration. This allows + /// reading and potentially updating the configuration at runtime. + /// + /// # Returns + /// + /// Arc<RwLock<CoreConfig>> for shared access to configuration + pub fn config(&self) -> Arc<RwLock<CoreConfig>> { + Arc::clone(&self.config) + } + + /// Get a reference to the connectors registry + /// + /// This is primarily for internal use by methods that need direct + /// access to the connector map for mutations (create, remove). + /// + /// # Returns + /// + /// Reference to the RwLock-protected connector HashMap + // TODO: Low-level runtime access - internal methods for advanced connector/user management operations + #[allow(dead_code)] + pub(crate) fn connectors_lock(&self) -> &RwLock<HashMap<ConnectorId, ConnectorHandle>> { + &self.connectors + } + + /// Get a reference to the users registry + /// + /// This is primarily for internal use by methods that need to manage + /// user data and perform authorization checks. + /// + /// # Returns + /// + /// Reference to the RwLock-protected user HashMap + // TODO: Low-level runtime access - internal methods for advanced connector/user management operations + #[allow(dead_code)] + pub(crate) fn users_lock(&self) -> &RwLock<HashMap<UserId, User>> { + &self.users + } + + /// Get a reference to the archivist (server feature only) + /// + /// Returns None if no archivist is configured or if not running in server mode. + /// + /// # Returns + /// + /// Option<Arc<Archivist>> for archival storage operations + #[cfg(feature = "server")] + pub async fn archivist(&self) -> Option<Arc<dirigent_archivist::Archivist>> { + self.archivist.read().await.clone() + } + + /// Set the archivist at runtime (for hot-reload activation) + #[cfg(feature = "server")] + pub async fn set_archivist(&self, archivist: Option<Arc<dirigent_archivist::Archivist>>) { + let mut guard = self.archivist.write().await; + *guard = archivist; + } + + /// Get a reference to the archivist slot (for sharing with AppState) + #[cfg(feature = "server")] + pub fn archivist_slot(&self) -> &Arc<tokio::sync::RwLock<Option<Arc<dirigent_archivist::Archivist>>>> { + &self.archivist + } + + /// Get the task runner slot (shared reference for AppState integration) + #[cfg(feature = "server")] + pub fn task_runner_slot( + &self, + ) -> &Arc<tokio::sync::RwLock<Option<Arc<dirigent_taskrunner::TaskRunner>>>> { + &self.task_runner + } + + // ----------------------------------------------------------------------- + // Matrix integration + // ----------------------------------------------------------------------- + + /// Initialize and start the Matrix service if configured. + /// + /// Reads the `matrix` behavior config and resolves the referenced account, + /// then creates and starts the service. + /// Call during server startup after CoreRuntime is created. + #[cfg(feature = "server")] + pub async fn start_matrix_service(&self) -> std::result::Result<(), String> { + let config_guard = self.config.read().await; + let behavior = match &config_guard.matrix { + Some(b) => b.clone(), + None => { + debug!("No [matrix] configuration found, skipping Matrix service startup"); + return Ok(()); + } + }; + + let account = config_guard + .accounts + .get(&behavior.account) + .ok_or_else(|| { + format!( + "Matrix config references account '{}' but it is not defined in [accounts]", + behavior.account + ) + })? + .clone(); + drop(config_guard); + + if account.kind != dirigent_auth::AccountKind::Matrix { + return Err(format!( + "Account '{}' has type {:?}, expected 'matrix'", + behavior.account, account.kind + )); + } + + let data_dir = dirigent_config::DirigentPaths::resolve() + .map_err(|e| format!("Failed to resolve data directory: {}", e))? + .data_dir() + .to_path_buf(); + + let service = dirigent_matrix::MatrixService::from_account(&account, behavior, data_dir) + .map_err(|e| format!("Failed to create Matrix service: {}", e))?; + + service + .login() + .await + .map_err(|e| format!("Matrix login failed: {}", e))?; + service + .start_sync() + .await + .map_err(|e| format!("Matrix sync start failed: {}", e))?; + + let service_arc = Arc::new(service); + *self.matrix_service.write().await = Some(service_arc); + + info!("Matrix service started successfully"); + Ok(()) + } + + /// Create a Matrix session share. + /// + /// Creates a Matrix room and starts a bidirectional bridge between the + /// specified connector session and the room. + /// + /// Returns the Matrix room ID string on success. + #[cfg(feature = "server")] + pub async fn create_matrix_share( + &self, + connector_id: &str, + session_id: &str, + session_title: Option<&str>, + ) -> std::result::Result<String, String> { + let service_guard = self.matrix_service.read().await; + let service = service_guard + .as_ref() + .ok_or("Matrix service not started")?; + + let client = service + .client_cloned() + .await + .ok_or("Matrix client not logged in")?; + + // Get the connector + let connectors = self.connectors.read().await; + let connector = connectors + .get(connector_id) + .ok_or_else(|| format!("Connector '{}' not found", connector_id))?; + + let event_rx = connector.subscribe(); + let connector_cmd_tx = connector.command_tx(); + drop(connectors); + + // Create room + let room_name = + dirigent_matrix::room::room_name_for_session(connector_id, session_title); + let invite = service.behavior().default_invite.clone(); + + let room_id = dirigent_matrix::room::create_share_room( + &client, + dirigent_matrix::CreateRoomOptions { + name: room_name, + topic: Some(format!( + "Dirigent session share: {} / {}", + connector_id, session_id + )), + invite, + }, + ) + .await + .map_err(|e| format!("Failed to create Matrix room: {}", e))?; + + let room_id_str = room_id.to_string(); + + // Get the Room handle for the share + let room = client + .get_room(&room_id) + .ok_or_else(|| format!("Room {} not found after creation", room_id_str))?; + + // Start the share — returns (share, command_rx) + let (share, mut command_rx) = dirigent_matrix::MatrixSessionShare::start( + connector_id.to_string(), + session_id.to_string(), + room_id_str.clone(), + room, + event_rx, + ); + + // Spawn proxy task: command_rx -> ConnectorCommand::SendMessage + let cmd_tx = connector_cmd_tx.clone(); + tokio::spawn(async move { + while let Some(proxy) = command_rx.recv().await { + let cmd = crate::connectors::ConnectorCommand::SendMessage { + session_id: proxy.session_id, + text: proxy.text, + }; + if cmd_tx.send(cmd).await.is_err() { + break; + } + } + }); + + // Register share + service + .register_share(share) + .await + .map_err(|e| format!("Failed to register share: {}", e))?; + + // Persist sharing metadata in archivist + if let Some(archivist) = self.archivist().await { + if let Some(connector_uid) = self.get_archivist_connector_uid(connector_id).await { + if let Ok(scroll_id) = archivist.resolve_session(connector_uid, session_id, None).await { + if let Err(e) = archivist.update_session_sharing( + scroll_id, + Some(room_id_str.clone()), + true, + None, + ).await { + tracing::warn!("Failed to persist sharing metadata: {}", e); + } + } + } + } + + info!( + connector_id = %connector_id, + session_id = %session_id, + room_id = %room_id_str, + "Matrix session share created" + ); + + Ok(room_id_str) + } + + /// Stop the Matrix service (if running). + /// + /// Shuts down all active shares and releases the client, then clears the + /// service slot. Returns `true` if a service was running and was stopped. + #[cfg(feature = "server")] + pub async fn stop_matrix_service(&self) -> bool { + let mut guard = self.matrix_service.write().await; + if let Some(service) = guard.take() { + service.shutdown().await; + info!("Matrix service stopped"); + true + } else { + false + } + } + + /// Restart the Matrix service. + /// + /// Stops any running instance, then starts a fresh one from the current + /// configuration. This picks up any config changes made since the last + /// start. + #[cfg(feature = "server")] + pub async fn restart_matrix_service(&self) -> std::result::Result<(), String> { + self.stop_matrix_service().await; + self.start_matrix_service().await + } + + /// Get the Matrix service (if started). + #[cfg(feature = "server")] + pub async fn matrix_service(&self) -> Option<Arc<dirigent_matrix::MatrixService>> { + self.matrix_service.read().await.clone() + } + + /// Register an existing connector with the archivist (for hot-reload) + /// + /// Used during hot-reload to register connectors that were created + /// before the archivist was activated. + #[cfg(feature = "server")] + pub async fn register_connector_with_archivist( + &self, + connector_id: &str, + ) -> Result<(), CoreError> { + let archivist = self + .archivist() + .await + .ok_or_else(|| CoreError::Internal("Archivist not configured".to_string()))?; + + let connector = self + .get_connector(&connector_id.to_string()) + .await + .ok_or(CoreError::NotFound)?; + + let connector_kind = connector.kind(); + let connector_title = connector.title().to_string(); + let fingerprint = { + let config_arc = self.config.read().await; + config_arc + .connectors + .iter() + .find(|c| c.id.as_deref() == Some(connector_id)) + .and_then(|cfg| { + crate::connectors::fingerprint::compute_fingerprint(&cfg.kind, &cfg.params) + }) + }; + + let register_req = dirigent_archivist::types::RegisterConnectorRequest { + custom_uid: uuid::Uuid::try_parse(connector_id).ok(), + r#type: format!("{:?}", connector_kind), + title: connector_title, + client_native_id: connector_id.to_string(), + metadata: serde_json::json!({}), + fingerprint: fingerprint.clone(), + }; + + match archivist.register_connector(register_req, None).await { + Ok(response) => { + self.archivist_connector_uids + .write() + .await + .insert(connector_id.to_string(), response.connector_uid); + info!( + connector_id = %connector_id, + connector_uid = %response.connector_uid, + "Registered existing connector with archivist (hot-reload)" + ); + Ok(()) + } + Err(e) => { + warn!( + connector_id = %connector_id, + error = %e, + "Failed to register connector with archivist" + ); + Err(CoreError::Internal(format!( + "Failed to register connector: {}", + e + ))) + } + } + } + + /// Get the archivist connector_uid for a given connector_id + /// + /// This mapping enables archive-first reads to work with non-UUID connector IDs. + /// When a connector is created with a custom ID (e.g., "opencode-1" from config), + /// the archivist still uses a UUID internally. This method retrieves that UUID. + /// + /// # Arguments + /// + /// * `connector_id` - The connector ID (may be UUID or custom string) + /// + /// # Returns + /// + /// Some(Uuid) if the connector is registered with the archivist, None otherwise + #[cfg(feature = "server")] + pub async fn get_archivist_connector_uid(&self, connector_id: &str) -> Option<uuid::Uuid> { + self.archivist_connector_uids + .read() + .await + .get(connector_id) + .copied() + } + + /// Get the connector_id for a given archivist connector_uid (reverse lookup) + /// + /// This is the inverse of `get_archivist_connector_uid`. Since the map is small + /// (typically < 10 entries), iterating is efficient. + /// + /// # Arguments + /// * `uid` - The archivist connector UUID + /// + /// # Returns + /// Some(String) with the connector_id if found, None otherwise + #[cfg(feature = "server")] + pub async fn get_connector_id_by_uid(&self, uid: uuid::Uuid) -> Option<String> { + self.archivist_connector_uids + .read() + .await + .iter() + .find(|(_, &v)| v == uid) + .map(|(k, _)| k.clone()) + } + + /// Backfill sessions from a connector into the archivist + /// + /// This method imports existing sessions from a connector that supports + /// `list_sessions()` and `list_messages()` operations into the archivist. + /// + /// # Arguments + /// + /// * `connector_id` - The unique identifier of the connector to backfill from + /// + /// # Returns + /// + /// Statistics about the backfill operation including number of sessions and messages imported + /// + /// # Errors + /// + /// - `NotFound` - Connector with the specified ID doesn't exist + /// - `Internal` - Archivist is not configured or other internal errors + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig}; + /// # async fn example(runtime: &CoreRuntime) -> Result<(), Box<dyn std::error::Error>> { + /// let stats = runtime.backfill_connector_sessions("opencode-1").await?; + /// println!("Imported {} sessions with {} messages", + /// stats.sessions_imported, stats.messages_imported); + /// # Ok(()) + /// # } + /// ``` + #[cfg(feature = "server")] + pub async fn backfill_connector_sessions( + &self, + connector_id: &str, + ) -> Result<dirigent_archivist::BackfillStats, CoreError> { + use futures::future::BoxFuture; + use std::time::Duration; + + info!(connector_id = %connector_id, "Starting connector session backfill"); + + // Get archivist + let archivist = self + .archivist + .read() + .await + .clone() + .ok_or_else(|| CoreError::Internal("Archivist not configured".to_string()))?; + + // Get connector handle + let connector = self + .get_connector(&connector_id.to_string()) + .await + .ok_or_else(|| { + error!(connector_id = %connector_id, "Connector not found for backfill"); + CoreError::NotFound + })?; + + // Get connector UID from archivist mapping + // This is the same UID used for archive-first reads and session registration + let connector_uid = self + .archivist_connector_uids + .read() + .await + .get(connector_id) + .copied() + .ok_or_else(|| { + error!( + connector_id = %connector_id, + "Connector not registered with archivist (no connector_uid mapping)" + ); + CoreError::Internal(format!( + "Connector {} not registered with archivist. \ + This should not happen - connector registration happens on creation.", + connector_id + )) + })?; + + debug!( + connector_id = %connector_id, + connector_uid = %connector_uid, + "Using archivist connector_uid for backfill" + ); + + // Send ListSessions command and wait for response + let mut events = connector.subscribe(); + let cmd_tx = connector.command_tx(); + + cmd_tx + .send(ConnectorCommand::ListSessions) + .await + .map_err(|e| { + error!(connector_id = %connector_id, error = %e, "Failed to send ListSessions command"); + CoreError::Internal(format!("Failed to send command: {}", e)) + })?; + + // Wait for SessionsListed event with timeout + let sessions = tokio::time::timeout(Duration::from_secs(30), async { + while let Ok(event) = events.recv().await { + if let dirigent_protocol::Event::SessionsListed { + connector_id: _, + sessions, + } = event + { + return Ok(sessions); + } + } + Err(CoreError::Internal( + "No SessionsListed event received".to_string(), + )) + }) + .await + .map_err(|_| { + error!(connector_id = %connector_id, "Timeout waiting for sessions list"); + CoreError::Internal("Timeout waiting for sessions list".to_string()) + })??; + + debug!( + connector_id = %connector_id, + session_count = sessions.len(), + "Received {} sessions from connector", + sessions.len() + ); + + // Create closure to fetch messages for each session + let connector_clone = connector.clone(); + let connector_id_clone = connector_id.to_string(); + let fetch_messages = move |session_id: &str| { + let session_id = session_id.to_string(); + let connector = connector_clone.clone(); + let connector_id = connector_id_clone.clone(); + + Box::pin(async move { + debug!( + connector_id = %connector_id, + session_id = %session_id, + "Fetching messages for session" + ); + + // Subscribe to events + let mut events = connector.subscribe(); + let cmd_tx = connector.command_tx(); + + // Send ListMessages command + cmd_tx + .send(ConnectorCommand::ListMessages { + session_id: session_id.clone(), + }) + .await + .map_err(|e| { + dirigent_archivist::ArchivistError::InvalidRequest(format!( + "Failed to send command: {}", + e + )) + })?; + + // Wait for MessagesListed event + let messages = tokio::time::timeout(Duration::from_secs(30), async { + while let Ok(event) = events.recv().await { + if let dirigent_protocol::Event::MessagesListed { messages } = event { + // Note: MessagesListed doesn't have a session_id field + // We assume the messages are for the session we requested + return Ok(messages); + } + } + Err(dirigent_archivist::ArchivistError::InvalidRequest( + "No MessagesListed event received".to_string(), + )) + }) + .await + .map_err(|_| { + dirigent_archivist::ArchivistError::InvalidRequest( + "Timeout waiting for messages list".to_string(), + ) + })??; + + debug!( + connector_id = %connector_id, + session_id = %session_id, + message_count = messages.len(), + "Fetched {} messages for session", + messages.len() + ); + + Ok(messages) + }) + as BoxFuture< + 'static, + Result<Vec<dirigent_protocol::Message>, dirigent_archivist::ArchivistError>, + > + }; + + // Perform backfill + let stats = dirigent_archivist::backfill_from_sessions( + &*archivist, + connector_uid, + sessions, + fetch_messages, + ) + .await + .map_err(|e| { + error!( + connector_id = %connector_id, + error = %e, + "Backfill operation failed" + ); + CoreError::Internal(format!("Backfill failed: {}", e)) + })?; + + info!( + connector_id = %connector_id, + sessions_imported = stats.sessions_imported, + messages_imported = stats.messages_imported, + errors = stats.errors.len(), + "Connector session backfill completed" + ); + + Ok(stats) + } + + /// Update ACP Server configuration + /// + /// This method updates the ACP Server configuration in the runtime's CoreConfig. + /// After updating, you should call `save_config()` to persist the changes to disk. + /// + /// # Arguments + /// + /// * `config` - The new ACP Server configuration + /// + /// # Returns + /// + /// Ok(()) if the configuration was updated successfully. + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig, AcpServerConfig}; + /// # async fn example() -> Result<(), Box<dyn std::error::Error>> { + /// let runtime = CoreRuntime::new(CoreConfig::default()); + /// let acp_config = AcpServerConfig { + /// enabled: true, + /// port: 3001, + /// allowed_origins: None, + /// max_connections: 100, + /// default_connector_id: None, + /// }; + /// runtime.update_acp_server_config(acp_config).await?; + /// runtime.save_config().await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn update_acp_server_config(&self, config: AcpServerConfig) -> Result<(), CoreError> { + config_manager::update_acp_server_config(&self.config, config).await + } + + /// Save the current configuration to disk + /// + /// This method serializes the current runtime configuration to JSON or TOML format + /// and writes it to the filesystem atomically (write to temp file, then rename). + /// + /// The save path is determined as follows: + /// 1. The file the config was originally loaded from, promoted to user config dir if project-local + /// 2. User config directory default (`~/.config/dirigent/dirigent.toml`) + /// + /// The file format is determined by the extension of the save path. + /// + /// # Returns + /// + /// Ok(()) if the configuration was saved successfully, or an error. + /// + /// # Errors + /// + /// Returns `CoreError::Internal` if: + /// - Failed to serialize the configuration + /// - Failed to write to the temp file + /// - Failed to rename the temp file to the final path + /// + /// # Example + /// + /// ```no_run + /// # use dirigent_core::{CoreRuntime, CoreConfig}; + /// # async fn example() -> Result<(), Box<dyn std::error::Error>> { + /// let runtime = CoreRuntime::new(CoreConfig::default(, None)); + /// runtime.save_config().await?; + /// # Ok(()) + /// # } + /// ``` + pub async fn save_config(&self) -> Result<(), CoreError> { + config_manager::save_config(&self.config).await + } + + /// Update the connector summary cache from the current connectors map + /// + /// This method rebuilds the sync-accessible cache used by GatewayConnector + /// callbacks. It should be called after any connector is added or removed. + /// + /// The cache uses a std::sync::RwLock so it can be accessed from synchronous + /// callback functions that don't have access to an async runtime. + pub async fn update_connector_summary_cache(&self) { + summary_cache::update_connector_summary_cache( + &self.connectors, + &self.config, + &self.connector_summary_cache, + ) + .await + } + + /// Get a clone of the connector summary cache Arc + /// + /// This returns the Arc<std::sync::RwLock<Vec<ConnectorSummaryInfo>>> that can + /// be shared with synchronous callbacks (e.g., in GatewayConnector). + pub fn connector_summary_cache(&self) -> Arc<std::sync::RwLock<Vec<ConnectorSummaryInfo>>> { + Arc::clone(&self.connector_summary_cache) + } +} + +/// Handle to a CoreRuntime +/// +/// CoreHandle is a lightweight, cloneable wrapper around CoreRuntime. +/// It uses Arc internally, so cloning is cheap and all clones refer +/// to the same underlying runtime. +/// +/// # Purpose +/// +/// CoreHandle provides ergonomic access to the runtime: +/// - Cheap to clone and pass between tasks +/// - Implements Deref to CoreRuntime for transparent method access +/// - Can be stored in application state (e.g., Axum/Dioxus) +/// +/// # Usage +/// +/// ```no_run +/// use dirigent_core::{CoreRuntime, CoreConfig, CoreHandle}; +/// +/// # async fn example() { +/// let runtime = CoreRuntime::new(CoreConfig::default(, None)); +/// let handle = CoreHandle::new(runtime); +/// +/// // Clone is cheap - just increments Arc refcount +/// let handle2 = handle.clone(); +/// +/// // Both handles refer to the same runtime +/// let connectors1 = handle.list_connectors(None).await; +/// let connectors2 = handle2.list_connectors(None).await; +/// assert_eq!(connectors1.len(), connectors2.len()); +/// # } +/// ``` +#[derive(Clone)] +pub struct CoreHandle { + /// The underlying runtime wrapped in Arc for cheap cloning + runtime: Arc<CoreRuntime>, +} + +impl CoreHandle { + /// Create a new CoreHandle wrapping the given runtime + /// + /// # Arguments + /// + /// * `runtime` - The CoreRuntime to wrap + /// + /// # Returns + /// + /// A new CoreHandle that can be cloned and shared + /// + /// # Example + /// + /// ```no_run + /// use dirigent_core::{CoreRuntime, CoreConfig, CoreHandle}; + /// + /// let runtime = CoreRuntime::new(CoreConfig::default(, None)); + /// let handle = CoreHandle::new(runtime); + /// ``` + pub fn new(runtime: CoreRuntime) -> Self { + let arc = Arc::new(runtime); + + // Set the weak reference to self + let weak = Arc::downgrade(&arc); + { + // Use try_write() instead of blocking_write() to avoid blocking the async runtime + let mut self_weak = arc + .self_weak + .try_write() + .expect("self_weak lock should not be contended during construction"); + *self_weak = Some(weak); + } // Drop the lock guard here + + Self { runtime: arc } + } + + /// Get a reference to the underlying Arc<CoreRuntime> + /// + /// This is useful when you need direct access to the Arc, for example + /// when storing the runtime in application state. + /// + /// # Returns + /// + /// A reference to the Arc<CoreRuntime> + pub fn inner(&self) -> &Arc<CoreRuntime> { + &self.runtime + } + + /// Start a background task to monitor connector state changes and handle fallbacks + /// + /// This task subscribes to the SharingBus and watches for connector + /// errors and disconnections. When a connector that has transferred sessions + /// fails, it triggers automatic fallback to Gateway. + /// + /// # Returns + /// + /// A JoinHandle for the monitoring task + pub fn start_fallback_monitor(&self) -> tokio::task::JoinHandle<()> { + let runtime = self.runtime.clone(); + + tokio::spawn(async move { + let mut bus_rx = runtime.sharing_bus.subscribe_all().await; + + info!("Started fallback monitor for transferred sessions"); + + while let Some(bus_event) = bus_rx.rx.recv().await { + // Check for connector failure events + match bus_event.event.as_ref() { + dirigent_protocol::Event::Error { .. } + | dirigent_protocol::Event::Disconnected => { + // When we see an error or disconnect, check all connectors + // to see if any are in failed states with transferred sessions + let connectors = runtime.connectors.read().await; + for (id, handle) in connectors.iter() { + let state = handle.state(); + runtime + .check_transferred_sessions_fallback(id, &state) + .await; + } + } + _ => {} + } + } + info!("Bus event stream closed, stopping fallback monitor"); + }) + } +} + +impl Deref for CoreHandle { + type Target = CoreRuntime; + + fn deref(&self) -> &Self::Target { + &self.runtime + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{ConnectorKind, ConnectorState}; + use tokio::sync::{broadcast, mpsc}; + + #[tokio::test] + async fn test_core_runtime_new() { + // `SharingBus::new` spawns a worker task, so we need a Tokio + // runtime in scope. `CoreRuntime::new` itself remains sync. + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Runtime should be initialized with empty state. We can't + // easily test internal state without async, but we can verify + // construction succeeded. + drop(runtime); + } + + #[tokio::test] + async fn test_list_connectors_empty() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + let connectors = runtime.list_connectors(None).await; + assert_eq!(connectors.len(), 0); + } + + #[tokio::test] + async fn test_list_connectors_with_data() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Add a test connector + let (cmd_tx, _cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + let handle = ConnectorHandle::new( + "test-conn".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test Connector".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, // icon_path + false, // show_type_overlay + ); + + { + let mut connectors = runtime.connectors.write().await; + connectors.insert("test-conn".to_string(), handle); + } + + // List all connectors + let all = runtime.list_connectors(None).await; + assert_eq!(all.len(), 1); + assert_eq!(all[0].id, "test-conn"); + assert_eq!(all[0].title, "Test Connector"); + assert_eq!(all[0].owner, uuid::Uuid::nil()); + + // List connectors for specific owner + let user1 = runtime.list_connectors(Some(uuid::Uuid::nil())).await; + assert_eq!(user1.len(), 1); + + let user2 = runtime + .list_connectors(Some(uuid::Uuid::from_u128(2))) + .await; + assert_eq!(user2.len(), 0); + } + + #[tokio::test] + async fn test_list_connectors_filters_by_owner() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Add connectors for different users + let (cmd_tx1, _) = mpsc::channel(100); + let (events_tx1, _) = broadcast::channel(1000); + let handle1 = ConnectorHandle::new( + "conn-1".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Connector 1".to_string(), + cmd_tx1, + events_tx1, + serde_json::json!({}), + None, + None, + false, + ); + + let (cmd_tx2, _) = mpsc::channel(100); + let (events_tx2, _) = broadcast::channel(1000); + let handle2 = ConnectorHandle::new( + "conn-2".to_string(), + ConnectorKind::OpenCode, + uuid::Uuid::from_u128(2), + "Connector 2".to_string(), + cmd_tx2, + events_tx2, + serde_json::json!({}), + None, + None, + false, + ); + + { + let mut connectors = runtime.connectors.write().await; + connectors.insert("conn-1".to_string(), handle1); + connectors.insert("conn-2".to_string(), handle2); + } + + // List all + let all = runtime.list_connectors(None).await; + assert_eq!(all.len(), 2); + + // List for user-1 + let user1 = runtime.list_connectors(Some(uuid::Uuid::nil())).await; + assert_eq!(user1.len(), 1); + assert_eq!(user1[0].id, "conn-1"); + + // List for user-2 + let user2 = runtime + .list_connectors(Some(uuid::Uuid::from_u128(2))) + .await; + assert_eq!(user2.len(), 1); + assert_eq!(user2[0].id, "conn-2"); + } + + #[tokio::test] + async fn test_get_connector_exists() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Add a test connector + let (cmd_tx, _cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + let handle = ConnectorHandle::new( + "test-conn".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test Connector".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + { + let mut connectors = runtime.connectors.write().await; + connectors.insert("test-conn".to_string(), handle); + } + + // Get the connector + let result = runtime.get_connector(&"test-conn".to_string()).await; + assert!(result.is_some()); + + let retrieved = result.unwrap(); + assert_eq!(retrieved.id(), "test-conn"); + assert_eq!(retrieved.title(), "Test Connector"); + } + + #[tokio::test] + async fn test_get_connector_not_found() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + let result = runtime.get_connector(&"nonexistent".to_string()).await; + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_core_handle_new() { + // `SharingBus::new` spawns a worker task, so we need a Tokio + // runtime in scope. + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + let handle = CoreHandle::new(runtime); + + // Handle should be created successfully + drop(handle); + } + + #[tokio::test] + async fn test_core_handle_deref() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + let handle = CoreHandle::new(runtime); + + // Should be able to call runtime methods directly via Deref + let connectors = handle.list_connectors(None).await; + assert_eq!(connectors.len(), 0); + } + + #[tokio::test] + async fn test_core_handle_clone() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + let handle1 = CoreHandle::new(runtime); + + // Clone should be cheap and share the same runtime + let handle2 = handle1.clone(); + + // Add a connector via handle1 + let (cmd_tx, _) = mpsc::channel(100); + let (events_tx, _) = broadcast::channel(1000); + let connector = ConnectorHandle::new( + "test-conn".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + { + let mut connectors = handle1.connectors_lock().write().await; + connectors.insert("test-conn".to_string(), connector); + } + + // handle2 should see the same connector + let connectors = handle2.list_connectors(None).await; + assert_eq!(connectors.len(), 1); + } + + #[tokio::test] + async fn test_core_handle_inner() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + let handle = CoreHandle::new(runtime); + + // Should be able to get the inner Arc + let inner = handle.inner(); + assert!(Arc::strong_count(inner) >= 1); + } + + // Tests for create_connector (T024) + + #[tokio::test] + async fn test_create_connector_generates_id() { + use crate::ConnectorConfig; + use serde_json::json; + + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + let cfg = ConnectorConfig { + id: None, // No ID provided + kind: ConnectorKind::OpenCode, + owner: None, + title: Some("Test".to_string()), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": "Test", + "initial_session": null + }), + ..Default::default() + }; + + let result = runtime.create_connector(uuid::Uuid::nil(), cfg).await; + + assert!(result.is_ok()); + let connector_id = result.unwrap(); + assert!(!connector_id.is_empty()); + + // Verify connector was created + let connector = runtime.get_connector(&connector_id).await; + assert!(connector.is_some()); + } + + #[tokio::test] + async fn test_create_connector_with_id() { + use crate::ConnectorConfig; + use serde_json::json; + + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + let cfg = ConnectorConfig { + id: Some("my-connector".to_string()), + kind: ConnectorKind::OpenCode, + owner: None, + title: Some("Test".to_string()), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": "Test", + "initial_session": null + }), + ..Default::default() + }; + + let result = runtime.create_connector(uuid::Uuid::nil(), cfg).await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "my-connector"); + } + + #[tokio::test] + async fn test_create_connector_already_exists() { + use crate::ConnectorConfig; + use serde_json::json; + + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + let cfg = ConnectorConfig { + id: Some("duplicate".to_string()), + kind: ConnectorKind::OpenCode, + owner: None, + title: Some("Test".to_string()), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": "Test", + "initial_session": null + }), + ..Default::default() + }; + + // Create first connector + let result1 = runtime + .create_connector(uuid::Uuid::nil(), cfg.clone()) + .await; + assert!(result1.is_ok()); + + // Try to create duplicate + let result2 = runtime.create_connector(uuid::Uuid::nil(), cfg).await; + assert!(result2.is_err()); + assert_eq!(result2.unwrap_err(), CoreError::AlreadyExists); + } + + #[tokio::test] + async fn test_create_connector_invalid_config() { + use crate::ConnectorConfig; + use serde_json::json; + + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + let cfg = ConnectorConfig { + id: None, + kind: ConnectorKind::OpenCode, + owner: None, + title: Some("Test".to_string()), + working_directory: None, + params: json!({ + // Missing required fields + "invalid": "config" + }), + ..Default::default() + }; + + let result = runtime.create_connector(uuid::Uuid::nil(), cfg).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), CoreError::InvalidConfig); + } + + #[tokio::test] + async fn test_create_connector_mock_not_allowed() { + use crate::ConnectorConfig; + use serde_json::json; + + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + let cfg = ConnectorConfig { + id: None, + kind: ConnectorKind::Mock, + owner: None, + title: Some("Test".to_string()), + working_directory: None, + params: json!({}), + ..Default::default() + }; + + let result = runtime.create_connector(uuid::Uuid::nil(), cfg).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), CoreError::InvalidConfig); + } + + // Tests for stop_connector (T026) + + #[tokio::test] + async fn test_stop_connector_not_found() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + let result = runtime.stop_connector(&"nonexistent".to_string()).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), CoreError::NotFound); + } + + #[tokio::test] + async fn test_stop_connector_success() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Add a mock connector + let (cmd_tx, _cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + let handle = ConnectorHandle::new( + "test-conn".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + { + let mut connectors = runtime.connectors.write().await; + connectors.insert("test-conn".to_string(), handle); + } + + // Stop should succeed even without a running task + let result = runtime.stop_connector(&"test-conn".to_string()).await; + assert!(result.is_ok()); + + // Verify state was updated to Stopped + let connector = runtime.get_connector(&"test-conn".to_string()).await; + assert!(connector.is_some()); + assert_eq!(connector.unwrap().state(), ConnectorState::Stopped); + } + + // Tests for remove_connector (T027) + + #[tokio::test] + async fn test_remove_connector_not_found() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + let result = runtime.remove_connector(&"nonexistent".to_string()).await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), CoreError::NotFound); + } + + #[tokio::test] + async fn test_remove_connector_success() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Add a mock connector + let (cmd_tx, _cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + let handle = ConnectorHandle::new( + "test-conn".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + { + let mut connectors = runtime.connectors.write().await; + connectors.insert("test-conn".to_string(), handle); + } + + // Verify it exists + assert!(runtime + .get_connector(&"test-conn".to_string()) + .await + .is_some()); + + // Remove it + let result = runtime.remove_connector(&"test-conn".to_string()).await; + assert!(result.is_ok()); + + // Verify it's gone + assert!(runtime + .get_connector(&"test-conn".to_string()) + .await + .is_none()); + } + + // Tests for send_command (T028) + + #[tokio::test] + async fn test_send_command_not_found() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + let result = runtime + .send_command(&"nonexistent".to_string(), ConnectorCommand::ListSessions) + .await; + assert!(result.is_err()); + assert_eq!(result.unwrap_err(), CoreError::NotFound); + } + + #[tokio::test] + async fn test_send_command_success() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Add a mock connector + let (cmd_tx, mut cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + let handle = ConnectorHandle::new( + "test-conn".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + { + let mut connectors = runtime.connectors.write().await; + connectors.insert("test-conn".to_string(), handle); + } + + // Send a command + let result = runtime + .send_command(&"test-conn".to_string(), ConnectorCommand::ListSessions) + .await; + assert!(result.is_ok()); + + // Verify the command was received + let received = cmd_rx.try_recv(); + assert!(received.is_ok()); + matches!(received.unwrap(), ConnectorCommand::ListSessions); + } + + #[tokio::test] + async fn test_send_command_all_types() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Add a mock connector + let (cmd_tx, mut cmd_rx) = mpsc::channel(100); + let (events_tx, _events_rx) = broadcast::channel(1000); + let handle = ConnectorHandle::new( + "test-conn".to_string(), + ConnectorKind::Mock, + uuid::Uuid::nil(), + "Test".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + { + let mut connectors = runtime.connectors.write().await; + connectors.insert("test-conn".to_string(), handle); + } + + // Test ListSessions + runtime + .send_command(&"test-conn".to_string(), ConnectorCommand::ListSessions) + .await + .unwrap(); + matches!(cmd_rx.recv().await.unwrap(), ConnectorCommand::ListSessions); + + // Test ListMessages + runtime + .send_command( + &"test-conn".to_string(), + ConnectorCommand::ListMessages { + session_id: "sess-1".to_string(), + }, + ) + .await + .unwrap(); + let cmd = cmd_rx.recv().await.unwrap(); + if let ConnectorCommand::ListMessages { session_id } = cmd { + assert_eq!(session_id, "sess-1"); + } else { + panic!("Expected ListMessages command"); + } + + // Test SendMessage + runtime + .send_command( + &"test-conn".to_string(), + ConnectorCommand::SendMessage { + session_id: "sess-1".to_string(), + text: "Hello".to_string(), + }, + ) + .await + .unwrap(); + let cmd = cmd_rx.recv().await.unwrap(); + if let ConnectorCommand::SendMessage { session_id, text } = cmd { + assert_eq!(session_id, "sess-1"); + assert_eq!(text, "Hello"); + } else { + panic!("Expected SendMessage command"); + } + + // Test Reconnect + runtime + .send_command(&"test-conn".to_string(), ConnectorCommand::Reconnect) + .await + .unwrap(); + matches!(cmd_rx.recv().await.unwrap(), ConnectorCommand::Reconnect); + + // Test Shutdown + runtime + .send_command(&"test-conn".to_string(), ConnectorCommand::Shutdown) + .await + .unwrap(); + matches!(cmd_rx.recv().await.unwrap(), ConnectorCommand::Shutdown); + } + + // Tests for the SharingBus-based event fan-out (T029 replacement). + // + // Originally T029 exercised `subscribe_global()`; that API was removed + // in Phase 4 Task 10. The replacement test verifies the SharingBus + // fan-out directly: publishing a `BusEvent` must reach every + // `subscribe_all()` subscriber. + #[tokio::test] + async fn test_sharing_bus_subscribe_all() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Two independent subscribers to the primary event fan-out. + let mut rx1 = runtime.sharing_bus.subscribe_all().await; + let mut rx2 = runtime.sharing_bus.subscribe_all().await; + + let event = dirigent_protocol::Event::Connected; + let bus_event = dirigent_protocol::streaming::BusEvent { + routing: dirigent_protocol::streaming::EventRouting::default(), + origin: dirigent_protocol::streaming::EventOrigin::Runtime, + event: Arc::new(event), + }; + runtime.sharing_bus.publish(bus_event).await; + + // Both receivers should get it. + let received1 = rx1.rx.recv().await.expect("rx1 should receive"); + let received2 = rx2.rx.recv().await.expect("rx2 should receive"); + + matches!(received1.event.as_ref(), dirigent_protocol::Event::Connected); + matches!(received2.event.as_ref(), dirigent_protocol::Event::Connected); + } + + // Integration test: create, list, remove + + #[tokio::test] + async fn test_integration_create_list_remove() { + use crate::ConnectorConfig; + use serde_json::json; + + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Initially empty + let list = runtime.list_connectors(None).await; + assert_eq!(list.len(), 0); + + // Create a connector + let cfg = ConnectorConfig { + id: Some("conn-1".to_string()), + kind: ConnectorKind::OpenCode, + owner: None, + title: Some("Test 1".to_string()), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": "Test 1", + "initial_session": null + }), + ..Default::default() + }; + + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + assert_eq!(connector_id, "conn-1"); + + // List should show 1 connector + let list = runtime.list_connectors(None).await; + assert_eq!(list.len(), 1); + assert_eq!(list[0].id, "conn-1"); + assert_eq!(list[0].owner, uuid::Uuid::nil()); + + // Remove the connector + runtime.remove_connector(&connector_id).await.unwrap(); + + // List should be empty again + let list = runtime.list_connectors(None).await; + assert_eq!(list.len(), 0); + } + + #[tokio::test] + async fn test_integration_multiple_connectors() { + use crate::ConnectorConfig; + use serde_json::json; + + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Create multiple connectors + for i in 1..=3 { + let cfg = ConnectorConfig { + id: Some(format!("conn-{}", i)), + kind: ConnectorKind::OpenCode, + owner: None, + title: Some(format!("Test {}", i)), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": format!("Test {}", i), + "initial_session": null + }), + ..Default::default() + }; + + runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + } + + // List should show all 3 + let list = runtime.list_connectors(None).await; + assert_eq!(list.len(), 3); + + // Filter by owner + let user1_list = runtime.list_connectors(Some(uuid::Uuid::nil())).await; + assert_eq!(user1_list.len(), 3); + + let user2_list = runtime + .list_connectors(Some(uuid::Uuid::from_u128(2))) + .await; + assert_eq!(user2_list.len(), 0); + + // Remove one + runtime + .remove_connector(&"conn-2".to_string()) + .await + .unwrap(); + + // List should show 2 + let list = runtime.list_connectors(None).await; + assert_eq!(list.len(), 2); + assert!(list.iter().any(|c| c.id == "conn-1")); + assert!(list.iter().any(|c| c.id == "conn-3")); + assert!(!list.iter().any(|c| c.id == "conn-2")); + } + + #[tokio::test] + async fn test_create_connector_after_removing_all() { + use serde_json::json; + + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Verify empty state initially + assert_eq!(runtime.list_connectors(None).await.len(), 0); + + // Create first connector + let cfg1 = ConnectorConfig { + id: Some("conn-1".to_string()), + kind: ConnectorKind::OpenCode, + owner: None, + title: Some("First Connector".to_string()), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": "First Connector", + "initial_session": null + }), + ..Default::default() + }; + + let id1 = runtime + .create_connector(uuid::Uuid::nil(), cfg1) + .await + .unwrap(); + assert_eq!(id1, "conn-1"); + assert_eq!(runtime.list_connectors(None).await.len(), 1); + + // Remove connector (transition to empty state) + runtime.remove_connector(&id1).await.unwrap(); + assert_eq!(runtime.list_connectors(None).await.len(), 0); + + // Create new connector after empty state (this is the key test) + let cfg2 = ConnectorConfig { + id: Some("conn-2".to_string()), + kind: ConnectorKind::OpenCode, + owner: None, + title: Some("Second Connector".to_string()), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": "Second Connector", + "initial_session": null + }), + ..Default::default() + }; + + let id2 = runtime + .create_connector(uuid::Uuid::nil(), cfg2) + .await + .unwrap(); + assert_eq!(id2, "conn-2"); + assert_eq!(runtime.list_connectors(None).await.len(), 1); + + // Verify new connector is functional + let connector = runtime.get_connector(&id2).await.unwrap(); + assert_eq!(connector.state(), ConnectorState::Initializing); // Will transition to Ready + assert_ne!(id1, id2); // Different IDs + + // Verify we can list connectors + let list = runtime.list_connectors(None).await; + assert_eq!(list.len(), 1); + assert_eq!(list[0].id, "conn-2"); + assert_eq!(list[0].title, "Second Connector"); + + // Clean up + runtime.remove_connector(&id2).await.unwrap(); + assert_eq!(runtime.list_connectors(None).await.len(), 0); + } + + #[tokio::test] + async fn test_rapid_remove_create_cycles() { + use serde_json::json; + + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Perform 5 rapid remove-create cycles + for i in 1..=5 { + let connector_id = format!("conn-{}", i); + + // Create + let cfg = ConnectorConfig { + id: Some(connector_id.clone()), + kind: ConnectorKind::OpenCode, + owner: None, + title: Some(format!("Connector {}", i)), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": format!("Connector {}", i), + "initial_session": null + }), + ..Default::default() + }; + + runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + assert_eq!(runtime.list_connectors(None).await.len(), 1); + + // Verify it exists + let connector = runtime.get_connector(&connector_id).await; + assert!(connector.is_some()); + + // Immediately remove (no delay) + runtime.remove_connector(&connector_id).await.unwrap(); + assert_eq!(runtime.list_connectors(None).await.len(), 0); + + // Verify it's gone + let connector = runtime.get_connector(&connector_id).await; + assert!(connector.is_none()); + } + + // Verify clean state after all cycles + assert_eq!(runtime.list_connectors(None).await.len(), 0); + } + + #[tokio::test] + async fn test_fallback_on_connector_error() { + let config = CoreConfig::default(); + let runtime = Arc::new(CoreRuntime::new(config, None)); + + // Create a Gateway connector + let (gateway_cmd_tx, _) = mpsc::channel(100); + let (gateway_events_tx, _) = broadcast::channel(1000); + let gateway_handle = ConnectorHandle::new( + "gateway-1".to_string(), + ConnectorKind::Gateway, + uuid::Uuid::nil(), + "Gateway".to_string(), + gateway_cmd_tx, + gateway_events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + // Create an OpenCode connector + let (opencode_cmd_tx, _) = mpsc::channel(100); + let (opencode_events_tx, _) = broadcast::channel(1000); + let opencode_handle = ConnectorHandle::new( + "opencode-1".to_string(), + ConnectorKind::OpenCode, + uuid::Uuid::nil(), + "OpenCode".to_string(), + opencode_cmd_tx, + opencode_events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + // Set OpenCode to Ready state + { + let state_arc = opencode_handle.state_lock(); + let mut state = state_arc.write().await; + *state = ConnectorState::Ready; + } + + { + let mut connectors = runtime.connectors.write().await; + connectors.insert("gateway-1".to_string(), gateway_handle); + connectors.insert("opencode-1".to_string(), opencode_handle); + } + + // Simulate a transfer + let transfer = TransferredSession { + gateway_session_id: "gateway-session-1".to_string(), + gateway_connector_id: "gateway-1".to_string(), + target_connector_id: "opencode-1".to_string(), + target_session_id: "opencode-session-1".to_string(), + transferred_at: std::time::Instant::now(), + }; + + { + let mut transfers = runtime.transferred_sessions.write().await; + transfers.insert("opencode-session-1".to_string(), transfer); + } + + // Subscribe to the SharingBus and wait for ForwardingPanic. + let mut bus_rx = runtime.sharing_bus.subscribe_all().await; + + // Trigger connector failure + runtime + .check_transferred_sessions_fallback( + "opencode-1", + &ConnectorState::Error("Connection lost".to_string()), + ) + .await; + + // Verify ForwardingPanic event emitted + let bus_event = tokio::time::timeout(Duration::from_secs(1), bus_rx.rx.recv()) + .await + .expect("Should receive event") + .expect("Event channel should not be closed"); + + match bus_event.event.as_ref() { + dirigent_protocol::Event::ForwardingPanic { + connector_id, + reason, + fallback_gateway_session, + .. + } => { + assert_eq!(connector_id, "opencode-1"); + assert!(reason.contains("Connection lost")); + assert_eq!( + fallback_gateway_session.as_deref(), + Some("gateway-session-1") + ); + } + other => panic!("Expected ForwardingPanic event, got {:?}", other), + } + + // Verify transfer record removed + let transfers = runtime.transferred_sessions.read().await; + assert!(transfers.is_empty()); + } + + #[tokio::test] + async fn test_cleanup_stale_transfers() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Add old transfer + let old_transfer = TransferredSession { + gateway_session_id: "old-session".to_string(), + gateway_connector_id: "gateway-1".to_string(), + target_connector_id: "opencode-1".to_string(), + target_session_id: "old-session-1".to_string(), + transferred_at: std::time::Instant::now() - Duration::from_secs(3600), // 1 hour ago + }; + + { + let mut transfers = runtime.transferred_sessions.write().await; + transfers.insert("old-session-1".to_string(), old_transfer); + } + + // Add recent transfer + let recent_transfer = TransferredSession { + gateway_session_id: "recent-session".to_string(), + gateway_connector_id: "gateway-1".to_string(), + target_connector_id: "opencode-2".to_string(), + target_session_id: "recent-session-1".to_string(), + transferred_at: std::time::Instant::now(), + }; + + { + let mut transfers = runtime.transferred_sessions.write().await; + transfers.insert("recent-session-1".to_string(), recent_transfer); + } + + // Cleanup with 30 min max age + runtime.cleanup_stale_transfers(Duration::from_secs(1800)); + + // Old removed, recent kept + let transfers = runtime.transferred_sessions.read().await; + assert!(!transfers.contains_key("old-session-1")); + assert!(transfers.contains_key("recent-session-1")); + } + + #[tokio::test] + async fn test_find_gateway_connector() { + let config = CoreConfig::default(); + let runtime = CoreRuntime::new(config, None); + + // Initially no Gateway + assert!(runtime.find_gateway_connector().await.is_none()); + + // Add a non-Gateway connector + let (cmd_tx, _) = mpsc::channel(100); + let (events_tx, _) = broadcast::channel(1000); + let opencode_handle = ConnectorHandle::new( + "opencode-1".to_string(), + ConnectorKind::OpenCode, + uuid::Uuid::nil(), + "OpenCode".to_string(), + cmd_tx, + events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + { + let mut connectors = runtime.connectors.write().await; + connectors.insert("opencode-1".to_string(), opencode_handle); + } + + // Still no Gateway + assert!(runtime.find_gateway_connector().await.is_none()); + + // Add a Gateway connector + let (gateway_cmd_tx, _) = mpsc::channel(100); + let (gateway_events_tx, _) = broadcast::channel(1000); + let gateway_handle = ConnectorHandle::new( + "gateway-1".to_string(), + ConnectorKind::Gateway, + uuid::Uuid::nil(), + "Gateway".to_string(), + gateway_cmd_tx, + gateway_events_tx, + serde_json::json!({}), + None, + None, + false, + ); + + // Set Gateway to Ready state + { + let state_arc = gateway_handle.state_lock(); + let mut state = state_arc.write().await; + *state = ConnectorState::Ready; + } + + { + let mut connectors = runtime.connectors.write().await; + connectors.insert("gateway-1".to_string(), gateway_handle); + } + + // Now should find Gateway + let gateway_id = runtime.find_gateway_connector().await; + assert_eq!(gateway_id, Some("gateway-1".to_string())); + } +} diff --git a/crates/dirigent_core/src/runtime/session_transfer.rs b/crates/dirigent_core/src/runtime/session_transfer.rs new file mode 100644 index 0000000..30823c1 --- /dev/null +++ b/crates/dirigent_core/src/runtime/session_transfer.rs @@ -0,0 +1,109 @@ +//! Session Transfer Helpers +//! +//! This module contains helper functions for transferring sessions between connectors. +//! These are primarily used by the Gateway connector when selecting a target connector. + +use std::time::Duration; +use tokio::sync::broadcast; + +use crate::connectors::ConnectorCommand; +use dirigent_protocol::session::{SessionModeState, SessionModelState}; + +/// Create a session in a connector and wait for the event +/// +/// Returns the session ID along with optional models/modes from the SessionCreated event. +/// The Session struct in SessionCreated already contains models/modes from the +/// connector's session/new response. +pub async fn create_session_in_connector( + cmd_tx: &tokio::sync::mpsc::Sender<ConnectorCommand>, + events: &mut broadcast::Receiver<dirigent_protocol::Event>, +) -> Result<(String, Option<SessionModelState>, Option<SessionModeState>), String> { + if let Err(e) = cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await + { + return Err(format!("Failed to send create command: {}", e)); + } + + // Wait for SessionCreated event + let timeout = Duration::from_secs(30); + let start = std::time::Instant::now(); + + while start.elapsed() < timeout { + match tokio::time::timeout(Duration::from_secs(1), events.recv()).await { + Ok(Ok(dirigent_protocol::Event::SessionCreated { session, .. })) => { + // Extract models/modes directly from the Session struct + return Ok((session.id, session.models, session.modes)); + } + Ok(Ok(dirigent_protocol::Event::Error { message })) => { + return Err(format!("Connector error: {}", message)); + } + Ok(Err(_)) => { + return Err("Event channel closed".to_string()); + } + Err(_) => continue, // Timeout on this iteration, keep waiting + _ => continue, + } + } + + Err("Timeout waiting for session creation".to_string()) +} + +/// Wait for a session load event +/// +/// Waits for a SessionUpdated or SessionCreated event for the given session ID, +/// or returns an error if a SessionError event is received or timeout occurs. +pub async fn wait_for_session_event( + events: &mut broadcast::Receiver<dirigent_protocol::Event>, + session_id: &str, + timeout: Duration, +) -> Result<String, String> { + let start = std::time::Instant::now(); + + while start.elapsed() < timeout { + match tokio::time::timeout(Duration::from_millis(500), events.recv()).await { + Ok(Ok(dirigent_protocol::Event::SessionUpdated { session, .. })) + if session.id == session_id => + { + return Ok(session.id); + } + Ok(Ok(dirigent_protocol::Event::SessionCreated { session, .. })) + if session.id == session_id => + { + return Ok(session.id); + } + Ok(Ok(dirigent_protocol::Event::SessionError { + session_id: sid, + error_message, + .. + })) if sid == session_id => { + return Err(error_message); + } + Ok(Err(_)) => return Err("Event channel closed".to_string()), + Err(_) => continue, + _ => continue, + } + } + + Err(format!("Session '{}' not found", session_id)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_wait_for_session_event_timeout() { + let (_tx, mut rx) = broadcast::channel::<dirigent_protocol::Event>(10); + + let result = + wait_for_session_event(&mut rx, "test-session", Duration::from_millis(100)).await; + + assert!(result.is_err()); + assert!(result.unwrap_err().contains("not found")); + } +} diff --git a/crates/dirigent_core/src/runtime/summary_cache.rs b/crates/dirigent_core/src/runtime/summary_cache.rs new file mode 100644 index 0000000..cfe52ca --- /dev/null +++ b/crates/dirigent_core/src/runtime/summary_cache.rs @@ -0,0 +1,96 @@ +//! Connector Summary Cache +//! +//! This module manages a synchronous cache of connector summary information +//! that can be accessed from callbacks without an async runtime. + +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::warn; + +use crate::config::CoreConfig; +use crate::connectors::gateway::ConnectorSummaryInfo; +use crate::connectors::{Connector, ConnectorHandle}; +use crate::types::ConnectorKind; + +/// Update the connector summary cache from the current connectors map +/// +/// This function rebuilds the sync-accessible cache used by GatewayConnector +/// callbacks. It should be called after any connector is added or removed. +/// +/// The cache uses a std::sync::RwLock so it can be accessed from synchronous +/// callback functions that don't have access to an async runtime. +/// +/// # Arguments +/// +/// * `connectors` - Read-locked map of connector handles +/// * `config` - Read-locked runtime configuration +/// * `cache` - The std::sync::RwLock cache to update +pub async fn update_connector_summary_cache( + connectors: &RwLock<HashMap<String, ConnectorHandle>>, + config: &RwLock<CoreConfig>, + cache: &Arc<std::sync::RwLock<Vec<ConnectorSummaryInfo>>>, +) { + let connectors_guard = connectors.read().await; + let config_guard = config.read().await; + + let summaries: Vec<ConnectorSummaryInfo> = connectors_guard + .values() + .map(|handle| { + let kind = handle.kind(); + let connector_id = handle.id(); + + // Get agent_type for ACP connectors + let agent_type = if kind == ConnectorKind::Acp { + config_guard + .connectors + .iter() + .find(|c| c.id.as_deref() == Some(connector_id)) + .and_then(|cfg| { + serde_json::from_value::<crate::connectors::acp::config::AcpConfig>( + cfg.params.clone(), + ) + .ok() + }) + .map(|cfg| cfg.agent_type) + } else { + None + }; + + ConnectorSummaryInfo { + id: connector_id.to_string(), + title: handle.title().to_string(), + kind: format!("{:?}", kind), + state: format!("{:?}", handle.state()), + supports_session_transfer: kind.supports_session_transfer(), + agent_type, + } + }) + .collect(); + + // Update the sync cache (this lock should never block for long since we're just writing a Vec) + if let Ok(mut cache_guard) = cache.write() { + *cache_guard = summaries; + } else { + warn!("Failed to acquire connector summary cache lock for update"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::CoreConfig; + use std::sync::Arc; + + #[tokio::test] + async fn test_update_empty_cache() { + let connectors = RwLock::new(HashMap::new()); + let config = RwLock::new(CoreConfig::default()); + let cache = Arc::new(std::sync::RwLock::new(Vec::new())); + + update_connector_summary_cache(&connectors, &config, &cache).await; + + let cached = cache.read().unwrap(); + assert!(cached.is_empty()); + } +} diff --git a/crates/dirigent_core/src/runtime/zed_detection.rs b/crates/dirigent_core/src/runtime/zed_detection.rs new file mode 100644 index 0000000..1cc6a93 --- /dev/null +++ b/crates/dirigent_core/src/runtime/zed_detection.rs @@ -0,0 +1,589 @@ +//! Zed editor agent detection and connector config generation. +//! +//! This module detects Zed editor installations and generates `ConnectorConfig` +//! entries for discovered ACP agents. It is called during runtime initialization +//! to auto-populate connectors from Zed's agent configuration. + +use crate::config::{ConnectorConfig, CoreConfig}; +use crate::connectors::acp::config::ConnectorAgentType; +use crate::connectors::acp::{AcpConfig, TransportKind}; +use crate::types::ConnectorKind; +use dirigent_tools::EmbeddingConfig; +use tracing::{debug, info, warn}; + +/// Default supported features for known ACP agent types. +/// +/// These are conservative defaults based on confirmed agent capabilities. +/// Users can override via the connector config UI. +fn default_features_for_agent(agent_type: &ConnectorAgentType) -> Vec<String> { + match agent_type { + ConnectorAgentType::Claude => vec![ + "cancellation".to_string(), + "session_resume".to_string(), + "session_list".to_string(), + ], + ConnectorAgentType::Codex => vec![ + "session_resume".to_string(), + ], + // Gemini: no confirmed features yet (hangs on connect — BUG-7) + ConnectorAgentType::Gemini => vec![], + ConnectorAgentType::Custom => vec![], + } +} + +/// Convert a discovered Zed agent into a `ConnectorConfig` for the runtime. +/// +/// Only creates connectors for agents with resolved binary paths (typically +/// registry agents whose binaries have been downloaded by Zed). +/// +/// Returns `None` if the agent has no binary path. +pub fn zed_agent_to_connector_config(agent: &dirigent_zed::ZedAgent) -> Option<ConnectorConfig> { + let binary_path = agent.binary_path.as_ref()?; + + // Map agent names to proper types. Handles both Zed settings keys + // (e.g. "claude-acp") and external_agents directory names (e.g. "claude-agent-acp"). + let name_lower = agent.name.to_lowercase(); + let (default_title, default_icon, agent_type): (&str, &str, ConnectorAgentType) = + if name_lower.contains("claude") { + ("Claude (Zed)", "claude", ConnectorAgentType::Claude) + } else if name_lower.contains("codex") { + ("Codex (Zed)", "codex", ConnectorAgentType::Codex) + } else if name_lower.contains("gemini") { + ("Gemini (Zed)", "gemini", ConnectorAgentType::Gemini) + } else { + (agent.name.as_str(), "acp", ConnectorAgentType::Custom) + }; + + // Use registry display name with "(Zed)" suffix when available, falling + // back to the hardcoded title. + let title: String = match agent.display_name.as_deref() { + Some(display) => format!("{display} (Zed)"), + None => default_title.to_string(), + }; + + // Use the locally cached SVG icon path from the registry when available, + // otherwise fall back to the built-in icon name. + let icon: String = match agent.icon_local_path.as_ref() { + Some(path) => path.to_string_lossy().to_string(), + None => default_icon.to_string(), + }; + + let features = default_features_for_agent(&agent_type); + + // Build a proper AcpConfig with stdio transport pointing to the Zed-managed binary. + let env: Vec<(String, String)> = agent + .env_overrides + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + + // Use args from registry metadata (e.g. ["--acp"]) when available. + let args = if agent.args.is_empty() { + vec![] + } else { + agent.args.clone() + }; + + let acp_config = AcpConfig { + transport: TransportKind::Stdio { + command: binary_path.to_string_lossy().to_string(), + args, + cwd: None, + env, + }, + protocol_version: 1, + cwd: ".".to_string(), + retry: Default::default(), + embedding: EmbeddingConfig::default(), + default_ownership: Default::default(), + acp_log_dir: None, + agent_type, + }; + + let params = match serde_json::to_value(&acp_config) { + Ok(v) => v, + Err(e) => { + warn!( + agent = %agent.name, + error = %e, + "Failed to serialize AcpConfig for Zed agent" + ); + return None; + } + }; + + Some(ConnectorConfig { + id: None, + kind: ConnectorKind::Acp, + owner: None, + title: Some(title), + working_directory: None, + params, + icon_path: Some(icon), + show_type_overlay: false, + supported_features: features, + tool_configuration: None, + plugin_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: Some(agent.name.clone()), + }) +} + +/// Refresh binary paths for Zed-sourced connectors in the config. +/// +/// When Zed upgrades agent binaries in the background, the binary path changes +/// (e.g. new version directory). This function re-detects the current binary +/// paths from Zed installations and updates any connector that has a +/// `zed_agent_name` set. +/// +/// Returns the number of connectors updated. +pub fn refresh_zed_connector_binaries(config: &mut CoreConfig) -> usize { + let installations = dirigent_zed::detect_installations(); + if installations.is_empty() { + debug!("No Zed installations detected, skipping binary refresh"); + return 0; + } + + // Build a map of agent_name -> latest binary path from all Zed installations. + let mut agent_binaries: std::collections::HashMap<String, String> = + std::collections::HashMap::new(); + for installation in &installations { + for agent in &installation.agents { + if let Some(ref binary_path) = agent.binary_path { + agent_binaries + .insert(agent.name.clone(), binary_path.to_string_lossy().to_string()); + } + } + } + + let mut updated = 0usize; + + for connector in &mut config.connectors { + let zed_name = match connector.zed_agent_name.as_deref() { + Some(n) => n, + None => continue, + }; + + let new_binary = match agent_binaries.get(zed_name) { + Some(b) => b.clone(), + None => continue, + }; + + // Parse current ACP config to check the existing binary path. + let mut acp_config: AcpConfig = match serde_json::from_value(connector.params.clone()) { + Ok(c) => c, + Err(_) => continue, + }; + + let current_command = match &acp_config.transport { + TransportKind::Stdio { command, .. } => command.clone(), + _ => continue, + }; + + if current_command == new_binary { + continue; + } + + // Update the transport command to the new binary path. + match &mut acp_config.transport { + TransportKind::Stdio { command, .. } => { + info!( + zed_agent = %zed_name, + old = %current_command, + new = %new_binary, + "Updating Zed connector binary path" + ); + *command = new_binary; + } + _ => continue, + } + + // Re-serialize back into params. + match serde_json::to_value(&acp_config) { + Ok(v) => { + connector.params = v; + updated += 1; + } + Err(e) => { + warn!( + zed_agent = %zed_name, + error = %e, + "Failed to re-serialize AcpConfig after binary update" + ); + } + } + } + + if updated > 0 { + info!(count = updated, "Updated Zed connector binary paths"); + } + + updated +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::CoreConfig; + + #[test] + fn test_zed_agent_to_connector_config_no_binary() { + let agent = dirigent_zed::ZedAgent { + name: "claude-acp".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: None, + env_overrides: std::collections::HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + assert!(zed_agent_to_connector_config(&agent).is_none()); + } + + #[test] + fn test_zed_agent_to_connector_config_with_binary() { + let agent = dirigent_zed::ZedAgent { + name: "claude-acp".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/usr/local/bin/claude-acp")), + env_overrides: std::collections::HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + + let config = zed_agent_to_connector_config(&agent).unwrap(); + assert_eq!(config.kind, ConnectorKind::Acp); + assert_eq!(config.title.as_deref(), Some("Claude (Zed)")); + assert_eq!(config.icon_path.as_deref(), Some("claude")); + assert_eq!(config.source, None); + + // Verify the params contain a proper AcpConfig + let acp_config: AcpConfig = serde_json::from_value(config.params).unwrap(); + match &acp_config.transport { + TransportKind::Stdio { command, .. } => { + assert_eq!(command, "/usr/local/bin/claude-acp"); + } + _ => panic!("Expected stdio transport"), + } + assert_eq!(acp_config.agent_type, ConnectorAgentType::Claude); + assert_eq!(config.supported_features, vec!["cancellation", "session_resume", "session_list"]); + } + + #[test] + fn test_zed_agent_to_connector_config_with_env() { + let mut env = std::collections::HashMap::new(); + env.insert( + "CLAUDE_CODE_EXECUTABLE".to_string(), + "/usr/bin/claude".to_string(), + ); + + let agent = dirigent_zed::ZedAgent { + name: "claude-acp".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/path/to/binary")), + env_overrides: env, + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + + let config = zed_agent_to_connector_config(&agent).unwrap(); + let acp_config: AcpConfig = serde_json::from_value(config.params).unwrap(); + match &acp_config.transport { + TransportKind::Stdio { env, .. } => { + assert!(env + .iter() + .any(|(k, v)| k == "CLAUDE_CODE_EXECUTABLE" && v == "/usr/bin/claude")); + } + _ => panic!("Expected stdio transport"), + } + } + + #[test] + fn test_zed_agent_to_connector_config_unknown_agent() { + let agent = dirigent_zed::ZedAgent { + name: "my-custom-agent".to_string(), + agent_type: dirigent_zed::AgentServerType::Custom, + binary_path: Some(std::path::PathBuf::from("/path/to/custom")), + env_overrides: std::collections::HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + + let config = zed_agent_to_connector_config(&agent).unwrap(); + assert_eq!(config.title.as_deref(), Some("my-custom-agent")); + assert_eq!(config.icon_path.as_deref(), Some("acp")); + + let acp_config: AcpConfig = serde_json::from_value(config.params).unwrap(); + assert_eq!(acp_config.agent_type, ConnectorAgentType::Custom); + assert!(config.supported_features.is_empty()); + } + + #[test] + fn test_zed_agent_codex() { + let agent = dirigent_zed::ZedAgent { + name: "codex-acp".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/path/to/codex")), + env_overrides: std::collections::HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + + let config = zed_agent_to_connector_config(&agent).unwrap(); + assert_eq!(config.title.as_deref(), Some("Codex (Zed)")); + assert_eq!(config.icon_path.as_deref(), Some("codex")); + + let acp_config: AcpConfig = serde_json::from_value(config.params).unwrap(); + assert_eq!(acp_config.agent_type, ConnectorAgentType::Codex); + assert_eq!(config.supported_features, vec!["session_resume"]); + } + + #[test] + fn test_zed_agent_gemini() { + let agent = dirigent_zed::ZedAgent { + name: "gemini".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/path/to/gemini")), + env_overrides: std::collections::HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + + let config = zed_agent_to_connector_config(&agent).unwrap(); + assert_eq!(config.title.as_deref(), Some("Gemini (Zed)")); + assert_eq!(config.icon_path.as_deref(), Some("gemini")); + + let acp_config: AcpConfig = serde_json::from_value(config.params).unwrap(); + assert_eq!(acp_config.agent_type, ConnectorAgentType::Gemini); + assert!(config.supported_features.is_empty()); + } + + #[test] + fn test_dismissed_zed_agent_title_matches_generated_config() { + // Verify that a dismissed title like "Claude (Zed)" matches the title + // generated by zed_agent_to_connector_config for a claude-acp agent + let mut core_config = CoreConfig::default(); + core_config + .dismissed_zed_agents + .push("Claude (Zed)".to_string()); + + let agent = dirigent_zed::ZedAgent { + name: "claude-acp".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/usr/local/bin/claude-acp")), + env_overrides: std::collections::HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + + let connector_config = zed_agent_to_connector_config(&agent).unwrap(); + assert_eq!(connector_config.title.as_deref(), Some("Claude (Zed)")); + // The dismissed list should contain the generated title + assert!(core_config + .dismissed_zed_agents + .contains(connector_config.title.as_ref().unwrap())); + } + + #[test] + fn test_dismissed_list_does_not_block_other_agents() { + // Dismissing Claude should not block Codex or Gemini + let mut core_config = CoreConfig::default(); + core_config + .dismissed_zed_agents + .push("Claude (Zed)".to_string()); + + let codex_title = "Codex (Zed)".to_string(); + let gemini_title = "Gemini (Zed)".to_string(); + + assert!(!core_config.dismissed_zed_agents.contains(&codex_title)); + assert!(!core_config.dismissed_zed_agents.contains(&gemini_title)); + } + + #[test] + fn test_dismissed_zed_agents_serde_roundtrip() { + // Verify dismissed_zed_agents survives serialization/deserialization + let mut core_config = CoreConfig::default(); + core_config + .dismissed_zed_agents + .push("Claude (Zed)".to_string()); + core_config + .dismissed_zed_agents + .push("Gemini (Zed)".to_string()); + + let json = serde_json::to_string(&core_config).unwrap(); + let deserialized: CoreConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.dismissed_zed_agents.len(), 2); + assert!(deserialized + .dismissed_zed_agents + .contains(&"Claude (Zed)".to_string())); + assert!(deserialized + .dismissed_zed_agents + .contains(&"Gemini (Zed)".to_string())); + } + + #[test] + fn test_dismissed_zed_agents_empty_not_serialized() { + // With skip_serializing_if = "Vec::is_empty", empty list should not appear in JSON + let core_config = CoreConfig::default(); + let json = serde_json::to_string(&core_config).unwrap(); + assert!( + !json.contains("dismissed_zed_agents"), + "Empty dismissed_zed_agents should be omitted from serialization" + ); + } + + #[test] + fn test_dismissed_zed_agents_deserialized_from_missing_field() { + // Old config files without dismissed_zed_agents should still deserialize + // thanks to #[serde(default)] + let json = r#"{"project_dir":".","connectors":[]}"#; + let config: CoreConfig = serde_json::from_str(json).unwrap(); + assert!(config.dismissed_zed_agents.is_empty()); + } + + #[test] + fn test_zed_agent_name_is_set() { + let agent = dirigent_zed::ZedAgent { + name: "claude-acp".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/usr/local/bin/claude-acp")), + env_overrides: std::collections::HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + let config = zed_agent_to_connector_config(&agent).unwrap(); + assert_eq!(config.zed_agent_name.as_deref(), Some("claude-acp")); + } + + #[test] + fn test_zed_agent_name_preserves_original_name() { + // The zed_agent_name should be the exact Zed agent name, not the display title + let agent = dirigent_zed::ZedAgent { + name: "claude-agent-acp".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/path/to/binary")), + env_overrides: std::collections::HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + let config = zed_agent_to_connector_config(&agent).unwrap(); + assert_eq!(config.title.as_deref(), Some("Claude (Zed)")); + assert_eq!(config.zed_agent_name.as_deref(), Some("claude-agent-acp")); + } + + #[test] + fn test_zed_agent_name_serde_roundtrip() { + let agent = dirigent_zed::ZedAgent { + name: "codex".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/path/to/codex")), + env_overrides: std::collections::HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + let config = zed_agent_to_connector_config(&agent).unwrap(); + let json = serde_json::to_string(&config).unwrap(); + assert!(json.contains("\"zed_agent_name\":\"codex\"")); + let deserialized: ConnectorConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.zed_agent_name.as_deref(), Some("codex")); + } + + #[test] + fn test_non_zed_connector_has_no_zed_agent_name() { + // ConnectorConfig created via template should not have zed_agent_name + let config = ConnectorConfig::default(); + assert!(config.zed_agent_name.is_none()); + } + + #[test] + fn test_enriched_display_name_used_in_title() { + let agent = dirigent_zed::ZedAgent { + name: "claude-acp".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/path/to/claude")), + env_overrides: std::collections::HashMap::new(), + display_name: Some("Claude Agent".to_string()), + description: Some("ACP wrapper for Anthropic's Claude".to_string()), + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }; + + let config = zed_agent_to_connector_config(&agent).unwrap(); + // Title should use the registry display name with "(Zed)" suffix. + assert_eq!(config.title.as_deref(), Some("Claude Agent (Zed)")); + } + + #[test] + fn test_enriched_args_passed_to_transport() { + let agent = dirigent_zed::ZedAgent { + name: "auggie".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/path/to/auggie")), + env_overrides: std::collections::HashMap::new(), + display_name: Some("Auggie CLI".to_string()), + description: None, + args: vec!["--acp".to_string()], + icon_local_path: None, + icon_url: None, + }; + + let config = zed_agent_to_connector_config(&agent).unwrap(); + let acp_config: AcpConfig = serde_json::from_value(config.params).unwrap(); + match &acp_config.transport { + TransportKind::Stdio { args, .. } => { + assert_eq!(args, &["--acp"]); + } + _ => panic!("Expected stdio transport"), + } + } + + #[test] + fn test_enriched_icon_path_used() { + let agent = dirigent_zed::ZedAgent { + name: "claude-acp".to_string(), + agent_type: dirigent_zed::AgentServerType::Registry, + binary_path: Some(std::path::PathBuf::from("/path/to/claude")), + env_overrides: std::collections::HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: Some(std::path::PathBuf::from("/icons/claude-acp.svg")), + icon_url: None, + }; + + let config = zed_agent_to_connector_config(&agent).unwrap(); + assert_eq!(config.icon_path.as_deref(), Some("/icons/claude-acp.svg")); + } +} diff --git a/crates/dirigent_core/src/sharing/bus.rs b/crates/dirigent_core/src/sharing/bus.rs new file mode 100644 index 0000000..5f93f95 --- /dev/null +++ b/crates/dirigent_core/src/sharing/bus.rs @@ -0,0 +1,477 @@ +//! SharingBus: single-producer, many-subscriber event multiplexer with +//! subscriber-side filtering performed by a worker task. See +//! docs/plans/2026-04-21-archivist-phase4-design.md §1. +//! +//! Architecture: +//! - One internal `tokio::sync::broadcast::Sender<BusEvent>` feeds a single +//! worker task. The worker iterates `Vec<SubscriberSlot>` (behind `RwLock`), +//! filter-matches each slot, and `try_send`s the event onto each slot's +//! `mpsc::Sender<BusEvent>`. +//! - Slow subscribers drop their own events at their mpsc (counted in the +//! slot's `lagged` atomic). The bus-internal broadcast channel never drops +//! due to a slow subscriber — only due to the broadcast lag contract, which +//! we log and continue. +//! - `SessionRegistered` events late-bind `(connector_id, native_session_id) -> +//! scroll_id` via a small cache consulted on every publish. + +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio::task::JoinHandle; +use tracing::{debug, warn}; +use uuid::Uuid; + +use dirigent_protocol::streaming::{BusEvent, EventFilter}; +pub use dirigent_protocol::streaming::BusReceiver; +use dirigent_protocol::Event; + +const BUS_INTERNAL_CAPACITY: usize = 1024; +const SUBSCRIBER_QUEUE_DEFAULT: usize = 256; + +/// Single-producer, many-subscriber event multiplexer. +/// +/// Subscribers see a `mpsc::Receiver<BusEvent>` that only yields events +/// matching their `EventFilter`. Filtering happens inside a single worker +/// task, so the cost per event is O(n_subscribers) regardless of publisher +/// count. Slow subscribers lose events at their own mpsc, not at the bus. +pub struct SharingBus { + publish_tx: broadcast::Sender<BusEvent>, + subscribers: Arc<RwLock<Vec<SubscriberSlot>>>, + scroll_id_cache: Arc<RwLock<HashMap<(String, String), Uuid>>>, + next_id: Arc<AtomicU64>, + _worker: JoinHandle<()>, +} + +struct SubscriberSlot { + id: u64, + filter: EventFilter, + sender: mpsc::Sender<BusEvent>, + lagged: Arc<AtomicU64>, +} + +impl SharingBus { + /// Construct a new bus and spawn its dispatch worker. + pub fn new() -> Arc<Self> { + let (publish_tx, publish_rx) = broadcast::channel(BUS_INTERNAL_CAPACITY); + let subscribers: Arc<RwLock<Vec<SubscriberSlot>>> = Arc::new(RwLock::new(Vec::new())); + let scroll_id_cache: Arc<RwLock<HashMap<(String, String), Uuid>>> = + Arc::new(RwLock::new(HashMap::new())); + let next_id = Arc::new(AtomicU64::new(0)); + + let worker = tokio::spawn(run_worker(publish_rx, Arc::clone(&subscribers))); + + Arc::new(Self { + publish_tx, + subscribers, + scroll_id_cache, + next_id, + _worker: worker, + }) + } + + /// Publish a `BusEvent` to all matching subscribers. + /// + /// This method also performs two side-effects on the scroll-id cache: + /// + /// 1. If the wrapped event is `Event::SessionRegistered`, the binding + /// `(connector_id, session_id) -> scroll_id` is inserted into the + /// cache, and the current event's `routing.scroll_id` is set so the + /// binding event itself carries its own scroll_id downstream. + /// 2. If the event's `routing.scroll_id` is absent but it carries both a + /// `connector_id` and `native_session_id`, the cache is consulted to + /// late-bind `scroll_id` before broadcasting. + pub async fn publish(&self, mut bus_event: BusEvent) { + // (2) Late-bind scroll_id from cache if we can, BEFORE the possibly + // more specific (1) handling overrides it. This is a no-op for + // SessionRegistered (its scroll_id is always populated in (1)). + if bus_event.routing.scroll_id.is_none() { + if let (Some(cid), Some(nsid)) = ( + bus_event.routing.connector_id.as_ref(), + bus_event.routing.native_session_id.as_ref(), + ) { + let cache = self.scroll_id_cache.read().await; + if let Some(uuid) = cache.get(&(cid.clone(), nsid.clone())) { + bus_event.routing.scroll_id = Some(*uuid); + } + } + } + + // (1) If the wrapped event is SessionRegistered, populate the cache + // and set scroll_id on the event itself. + if let Event::SessionRegistered { + connector_id, + session_id, + scroll_id, + } = bus_event.event.as_ref() + { + match Uuid::parse_str(scroll_id) { + Ok(uuid) => { + self.scroll_id_cache + .write() + .await + .insert((connector_id.clone(), session_id.clone()), uuid); + bus_event.routing.scroll_id = Some(uuid); + } + Err(e) => { + warn!( + connector_id = %connector_id, + session_id = %session_id, + scroll_id = %scroll_id, + error = %e, + "SessionRegistered carried an unparseable scroll_id; skipping late-bind cache insert", + ); + } + } + } + + // No subscribers is not an error — ignore the Result. + let _ = self.publish_tx.send(bus_event); + } + + /// Subscribe to every event on the bus. + pub async fn subscribe_all(&self) -> BusReceiver { + self.subscribe_filtered(EventFilter::All, SUBSCRIBER_QUEUE_DEFAULT) + .await + } + + /// Subscribe to events that match `filter`. `queue_capacity` caps the + /// buffered events between the worker and the caller's `recv()`. + pub async fn subscribe_filtered( + &self, + filter: EventFilter, + queue_capacity: usize, + ) -> BusReceiver { + let (tx, rx) = mpsc::channel(queue_capacity); + let lagged = Arc::new(AtomicU64::new(0)); + // Relaxed ordering is sufficient: subscriber IDs are only compared for + // equality with other IDs issued by this same bus; there is no + // cross-thread ordering dependency on this counter. + let id = self.next_id.fetch_add(1, Ordering::Relaxed); + self.subscribers.write().await.push(SubscriberSlot { + id, + filter, + sender: tx, + lagged: Arc::clone(&lagged), + }); + BusReceiver { id, rx, lagged } + } + + /// Remove a subscriber by id. Idempotent. + pub async fn unsubscribe(&self, id: u64) { + self.subscribers.write().await.retain(|s| s.id != id); + } +} + +async fn run_worker( + mut rx: broadcast::Receiver<BusEvent>, + subscribers: Arc<RwLock<Vec<SubscriberSlot>>>, +) { + loop { + match rx.recv().await { + Ok(evt) => { + let mut closed_ids: Vec<u64> = Vec::new(); + { + let subs = subscribers.read().await; + for slot in subs.iter() { + if !slot.filter.matches(&evt) { + continue; + } + match slot.sender.try_send(evt.clone()) { + Ok(()) => {} + Err(mpsc::error::TrySendError::Full(_)) => { + slot.lagged.fetch_add(1, Ordering::Relaxed); + warn!( + subscriber_id = slot.id, + "bus subscriber queue full; dropping event" + ); + } + Err(mpsc::error::TrySendError::Closed(_)) => { + closed_ids.push(slot.id); + } + } + } + } + if !closed_ids.is_empty() { + subscribers + .write() + .await + .retain(|s| !closed_ids.contains(&s.id)); + debug!(removed = closed_ids.len(), "GC'd closed subscriber slots"); + } + } + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(skipped = n, "SharingBus internal broadcast lagged"); + } + Err(broadcast::error::RecvError::Closed) => { + debug!("SharingBus worker exiting (sender closed)"); + return; + } + } + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use std::sync::atomic::Ordering; + use std::time::Duration; + + use tokio::time::timeout; + use uuid::Uuid; + + use super::*; + use dirigent_protocol::streaming::{BusEvent, EventKind, EventOrigin, EventRouting}; + use dirigent_protocol::Event; + + /// Build a minimal `BusEvent` for tests. Uses `Event::Connected` as payload + /// unless a specific event is needed for late-bind checks. + fn make_event( + scroll_id: Option<Uuid>, + connector_uid: Option<Uuid>, + connector_id: Option<String>, + native_session_id: Option<String>, + kind: EventKind, + event: Event, + ) -> BusEvent { + BusEvent { + routing: EventRouting { + scroll_id, + connector_uid, + connector_id, + native_session_id, + kind, + }, + origin: EventOrigin::Runtime, + event: Arc::new(event), + } + } + + // 1. subscribe_all + publish: one event round-trips to receiver. + #[tokio::test] + async fn subscribe_all_receives_published_event() { + let bus = SharingBus::new(); + let mut recv = bus.subscribe_all().await; + + let ev = make_event( + None, + None, + None, + None, + EventKind::System, + Event::Connected, + ); + bus.publish(ev).await; + + let got = timeout(Duration::from_millis(200), recv.rx.recv()) + .await + .expect("timed out waiting for event") + .expect("channel closed unexpectedly"); + + match got.event.as_ref() { + Event::Connected => {} + other => panic!("expected Event::Connected, got {:?}", other), + } + } + + // 2. ConnectorUid filter: matching UID passes, other UID skipped. + #[tokio::test] + async fn connector_uid_filter_only_forwards_matching_events() { + let bus = SharingBus::new(); + let target = Uuid::new_v4(); + let other = Uuid::new_v4(); + + let mut recv = bus + .subscribe_filtered(EventFilter::ConnectorUid(target), 16) + .await; + + // Publish one matching and one non-matching event. + let ev_match = make_event( + None, + Some(target), + None, + None, + EventKind::System, + Event::Connected, + ); + let ev_other = make_event( + None, + Some(other), + None, + None, + EventKind::System, + Event::Connected, + ); + bus.publish(ev_match).await; + bus.publish(ev_other).await; + + // First recv returns the matching event. + let got = timeout(Duration::from_millis(200), recv.rx.recv()) + .await + .expect("timed out waiting for first event") + .expect("channel closed unexpectedly"); + assert_eq!(got.routing.connector_uid, Some(target)); + + // Second recv must time out — no other matching event was published. + let result = timeout(Duration::from_millis(100), recv.rx.recv()).await; + assert!( + result.is_err(), + "expected no further events, got: {:?}", + result.ok().flatten().map(|e| e.routing.connector_uid) + ); + } + + // 3. Queue full = lagged counter increments, first event still delivered. + #[tokio::test] + async fn full_queue_increments_lagged_counter() { + let bus = SharingBus::new(); + // Capacity 1 — only one event can be buffered before try_send fails. + let mut recv = bus.subscribe_filtered(EventFilter::All, 1).await; + + // Publish 5 events without draining. + for _ in 0..5 { + let ev = make_event( + None, + None, + None, + None, + EventKind::System, + Event::Connected, + ); + bus.publish(ev).await; + } + + // Give the worker a chance to process all 5. + for _ in 0..10 { + tokio::task::yield_now().await; + } + tokio::time::sleep(Duration::from_millis(20)).await; + + // First event is still in the queue. + let first = timeout(Duration::from_millis(200), recv.rx.recv()) + .await + .expect("timed out waiting for first event") + .expect("channel closed unexpectedly"); + match first.event.as_ref() { + Event::Connected => {} + other => panic!("expected Event::Connected, got {:?}", other), + } + + // At minimum 4 events were dropped (5 published, 1 fit). + let lagged = recv.lagged.load(Ordering::Relaxed); + assert!( + lagged >= 4, + "expected lagged >= 4 after publishing 5 events to a capacity-1 queue, got {}", + lagged + ); + } + + // 4. scroll_id late-bind: SessionRegistered populates cache; subsequent + // events with matching (connector_id, native_session_id) get their + // scroll_id filled in before dispatch. + #[tokio::test] + async fn session_registered_populates_cache_and_late_binds_subsequent_events() { + let bus = SharingBus::new(); + let scroll = Uuid::new_v4(); + + // Subscriber filters on ScrollId(scroll). It should see: + // - the SessionRegistered event (bus sets its own scroll_id at publish) + // - a follow-up event with (connector_id="c", native_session_id="s") + // that had no scroll_id on entry (late-bound from the cache). + let mut recv = bus + .subscribe_filtered(EventFilter::ScrollId(scroll), 16) + .await; + + // --- publish SessionRegistered (binding event) --- + let reg_event = Event::SessionRegistered { + connector_id: "c".to_string(), + session_id: "s".to_string(), + scroll_id: scroll.to_string(), + }; + // We pass through the routing fields the producer would populate. + // `scroll_id` starts as None; publish() sets it from the event payload. + let reg_bus = make_event( + None, + None, + Some("c".to_string()), + Some("s".to_string()), + EventKind::SessionLifecycle, + reg_event, + ); + bus.publish(reg_bus).await; + + let got1 = timeout(Duration::from_millis(200), recv.rx.recv()) + .await + .expect("timed out waiting for SessionRegistered") + .expect("channel closed unexpectedly"); + assert!(matches!( + got1.event.as_ref(), + Event::SessionRegistered { .. } + )); + assert_eq!(got1.routing.scroll_id, Some(scroll)); + + // --- publish a follow-up event with no scroll_id but matching + // connector_id + native_session_id --- + let follow_up = make_event( + None, + None, + Some("c".to_string()), + Some("s".to_string()), + EventKind::System, + Event::Connected, + ); + bus.publish(follow_up).await; + + let got2 = timeout(Duration::from_millis(200), recv.rx.recv()) + .await + .expect("timed out waiting for late-bound follow-up") + .expect("channel closed unexpectedly"); + assert_eq!( + got2.routing.scroll_id, + Some(scroll), + "follow-up event should have had scroll_id late-bound from the cache" + ); + assert!(matches!(got2.event.as_ref(), Event::Connected)); + } + + // 5. Dropped receiver is GC'd after the next publish. + #[tokio::test] + async fn closed_receiver_slot_is_reaped_on_next_publish() { + let bus = SharingBus::new(); + + // Subscribe, then immediately drop the receiver — simulates a caller + // that forgets (or skips) `unsubscribe()`. + let recv = bus.subscribe_all().await; + drop(recv); + + // Sanity check: slot is present before GC. + assert_eq!(bus.subscribers.read().await.len(), 1); + + // Publish one event; the worker encounters TrySendError::Closed and + // schedules the slot for removal. + let ev = make_event( + None, + None, + None, + None, + EventKind::System, + Event::Connected, + ); + bus.publish(ev).await; + + // Give the worker a moment to process and GC. + for _ in 0..10 { + tokio::task::yield_now().await; + } + tokio::time::sleep(Duration::from_millis(10)).await; + + assert_eq!( + bus.subscribers.read().await.len(), + 0, + "closed subscriber slot should have been GC'd after publish" + ); + } +} diff --git a/crates/dirigent_core/src/sharing/config.rs b/crates/dirigent_core/src/sharing/config.rs new file mode 100644 index 0000000..3ae166d --- /dev/null +++ b/crates/dirigent_core/src/sharing/config.rs @@ -0,0 +1,101 @@ +//! `[[streams]]` TOML config block parsed from `dirigent.toml`. + +use serde::{Deserialize, Serialize}; +use dirigent_protocol::streaming::StreamScope; + +#[derive(Debug, Clone, Default, Deserialize, Serialize)] +pub struct StreamsConfig { + #[serde(default, rename = "streams")] + pub entries: Vec<StreamConfig>, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct StreamConfig { + pub name: String, + #[serde(rename = "type")] + pub kind: String, // "matrix" | "langfuse" | ... + pub scope: StreamScope, + #[serde(default = "default_enabled")] + pub enabled: bool, + #[serde(default = "default_params")] + pub params: toml::Value, // type-specific +} + +fn default_enabled() -> bool { true } + +fn default_params() -> toml::Value { toml::Value::Table(toml::map::Map::new()) } + +#[cfg(test)] +mod tests { + use super::*; + use uuid::Uuid; + + const FULL_TOML: &str = r#" +[[streams]] +name = "matrix-main" +type = "matrix" +enabled = true + +[streams.scope] +kind = "session" +scroll_id = "01985d00-0000-7000-8000-000000000000" + +[streams.params] +homeserver_url = "https://matrix.org" +room_id = "!abc:matrix.org" +"#; + + const MINIMAL_TOML: &str = r#" +[[streams]] +name = "minimal" +type = "langfuse" + +[streams.scope] +kind = "archive_wide" +acknowledged = false +"#; + + #[test] + fn round_trip_full() { + let cfg: StreamsConfig = toml::from_str(FULL_TOML).expect("parse failed"); + + assert_eq!(cfg.entries.len(), 1); + let entry = &cfg.entries[0]; + assert_eq!(entry.name, "matrix-main"); + assert_eq!(entry.kind, "matrix"); + assert!(entry.enabled); + + let expected_id = Uuid::parse_str("01985d00-0000-7000-8000-000000000000").unwrap(); + match &entry.scope { + StreamScope::Session { scroll_id } => { + assert_eq!(*scroll_id, expected_id); + } + other => panic!("expected Session scope, got {:?}", other), + } + + let params = entry.params.as_table().expect("params should be a table"); + assert_eq!( + params.get("homeserver_url").and_then(|v| v.as_str()), + Some("https://matrix.org") + ); + assert_eq!( + params.get("room_id").and_then(|v| v.as_str()), + Some("!abc:matrix.org") + ); + } + + #[test] + fn default_enabled_when_omitted() { + let cfg: StreamsConfig = toml::from_str(MINIMAL_TOML).expect("parse failed"); + assert_eq!(cfg.entries.len(), 1); + let entry = &cfg.entries[0]; + assert_eq!(entry.name, "minimal"); + assert!(entry.enabled, "enabled should default to true"); + } + + #[test] + fn empty_config_is_valid() { + let cfg: StreamsConfig = toml::from_str("").expect("empty parse failed"); + assert!(cfg.entries.is_empty()); + } +} diff --git a/crates/dirigent_core/src/sharing/factory.rs b/crates/dirigent_core/src/sharing/factory.rs new file mode 100644 index 0000000..a96788f --- /dev/null +++ b/crates/dirigent_core/src/sharing/factory.rs @@ -0,0 +1,143 @@ +//! Maps a `StreamConfig`'s `type` string to a concrete [`SessionStream`]. +//! +//! Stream implementations register themselves with a +//! [`StreamFactoryRegistry`] at boot so the runtime can construct streams +//! from `[[streams]]` config blocks without knowing every concrete type up +//! front. This is the stream-side analogue of the archivist's +//! `BackendRegistry` / `BackendFactory` pattern (Phase 3). + +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use thiserror::Error; + +use dirigent_protocol::streaming::SessionStream; + +use super::config::StreamConfig; + +/// Builds a concrete [`SessionStream`] from a `StreamConfig`. +/// +/// Each implementation advertises a single `kind` (matching the `type` +/// string in the TOML config). The registry routes config blocks to the +/// matching factory at boot. +#[async_trait] +pub trait StreamFactory: Send + Sync { + /// The `type` discriminator in `[[streams]]` — e.g. `"matrix"`, + /// `"langfuse"`. Must be unique across all registered factories. + fn kind(&self) -> &'static str; + + /// Build a running stream from its config. Implementations are + /// expected to read `cfg.params` (type-specific TOML table), establish + /// any transport, and return an [`Arc<dyn SessionStream>`] ready to + /// receive events. + async fn build(&self, cfg: &StreamConfig) -> Result<Arc<dyn SessionStream>, StreamBuildError>; +} + +/// Lookup table of `StreamFactory` implementations keyed by `kind()`. +/// +/// Populate this at startup (typically once, on the main thread) before +/// handing it off to the runtime. The registry is cheap to clone via +/// `Arc` at the call site. +#[derive(Default)] +pub struct StreamFactoryRegistry { + factories: HashMap<&'static str, Arc<dyn StreamFactory>>, +} + +impl std::fmt::Debug for StreamFactoryRegistry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("StreamFactoryRegistry") + .field("kinds", &self.factories.keys().collect::<Vec<_>>()) + .finish() + } +} + +impl StreamFactoryRegistry { + /// Construct an empty registry. + pub fn new() -> Self { + Self::default() + } + + /// Register a factory. Consumes and returns `self` so callers can + /// chain `.register(...)` calls at startup. + pub fn register<F: StreamFactory + 'static>(mut self, f: F) -> Self { + self.factories.insert(f.kind(), Arc::new(f)); + self + } + + /// Look up a factory by its `kind` discriminator, or `None` if no + /// factory for that kind has been registered. + pub fn get(&self, kind: &str) -> Option<&Arc<dyn StreamFactory>> { + self.factories.get(kind) + } +} + +/// Errors raised while building a `SessionStream` from its config. +#[derive(Debug, Error)] +pub enum StreamBuildError { + /// No factory is registered for `cfg.kind`. + #[error("unknown kind: {0}")] + UnknownKind(String), + /// Config was structurally valid TOML but semantically invalid + /// (missing required field, enum out of range, etc). + #[error("config: {0}")] + Config(String), + /// The factory accepted the config but the transport refused to come + /// up (network error, auth failure, …). + #[error("transport: {0}")] + Transport(String), +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + use dirigent_protocol::streaming::{ + BusEvent, SessionStream, StreamKind, StreamOutcome, StreamScope, StreamSummary, + }; + + struct DummyStream; + + #[async_trait] + impl SessionStream for DummyStream { + fn summary(&self) -> StreamSummary { + StreamSummary { + name: "dummy".to_string(), + kind: StreamKind::Custom, + target: "dummy".to_string(), + active_since: chrono::Utc::now(), + } + } + fn scope(&self) -> StreamScope { + StreamScope::ArchiveWide { acknowledged: false } + } + async fn on_event(&self, _event: &BusEvent) -> StreamOutcome { + StreamOutcome::Ok + } + async fn shutdown(&self) {} + } + + struct DummyFactory; + + #[async_trait] + impl StreamFactory for DummyFactory { + fn kind(&self) -> &'static str { + "dummy" + } + async fn build( + &self, + _cfg: &StreamConfig, + ) -> Result<Arc<dyn SessionStream>, StreamBuildError> { + Ok(Arc::new(DummyStream)) + } + } + + #[test] + fn register_and_lookup() { + let reg = StreamFactoryRegistry::new().register(DummyFactory); + assert!(reg.get("dummy").is_some()); + assert!(reg.get("missing").is_none()); + } +} diff --git a/crates/dirigent_core/src/sharing/health.rs b/crates/dirigent_core/src/sharing/health.rs new file mode 100644 index 0000000..e37b080 --- /dev/null +++ b/crates/dirigent_core/src/sharing/health.rs @@ -0,0 +1,118 @@ +//! Consecutive-failure health drift for streams (K=5 threshold). +//! +//! Mirrors the archivist's drift logic but tracks a single stream's +//! outcomes. Re-exports the shared `HealthStatus` enum from the +//! archivist's backend module to avoid duplication. + +pub use dirigent_archivist::backend::HealthStatus; + +/// Number of consecutive failures before a stream drifts from `Degraded` +/// to `Unavailable`. Matches the archivist's backend drift threshold. +pub const FAILURE_THRESHOLD: u32 = 5; + +/// Update health state after a successful event delivery. +/// Resets consecutive-failure counter; lifts Degraded → Healthy. +pub fn record_success(health: &mut HealthStatus, consecutive_failures: &mut u32) { + *consecutive_failures = 0; + if matches!(health, HealthStatus::Degraded { .. }) { + *health = HealthStatus::Healthy; + } +} + +/// Update health state after a failed event delivery. +/// Increments counter; drifts Healthy → Degraded → Unavailable at K=5. +pub fn record_failure( + health: &mut HealthStatus, + consecutive_failures: &mut u32, + reason: String, +) { + *consecutive_failures += 1; + if *consecutive_failures >= FAILURE_THRESHOLD { + *health = HealthStatus::Unavailable { + reason: format!("{} failures: {}", *consecutive_failures, reason), + }; + } else { + *health = HealthStatus::Degraded { reason }; + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn record_success_promotes_degraded_to_healthy() { + let mut health = HealthStatus::Degraded { + reason: "earlier hiccup".to_string(), + }; + let mut counter: u32 = 3; + record_success(&mut health, &mut counter); + assert_eq!(counter, 0); + assert_eq!(health, HealthStatus::Healthy); + } + + #[test] + fn record_success_keeps_healthy_and_resets_counter() { + let mut health = HealthStatus::Healthy; + let mut counter: u32 = 2; + record_success(&mut health, &mut counter); + assert_eq!(counter, 0); + assert_eq!(health, HealthStatus::Healthy); + } + + #[test] + fn record_success_does_not_rescue_unavailable() { + // The archivist rule is: once Unavailable, only operational events + // rescue. Success on a stream does clear the counter but we leave + // the final clearing decision to the caller; document the current + // behaviour: we only downgrade Degraded, not Unavailable. + let mut health = HealthStatus::Unavailable { + reason: "still broken".to_string(), + }; + let mut counter: u32 = 7; + record_success(&mut health, &mut counter); + assert_eq!(counter, 0); + assert!(matches!(health, HealthStatus::Unavailable { .. })); + } + + #[test] + fn record_failure_once_moves_healthy_to_degraded() { + let mut health = HealthStatus::Healthy; + let mut counter: u32 = 0; + record_failure(&mut health, &mut counter, "boom".to_string()); + assert_eq!(counter, 1); + match health { + HealthStatus::Degraded { reason } => assert_eq!(reason, "boom"), + other => panic!("expected Degraded, got {:?}", other), + } + } + + #[test] + fn record_failure_five_times_drifts_to_unavailable() { + let mut health = HealthStatus::Healthy; + let mut counter: u32 = 0; + for i in 0..5 { + record_failure(&mut health, &mut counter, format!("err-{i}")); + } + assert_eq!(counter, 5); + match health { + HealthStatus::Unavailable { reason } => { + assert!(reason.contains("5 failures"), "reason: {reason}"); + } + other => panic!("expected Unavailable after 5 failures, got {:?}", other), + } + } + + #[test] + fn record_failure_from_degraded_drifts_to_unavailable_at_threshold() { + let mut health = HealthStatus::Degraded { + reason: "early".to_string(), + }; + let mut counter: u32 = 4; + record_failure(&mut health, &mut counter, "final".to_string()); + assert_eq!(counter, 5); + assert!(matches!(health, HealthStatus::Unavailable { .. })); + } +} diff --git a/crates/dirigent_core/src/sharing/matrix.rs b/crates/dirigent_core/src/sharing/matrix.rs new file mode 100644 index 0000000..d092cac --- /dev/null +++ b/crates/dirigent_core/src/sharing/matrix.rs @@ -0,0 +1,217 @@ +//! `MatrixFactory`: build a Matrix [`SessionStream`] from a `[[streams]]` +//! config block. +//! +//! This is the first stream-side factory wired for the Phase 4 migration +//! (Task 18). The factory lives in `dirigent_core` rather than +//! `dirigent_matrix` because `StreamFactory` is defined here and +//! `dirigent_core` already depends on `dirigent_matrix` — putting it in +//! `dirigent_matrix` would create a cycle. +//! +//! ## Scope +//! +//! The factory's responsibility is narrow: parse `cfg.params`, resolve +//! the target Matrix room via a running `MatrixService`, and construct a +//! `MatrixSessionShare` configured for the stream path (no legacy +//! forwarder task). Command-proxy wiring (Matrix → Dirigent +//! `ConnectorCommand::SendMessage`) remains in +//! `CoreRuntime::create_matrix_share` for now; a follow-up will extend +//! the factory to cover that path. +//! +//! ## Config shape +//! +//! ```toml +//! [[streams]] +//! name = "matrix-main" +//! type = "matrix" +//! +//! [streams.scope] +//! kind = "session" +//! scroll_id = "01985d00-..." +//! +//! [streams.params] +//! connector_id = "opencode-1" # dirigent connector key +//! session_id = "native-abc123" # native connector session id +//! room_id = "!abc:matrix.org" # pre-existing room to attach to +//! homeserver_url = "https://matrix.org" # informational (service already knows) +//! ``` + +use std::sync::Arc; + +use async_trait::async_trait; +use serde::Deserialize; + +use dirigent_protocol::streaming::{SessionStream, StreamScope}; + +use super::config::StreamConfig; +use super::factory::{StreamBuildError, StreamFactory}; + +/// Stream-side factory for Matrix. See module docs for the expected +/// TOML shape. +pub struct MatrixFactory { + service: Arc<dirigent_matrix::MatrixService>, +} + +impl MatrixFactory { + /// Build a factory bound to a running `MatrixService`. The service + /// is expected to be logged in and sync-started by the time + /// `build()` is called; if it isn't, `build()` returns + /// `StreamBuildError::Transport`. + pub fn new(service: Arc<dirigent_matrix::MatrixService>) -> Self { + Self { service } + } +} + +#[derive(Debug, Deserialize)] +struct MatrixStreamParams { + /// Dirigent connector id that owns the session being bridged. + connector_id: String, + /// Native connector session id. + session_id: String, + /// Matrix room id — must be a pre-existing room the bot can access. + /// Room creation is still handled by + /// `CoreRuntime::create_matrix_share` until the factory path is + /// expanded to cover it. + room_id: String, + /// Informational; the logged-in `MatrixService` is the authority on + /// which homeserver to talk to. Accepted so configs can be + /// self-documenting and round-trip through TOML. + #[serde(default)] + #[allow(dead_code)] + homeserver_url: Option<String>, +} + +#[async_trait] +impl StreamFactory for MatrixFactory { + fn kind(&self) -> &'static str { + "matrix" + } + + async fn build( + &self, + cfg: &StreamConfig, + ) -> Result<Arc<dyn SessionStream>, StreamBuildError> { + // Scope must be Session; Matrix shares are intrinsically + // per-session bi-directional bridges. + let scroll_id = match &cfg.scope { + StreamScope::Session { scroll_id } => *scroll_id, + other => { + return Err(StreamBuildError::Config(format!( + "matrix stream requires scope.kind = \"session\", got {:?}", + other + ))); + } + }; + + // Parse type-specific params. + let params: MatrixStreamParams = cfg + .params + .clone() + .try_into() + .map_err(|e: toml::de::Error| { + StreamBuildError::Config(format!( + "matrix stream '{}': invalid params: {}", + cfg.name, e + )) + })?; + + // Look up the room via the service. We intentionally don't + // create or join rooms here — the room must already exist. + // Creation remains the responsibility of + // `CoreRuntime::create_matrix_share`. + let room = match self.service.room_by_id(¶ms.room_id).await { + Ok(Some(room)) => room, + Ok(None) => { + return Err(StreamBuildError::Transport(format!( + "matrix stream '{}': room '{}' not found on client \ + — ensure the bot has joined it", + cfg.name, params.room_id + ))); + } + Err(dirigent_matrix::MatrixError::NotLoggedIn) => { + return Err(StreamBuildError::Transport( + "matrix service is not logged in; cannot build stream" + .to_string(), + )); + } + Err(dirigent_matrix::MatrixError::Config(msg)) => { + return Err(StreamBuildError::Config(format!( + "matrix stream '{}': {}", + cfg.name, msg + ))); + } + Err(other) => { + return Err(StreamBuildError::Transport(format!( + "matrix stream '{}': {}", + cfg.name, other + ))); + } + }; + + // Construct the share for the stream path (no legacy forwarder + // task). We drop the command receiver on the floor here — the + // Matrix → Dirigent direction is not covered by this factory + // yet; see the follow-up TODO in the module docs. + let (share, _command_rx) = dirigent_matrix::MatrixSessionShare::new_for_stream( + params.connector_id, + params.session_id, + scroll_id, + params.room_id, + room, + ); + + Ok(Arc::new(share) as Arc<dyn SessionStream>) + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn factory_kind_is_matrix() { + // The factory's `kind()` is static and doesn't require a running + // MatrixService to read — covered by a minimal construction + // check in the integration test suite. + fn assert_is_factory<F: StreamFactory>(_: &F) {} + + // We can't easily build a MatrixService in a unit test (it needs + // an Account + data dir + SQLite store). The full smoke test + // lives in `packages/dirigent_matrix/tests/factory_test.rs` and + // the cross-crate registry test in + // `packages/dirigent_core/tests/matrix_migration_test.rs`. + // + // This module-local test exists only to assert that the impl + // block type-checks against the `StreamFactory` trait bound. + fn _compile_check(f: &MatrixFactory) { + assert_is_factory(f); + } + } + + #[test] + fn matrix_stream_params_deserialise_ok() { + let toml_str = r#" +connector_id = "opencode-1" +session_id = "native-abc" +room_id = "!foo:example.com" +homeserver_url = "https://matrix.org" +"#; + let p: MatrixStreamParams = toml::from_str(toml_str).expect("parse"); + assert_eq!(p.connector_id, "opencode-1"); + assert_eq!(p.session_id, "native-abc"); + assert_eq!(p.room_id, "!foo:example.com"); + assert_eq!(p.homeserver_url.as_deref(), Some("https://matrix.org")); + } + + #[test] + fn matrix_stream_params_reject_missing_required() { + // Missing room_id should fail. + let toml_str = r#" +connector_id = "opencode-1" +session_id = "native-abc" +"#; + let err: Result<MatrixStreamParams, _> = toml::from_str(toml_str); + assert!(err.is_err()); + } +} diff --git a/crates/dirigent_core/src/sharing/mock.rs b/crates/dirigent_core/src/sharing/mock.rs new file mode 100644 index 0000000..9873e67 --- /dev/null +++ b/crates/dirigent_core/src/sharing/mock.rs @@ -0,0 +1,80 @@ +//! Test-only `SessionStream` implementation for integration tests. +//! +//! Records every received `BusEvent` into an in-memory buffer; can be +//! configured to fail the next N events to exercise health drift paths. + +#![cfg(any(test, feature = "test-utils"))] + +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::atomic::{AtomicU32, Ordering}; + +use async_trait::async_trait; +use chrono::Utc; + +use dirigent_protocol::streaming::{ + BusEvent, SessionStream, StreamError, StreamKind, StreamOutcome, StreamScope, StreamSummary, +}; + +/// In-memory `SessionStream` used in integration tests. +pub struct MockStream { + scope: StreamScope, + name: String, + pub received: Arc<Mutex<Vec<BusEvent>>>, + pub fail_remaining: Arc<AtomicU32>, + pub shutdown_called: Arc<Mutex<bool>>, +} + +impl MockStream { + pub fn new(name: impl Into<String>, scope: StreamScope) -> Arc<Self> { + Arc::new(Self { + scope, + name: name.into(), + received: Arc::new(Mutex::new(Vec::new())), + fail_remaining: Arc::new(AtomicU32::new(0)), + shutdown_called: Arc::new(Mutex::new(false)), + }) + } + + /// Configure the next N `on_event` calls to return `StreamOutcome::Failed`. + pub fn fail_next(&self, n: u32) { + self.fail_remaining.store(n, Ordering::Relaxed); + } + + pub fn received_count(&self) -> usize { + self.received.lock().unwrap().len() + } + + pub fn was_shutdown(&self) -> bool { + *self.shutdown_called.lock().unwrap() + } +} + +#[async_trait] +impl SessionStream for MockStream { + fn summary(&self) -> StreamSummary { + StreamSummary { + name: self.name.clone(), + kind: StreamKind::Custom, + target: "mock".into(), + active_since: Utc::now(), + } + } + + fn scope(&self) -> StreamScope { + self.scope.clone() + } + + async fn on_event(&self, event: &BusEvent) -> StreamOutcome { + if self.fail_remaining.load(Ordering::Relaxed) > 0 { + self.fail_remaining.fetch_sub(1, Ordering::Relaxed); + return StreamOutcome::Failed(StreamError::Rejected("mock fail".into())); + } + self.received.lock().unwrap().push(event.clone()); + StreamOutcome::Ok + } + + async fn shutdown(&self) { + *self.shutdown_called.lock().unwrap() = true; + } +} diff --git a/crates/dirigent_core/src/sharing/mod.rs b/crates/dirigent_core/src/sharing/mod.rs new file mode 100644 index 0000000..ff37eec --- /dev/null +++ b/crates/dirigent_core/src/sharing/mod.rs @@ -0,0 +1,23 @@ +//! SharingBus, StreamRegistry, and replay. See docs/plans/2026-04-21-archivist-phase4-design.md. + +pub mod bus; +pub mod config; +pub mod factory; +pub mod health; +#[cfg(feature = "server")] +pub mod matrix; +#[cfg(any(test, feature = "test-utils"))] +pub mod mock; +pub mod registry; +pub mod replay; + +pub use bus::{BusReceiver, SharingBus}; +pub use config::{StreamConfig, StreamsConfig}; +pub use factory::{StreamBuildError, StreamFactory, StreamFactoryRegistry}; +pub use health::HealthStatus; +#[cfg(feature = "server")] +pub use matrix::MatrixFactory; +pub use registry::{StreamId, StreamInfo, StreamRegistration, StreamRegistry}; +pub use replay::{ReplayError, ReplayOptions, ReplayReport, ReplaySpeed}; +#[cfg(any(test, feature = "test-utils"))] +pub use mock::MockStream; diff --git a/crates/dirigent_core/src/sharing/registry.rs b/crates/dirigent_core/src/sharing/registry.rs new file mode 100644 index 0000000..3d9bb67 --- /dev/null +++ b/crates/dirigent_core/src/sharing/registry.rs @@ -0,0 +1,236 @@ +//! Owns all active streams. +//! +//! Populated at boot from `[[streams]]` config and at runtime via +//! [`StreamRegistry::attach`]. Each attached stream gets: +//! +//! - a bus subscription with an [`EventFilter`] derived from its scope, +//! - a dedicated worker task that drives `SessionStream::on_event`, +//! - a per-stream [`HealthStatus`] that drifts on consecutive failures +//! (see [`super::health`]). +//! +//! The worker is cancellable via a one-shot `mpsc::Sender<()>` on the +//! registration so [`detach`](StreamRegistry::detach) can stop delivery +//! deterministically before invoking `SessionStream::shutdown`. + +use std::sync::Arc; +use std::sync::atomic::{AtomicU32, Ordering}; + +use tokio::sync::{RwLock, mpsc}; +use tokio::task::JoinHandle; +use tracing::warn; +use uuid::Uuid; + +use dirigent_protocol::streaming::{ + BusReceiver, EventFilter, SessionStream, StreamOutcome, StreamScope, StreamSummary, +}; + +use super::bus::SharingBus; +use super::health::{HealthStatus, record_failure, record_success}; + +/// Per-subscriber queue capacity for a stream's bus subscription. Matches +/// the default used by `SharingBus::subscribe_all`. +const STREAM_QUEUE_CAPACITY: usize = 256; + +/// Identifier of a registered stream. Opaque wrapper around a `Uuid` so +/// that callers can't confuse stream ids with scroll/connector ids. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct StreamId(pub Uuid); + +/// Full registration record for a live stream. +/// +/// Held inside the registry behind an `Arc`; `detach` returns the Arc so +/// callers can inspect the final health state or await the worker handle +/// if they wish. Most fields are `Arc`s so the worker task and the +/// registry share state without serialising on a single lock. +pub struct StreamRegistration { + pub id: StreamId, + pub name: String, + pub stream: Arc<dyn SessionStream>, + pub scope: StreamScope, + pub enabled: bool, + pub health: Arc<RwLock<HealthStatus>>, + /// Number of consecutive delivery failures; drives the K=5 drift to + /// `Unavailable`. Stored atomic so the worker can update without + /// taking the health lock on every success. + pub consecutive_failures: Arc<AtomicU32>, + pub worker: JoinHandle<()>, + pub stop_tx: mpsc::Sender<()>, +} + +/// Snapshot view of a registered stream. Returned by +/// [`StreamRegistry::list`] for telemetry / UI. +#[derive(Debug, Clone)] +pub struct StreamInfo { + pub id: StreamId, + pub name: String, + pub summary: StreamSummary, + pub scope: StreamScope, + pub enabled: bool, + pub health: HealthStatus, + /// Current consecutive-failure count (mirrors + /// `StreamRegistration::consecutive_failures` at read time). + pub lagged_count: u64, +} + +/// The live registry of all streams wired to a [`SharingBus`]. +pub struct StreamRegistry { + bus: Arc<SharingBus>, + regs: RwLock<Vec<Arc<StreamRegistration>>>, +} + +impl StreamRegistry { + /// Build an empty registry bound to `bus`. + pub fn new(bus: Arc<SharingBus>) -> Self { + Self { + bus, + regs: RwLock::new(Vec::new()), + } + } + + /// Attach a running stream. + /// + /// Subscribes to the bus with a filter derived from `stream.scope()`, + /// spawns a worker task that ferries events into `on_event`, and + /// stores a registration with fresh health state. + pub async fn attach(&self, name: String, stream: Arc<dyn SessionStream>) -> StreamId { + let id = StreamId(Uuid::now_v7()); + let scope = stream.scope(); + let filter = scope_to_filter(&scope); + let bus_rx = self + .bus + .subscribe_filtered(filter, STREAM_QUEUE_CAPACITY) + .await; + + let (stop_tx, stop_rx) = mpsc::channel(1); + let health = Arc::new(RwLock::new(HealthStatus::Healthy)); + let failures = Arc::new(AtomicU32::new(0)); + + let stream_for_worker = Arc::clone(&stream); + let health_for_worker = Arc::clone(&health); + let failures_for_worker = Arc::clone(&failures); + let name_for_worker = name.clone(); + + let worker = tokio::spawn(run_stream_worker( + name_for_worker, + bus_rx, + stream_for_worker, + health_for_worker, + failures_for_worker, + stop_rx, + )); + + let reg = Arc::new(StreamRegistration { + id, + name, + stream, + scope, + enabled: true, + health, + consecutive_failures: failures, + worker, + stop_tx, + }); + self.regs.write().await.push(reg); + id + } + + /// Detach a stream. Signals the worker to exit, then invokes + /// `SessionStream::shutdown`. Returns the registration if the stream + /// was found, or `None` if the id was already detached. + pub async fn detach(&self, id: StreamId) -> Option<Arc<StreamRegistration>> { + let mut regs = self.regs.write().await; + let idx = regs.iter().position(|r| r.id == id)?; + let reg = regs.remove(idx); + drop(regs); + + // Best-effort stop: if the channel is already closed (worker panicked) + // we still want to run shutdown. + let _ = reg.stop_tx.send(()).await; + reg.stream.shutdown().await; + Some(reg) + } + + /// Look up a live stream by id. + pub async fn get_stream(&self, id: StreamId) -> Option<Arc<dyn SessionStream>> { + self.regs + .read() + .await + .iter() + .find(|r| r.id == id) + .map(|r| Arc::clone(&r.stream)) + } + + /// Snapshot every registered stream. Clones the underlying health + /// value so the returned `Vec` is safe to hand across async tasks + /// without holding any locks. + pub async fn list(&self) -> Vec<StreamInfo> { + let regs = self.regs.read().await; + let mut out = Vec::with_capacity(regs.len()); + for r in regs.iter() { + let health = r.health.read().await.clone(); + out.push(StreamInfo { + id: r.id, + name: r.name.clone(), + summary: r.stream.summary(), + scope: r.scope.clone(), + enabled: r.enabled, + health, + lagged_count: r.consecutive_failures.load(Ordering::Relaxed) as u64, + }); + } + out + } +} + +/// Translate a declarative [`StreamScope`] into the subscriber-side +/// [`EventFilter`] applied on the bus. +fn scope_to_filter(scope: &StreamScope) -> EventFilter { + match scope { + StreamScope::Session { scroll_id } => EventFilter::ScrollId(*scroll_id), + StreamScope::Connector { connector_uid } => EventFilter::ConnectorUid(*connector_uid), + StreamScope::ArchiveWide { .. } => EventFilter::All, + } +} + +/// Worker loop: pulls events from the bus subscription, forwards them to +/// the stream, and updates health state on every outcome. +async fn run_stream_worker( + name: String, + mut rx: BusReceiver, + stream: Arc<dyn SessionStream>, + health: Arc<RwLock<HealthStatus>>, + failures: Arc<AtomicU32>, + mut stop_rx: mpsc::Receiver<()>, +) { + loop { + tokio::select! { + biased; + _ = stop_rx.recv() => { + return; + } + maybe_evt = rx.rx.recv() => { + let Some(evt) = maybe_evt else { + // Bus hung up — registry should detach but we can exit + // here regardless. + return; + }; + match stream.on_event(&evt).await { + StreamOutcome::Ok | StreamOutcome::Skipped => { + let mut h = health.write().await; + let mut counter = failures.load(Ordering::Relaxed); + record_success(&mut h, &mut counter); + failures.store(counter, Ordering::Relaxed); + } + StreamOutcome::Failed(err) => { + let reason = err.to_string(); + warn!(stream = %name, error = %reason, "stream rejected event"); + let mut h = health.write().await; + let mut counter = failures.load(Ordering::Relaxed); + record_failure(&mut h, &mut counter, reason); + failures.store(counter, Ordering::Relaxed); + } + } + } + } + } +} diff --git a/crates/dirigent_core/src/sharing/replay.rs b/crates/dirigent_core/src/sharing/replay.rs new file mode 100644 index 0000000..6cc7c23 --- /dev/null +++ b/crates/dirigent_core/src/sharing/replay.rs @@ -0,0 +1,226 @@ +//! Replay: reads a session from the archive and dispatches synthetic +//! `BusEvent`s with `EventOrigin::Replay` directly to a target stream, +//! bypassing the `SharingBus`. +//! +//! Consumed by `CoreRuntime::replay_session_to_stream` (task 16). This +//! module intentionally exposes a free function that takes +//! `&Archivist`, `scroll_id`, `Arc<dyn SessionStream>`, and `ReplayOptions` +//! so it can be unit-tested without a full runtime. + +use std::sync::Arc; +use std::time::Duration; + +use uuid::Uuid; + +use dirigent_archivist::coordinator::Archivist; +use dirigent_archivist::error::ArchivistError; +use dirigent_archivist::types::MessageRecord; +use dirigent_protocol::{ + Event, Message, MessagePart, MessageRole, MessageStatus, + streaming::{BusEvent, EventOrigin, EventRouting, SessionStream, StreamOutcome}, +}; + +/// Options controlling a replay pass. +#[derive(Debug, Clone)] +pub struct ReplayOptions { + /// When true and the session is an AcpConnection, meta-events are read from + /// the archive (currently only counted — rendering meta events as + /// `BusEvent`s is out of scope for Phase 4). + pub include_meta_events: bool, + /// Pace events in real time (sleep between consecutive timestamps) or emit + /// as fast as the target stream can consume. + pub speed: ReplaySpeed, +} + +impl Default for ReplayOptions { + fn default() -> Self { + Self { + include_meta_events: false, + speed: ReplaySpeed::AsFastAsPossible, + } + } +} + +/// Controls inter-event pacing during replay. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReplaySpeed { + /// Sleep the wall-clock delta between consecutive message timestamps. + Realtime, + /// Emit events as fast as the stream can consume. + AsFastAsPossible, +} + +/// Outcome of a replay pass. +#[derive(Debug, Default, Clone)] +pub struct ReplayReport { + /// Total events dispatched to the stream (includes failed attempts). + pub events_sent: usize, + /// Events the stream rejected (`StreamOutcome::Failed`). + pub failures: usize, + /// Wall-clock duration of the replay in milliseconds. + pub duration_ms: u64, +} + +/// Errors raised by `replay_session_to_stream` itself. Stream-side failures are +/// counted in `ReplayReport::failures` rather than propagated, so one bad event +/// doesn't abort the replay. +#[derive(Debug, thiserror::Error)] +pub enum ReplayError { + /// The archive has no session with the given scroll id. + #[error("session not found: {0}")] + SessionNotFound(Uuid), + /// Archivist returned a non-SessionUnknown error (I/O, decoding, etc). + #[error("archivist: {0}")] + Archivist(String), +} + +/// Replay a session's archived messages to a single `SessionStream`. +/// +/// Reads metadata + messages from `archivist`, synthesises a `BusEvent` per +/// message with `EventOrigin::Replay { replay_id }`, and dispatches directly +/// to the target stream. The `SharingBus` is not involved; live events remain +/// unaffected. +/// +/// The function continues on stream failures and records the count in the +/// returned `ReplayReport`; only unrecoverable archive errors propagate. +pub async fn replay_session_to_stream( + archivist: &Archivist, + scroll_id: Uuid, + stream: Arc<dyn SessionStream>, + opts: ReplayOptions, +) -> Result<ReplayReport, ReplayError> { + let start = std::time::Instant::now(); + let replay_id = Uuid::new_v4(); + + // Load metadata. Translate the archivist's typed `SessionUnknown` into + // the replay-level `SessionNotFound` variant; everything else becomes + // `Archivist(_)` so callers can distinguish "missing" from "broken". + let metadata = archivist + .get_session_metadata(scroll_id, None) + .await + .map_err(|e| match e { + ArchivistError::SessionUnknown(id) => ReplayError::SessionNotFound(id), + other => ReplayError::Archivist(other.to_string()), + })?; + + let messages = archivist + .get_messages(scroll_id, None) + .await + .map_err(|e| ReplayError::Archivist(e.to_string()))?; + + let connector_uid = Some(metadata.connector_uid); + let native_session_id = metadata.native_session_id.clone(); + // We do not persist the orchestrator-side `connector_id` string in session + // metadata; the native session id is the best reversible handle we have. + let connector_id = native_session_id.clone().unwrap_or_default(); + + let mut events_sent = 0usize; + let mut failures = 0usize; + let mut prev_ts: Option<chrono::DateTime<chrono::Utc>> = None; + + for record in messages { + if matches!(opts.speed, ReplaySpeed::Realtime) { + if let Some(prev) = prev_ts { + let delta = record.ts.signed_duration_since(prev); + if let Ok(d) = delta.to_std() { + // Cap per-step sleep at 1h to avoid pathological archives + // where a session sat idle for days. + if d > Duration::from_millis(0) && d < Duration::from_secs(3600) { + tokio::time::sleep(d).await; + } + } + } + prev_ts = Some(record.ts); + } + + let message = message_from_record(&record, native_session_id.as_deref()); + let event = Event::MessageCompleted { + connector_id: connector_id.clone(), + message, + }; + + let mut routing = EventRouting::derive(&event, connector_uid, &connector_id); + // `derive()` leaves scroll_id=None (the bus cache normally fills it in). + // During replay we have the authoritative scroll_id up front. + routing.scroll_id = Some(scroll_id); + + let bus_event = BusEvent { + routing, + origin: EventOrigin::Replay { replay_id }, + event: Arc::new(event), + }; + + match stream.on_event(&bus_event).await { + StreamOutcome::Ok | StreamOutcome::Skipped => { + events_sent += 1; + } + StreamOutcome::Failed(_err) => { + failures += 1; + events_sent += 1; // count attempted regardless + } + } + } + + if opts.include_meta_events { + // Meta-events exist only on AcpConnection sessions; the read is + // cheap and idempotent, so we don't gate on `metadata.kind`. Render- + // as-BusEvent is out of scope for Phase 4 — we just probe the + // archive so missing meta-event storage surfaces as a log line + // here rather than later in the call chain. + let _ = archivist.get_meta_events(scroll_id, None).await; + } + + Ok(ReplayReport { + events_sent, + failures, + duration_ms: start.elapsed().as_millis() as u64, + }) +} + +/// Synthesize a protocol `Message` from an archived `MessageRecord`. +/// +/// The session_id we emit is the connector's native session id when known, +/// falling back to the stringified scroll_id so downstream routing at least +/// has a stable handle. +fn message_from_record(record: &MessageRecord, native_session_id: Option<&str>) -> Message { + Message { + id: record.message_id.to_string(), + session_id: native_session_id + .map(str::to_string) + .unwrap_or_else(|| record.session.to_string()), + role: parse_role(&record.role), + created_at: record.ts, + content: content_parts_from_record(record), + status: MessageStatus::Completed, + metadata: None, + } +} + +/// Parse the archivist's stringly-typed role into the protocol enum. +/// +/// `MessageRole` only has `User` and `Assistant` today; archived "system" / +/// "tool" rows (which the protocol layer does not support) fall back to +/// `User` rather than drop the message entirely. Lossy but preserves content. +fn parse_role(role: &str) -> MessageRole { + match role { + "assistant" => MessageRole::Assistant, + "user" => MessageRole::User, + // Protocol has no System/Tool variant; surface these as user messages + // so their content still reaches the stream. + _ => MessageRole::User, + } +} + +/// Prefer the archived structured `content_parts` (round-trips tool calls, +/// code blocks, etc). Fall back to a single `Text` part built from the +/// markdown rendering when parts are missing or fail to parse. +fn content_parts_from_record(record: &MessageRecord) -> Vec<MessagePart> { + if let Some(parts) = &record.content_parts { + if let Ok(parsed) = serde_json::from_value::<Vec<MessagePart>>(parts.clone()) { + return parsed; + } + } + vec![MessagePart::Text { + text: record.content_md.clone(), + }] +} diff --git a/crates/dirigent_core/src/tools/configuration.rs b/crates/dirigent_core/src/tools/configuration.rs new file mode 100644 index 0000000..41d3185 --- /dev/null +++ b/crates/dirigent_core/src/tools/configuration.rs @@ -0,0 +1,289 @@ +//! Tool configuration container for managing a set of [`ToolDirective`]s. +//! +//! A [`ToolConfiguration`] holds a flat list of directives and provides +//! query helpers used by the ACP intercept layer, the UI, and session +//! persistence (via metadata helpers). + +use super::directive::{ToolDirective, ToolHandler}; +use serde::{Deserialize, Serialize}; + +/// Metadata key used for storing tool configuration in session metadata. +const METADATA_KEY: &str = "tool_configuration"; + +/// A collection of tool directives that together define the tool routing +/// policy for a session or connector. +/// +/// # Examples +/// +/// ``` +/// use dirigent_core::tools::{ToolConfiguration, ToolDirective, ToolHandler}; +/// +/// let mut config = ToolConfiguration::new(); +/// config.set(ToolDirective::checked("shell_exec", ToolHandler::Deny)); +/// +/// assert!(config.should_intercept("shell_exec")); +/// assert!(!config.should_intercept("read_file")); +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)] +pub struct ToolConfiguration { + /// The ordered list of tool directives. + #[serde(default)] + pub directives: Vec<ToolDirective>, +} + +impl ToolConfiguration { + /// Create an empty configuration with no directives. + pub fn new() -> Self { + Self::default() + } + + /// Look up the directive for a given tool by its ID. + pub fn get(&self, tool_id: &str) -> Option<&ToolDirective> { + self.directives.iter().find(|d| d.tool_id == tool_id) + } + + /// Look up the directive for a given tool by its ID (mutable). + pub fn get_mut(&mut self, tool_id: &str) -> Option<&mut ToolDirective> { + self.directives.iter_mut().find(|d| d.tool_id == tool_id) + } + + /// Insert or replace a directive. If a directive for the same `tool_id` + /// already exists, it is replaced in place; otherwise the new directive + /// is appended. + pub fn set(&mut self, directive: ToolDirective) { + if let Some(existing) = self + .directives + .iter_mut() + .find(|d| d.tool_id == directive.tool_id) + { + *existing = directive; + } else { + self.directives.push(directive); + } + } + + /// Returns `true` if the tool with the given ID has a checked directive + /// with a non-Agent handler, meaning the call should be intercepted. + /// + /// Returns `false` for unknown tools (no directive) or passthrough + /// directives. + pub fn should_intercept(&self, tool_id: &str) -> bool { + self.get(tool_id) + .is_some_and(|d| d.checked && d.handler != ToolHandler::Agent) + } + + /// Returns the active handler for a tool, if it has a checked directive. + /// + /// Returns `None` for unknown tools or passthrough directives. + pub fn active_handler(&self, tool_id: &str) -> Option<&ToolHandler> { + self.get(tool_id) + .filter(|d| d.checked) + .map(|d| &d.handler) + } + + /// Returns the tool IDs of all directives that are checked (intercepted). + pub fn checked_tools(&self) -> Vec<&str> { + self.directives + .iter() + .filter(|d| d.checked) + .map(|d| d.tool_id.as_str()) + .collect() + } + + // -- Metadata helpers for session persistence -- + + /// Deserialize a `ToolConfiguration` from the `"tool_configuration"` key + /// in a session metadata JSON value. + /// + /// Returns `None` if the key is absent or the value cannot be deserialized. + pub fn from_metadata(metadata: &serde_json::Value) -> Option<Self> { + metadata + .get(METADATA_KEY) + .and_then(|v| serde_json::from_value(v.clone()).ok()) + } + + /// Serialize this configuration into the `"tool_configuration"` key of + /// a session metadata JSON object. + /// + /// If `metadata` is not already an object, this is a no-op. + pub fn to_metadata(&self, metadata: &mut serde_json::Value) { + if let Some(obj) = metadata.as_object_mut() { + if let Ok(value) = serde_json::to_value(self) { + obj.insert(METADATA_KEY.to_string(), value); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty_config_no_interception() { + let config = ToolConfiguration::new(); + assert!(!config.should_intercept("any_tool")); + assert!(config.active_handler("any_tool").is_none()); + assert!(config.checked_tools().is_empty()); + } + + #[test] + fn passthrough_directive() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::passthrough("read_file")); + + assert!(!config.should_intercept("read_file")); + // active_handler returns None for passthrough (not checked) + assert!(config.active_handler("read_file").is_none()); + assert!(config.checked_tools().is_empty()); + } + + #[test] + fn checked_directive_deny() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::checked("shell_exec", ToolHandler::Deny)); + + assert!(config.should_intercept("shell_exec")); + assert_eq!(config.active_handler("shell_exec"), Some(&ToolHandler::Deny)); + assert_eq!(config.checked_tools(), vec!["shell_exec"]); + } + + #[test] + fn checked_directive_hide() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::checked("dangerous_tool", ToolHandler::Hide)); + + assert!(config.should_intercept("dangerous_tool")); + assert_eq!( + config.active_handler("dangerous_tool"), + Some(&ToolHandler::Hide) + ); + } + + #[test] + fn checked_directive_agent_not_intercepted() { + // A checked directive with Agent handler should NOT be intercepted + // (Agent means passthrough even when checked). + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::checked("file_write", ToolHandler::Agent)); + + assert!(!config.should_intercept("file_write")); + assert_eq!( + config.active_handler("file_write"), + Some(&ToolHandler::Agent) + ); + assert_eq!(config.checked_tools(), vec!["file_write"]); + } + + #[test] + fn set_replaces_existing() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::checked("shell_exec", ToolHandler::Deny)); + assert!(config.should_intercept("shell_exec")); + + // Replace with passthrough + config.set(ToolDirective::passthrough("shell_exec")); + assert!(!config.should_intercept("shell_exec")); + + // Should still be exactly one directive + assert_eq!(config.directives.len(), 1); + } + + #[test] + fn checked_tools_list() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::passthrough("read_file")); + config.set(ToolDirective::checked("shell_exec", ToolHandler::Deny)); + config.set(ToolDirective::checked("file_write", ToolHandler::Editor)); + config.set(ToolDirective::passthrough("search")); + + let checked = config.checked_tools(); + assert_eq!(checked.len(), 2); + assert!(checked.contains(&"shell_exec")); + assert!(checked.contains(&"file_write")); + } + + #[test] + fn get_and_get_mut() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::passthrough("read_file")); + + assert!(config.get("read_file").is_some()); + assert!(config.get("nonexistent").is_none()); + + // Mutate via get_mut + if let Some(d) = config.get_mut("read_file") { + d.checked = true; + d.handler = ToolHandler::Dirigent; + } + assert!(config.should_intercept("read_file")); + } + + #[test] + fn serde_roundtrip() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::passthrough("read_file")); + config.set(ToolDirective::checked("shell_exec", ToolHandler::Deny)); + config.set(ToolDirective::checked( + "custom_tool", + ToolHandler::Plugin { name: "my_plugin".to_string() }, + )); + + let json = serde_json::to_string_pretty(&config).expect("serialize"); + let deserialized: ToolConfiguration = + serde_json::from_str(&json).expect("deserialize"); + assert_eq!(config, deserialized); + } + + #[test] + fn serde_roundtrip_plugin_variant() { + let handler = ToolHandler::Plugin { name: "security_scanner".to_string() }; + let json = serde_json::to_string(&handler).expect("serialize"); + + // Verify the tagged representation + assert!(json.contains("Plugin")); + assert!(json.contains("security_scanner")); + + let deserialized: ToolHandler = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(handler, deserialized); + } + + #[test] + fn from_metadata_roundtrip() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::checked("shell_exec", ToolHandler::Deny)); + config.set(ToolDirective::checked( + "plugin_tool", + ToolHandler::Plugin { name: "my_plugin".to_string() }, + )); + + // Write to metadata + let mut metadata = serde_json::json!({ "other_key": "other_value" }); + config.to_metadata(&mut metadata); + + // Read back from metadata + let restored = ToolConfiguration::from_metadata(&metadata) + .expect("should deserialize from metadata"); + assert_eq!(config, restored); + + // Original metadata keys are preserved + assert_eq!(metadata["other_key"], "other_value"); + } + + #[test] + fn from_metadata_missing_key() { + let metadata = serde_json::json!({ "something_else": 42 }); + assert!(ToolConfiguration::from_metadata(&metadata).is_none()); + } + + #[test] + fn to_metadata_noop_on_non_object() { + let mut config = ToolConfiguration::new(); + config.set(ToolDirective::passthrough("tool")); + + // Attempting to write to a non-object value should be a no-op + let mut metadata = serde_json::json!("not an object"); + config.to_metadata(&mut metadata); + assert_eq!(metadata, serde_json::json!("not an object")); + } +} diff --git a/crates/dirigent_core/src/tools/directive.rs b/crates/dirigent_core/src/tools/directive.rs new file mode 100644 index 0000000..5906333 --- /dev/null +++ b/crates/dirigent_core/src/tools/directive.rs @@ -0,0 +1,153 @@ +//! Tool directive types for controlling how individual tools are handled. +//! +//! A [`ToolDirective`] associates a tool (identified by its `tool_id` string, +//! matching `ToolCall.tool_name` from `dirigent_protocol`) with a [`ToolHandler`] +//! that determines how invocations of that tool should be routed. +//! +//! Directives come in two flavors: +//! - **Passthrough** (`checked: false`) -- the tool call is forwarded to the +//! agent without interception. +//! - **Checked** (`checked: true`) -- the tool call is intercepted and routed +//! to the specified [`ToolHandler`]. + +use serde::{Deserialize, Serialize}; + +/// Determines how a tool invocation is routed when intercepted. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type")] +pub enum ToolHandler { + /// Let the agent handle the tool call natively (default / passthrough). + Agent, + /// Route the tool call to the editor (e.g., file writes shown in the UI). + Editor, + /// Deny the tool call entirely. + Deny, + /// Hide the tool from the agent's tool list. + Hide, + /// Handle the tool call inside Dirigent itself. + Dirigent, + /// Route to a named plugin handler. + Plugin { name: String }, +} + +impl Default for ToolHandler { + fn default() -> Self { + Self::Agent + } +} + +/// A directive that binds a tool identifier to a handler configuration. +/// +/// # Examples +/// +/// ``` +/// use dirigent_core::tools::{ToolDirective, ToolHandler}; +/// +/// // Create a passthrough directive (agent handles the tool natively) +/// let passthrough = ToolDirective::passthrough("file_write"); +/// assert!(!passthrough.checked); +/// +/// // Create a checked directive that denies the tool +/// let denied = ToolDirective::checked("shell_exec", ToolHandler::Deny); +/// assert!(denied.checked); +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ToolDirective { + /// Tool identifier -- matches `ToolCall.tool_name` from dirigent_protocol. + pub tool_id: String, + /// Whether this tool should be intercepted (`true`) or passed through (`false`). + pub checked: bool, + /// The handler to use when `checked` is `true`. + pub handler: ToolHandler, +} + +impl ToolDirective { + /// Create a passthrough directive that lets the agent handle the tool natively. + pub fn passthrough(tool_id: impl Into<String>) -> Self { + Self { + tool_id: tool_id.into(), + checked: false, + handler: ToolHandler::Agent, + } + } + + /// Create a checked directive that intercepts the tool and routes it to the + /// given handler. + pub fn checked(tool_id: impl Into<String>, handler: ToolHandler) -> Self { + Self { + tool_id: tool_id.into(), + checked: true, + handler, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn passthrough_has_agent_handler() { + let d = ToolDirective::passthrough("read_file"); + assert_eq!(d.tool_id, "read_file"); + assert!(!d.checked); + assert_eq!(d.handler, ToolHandler::Agent); + } + + #[test] + fn checked_stores_handler() { + let d = ToolDirective::checked("shell_exec", ToolHandler::Deny); + assert_eq!(d.tool_id, "shell_exec"); + assert!(d.checked); + assert_eq!(d.handler, ToolHandler::Deny); + } + + #[test] + fn default_handler_is_agent() { + assert_eq!(ToolHandler::default(), ToolHandler::Agent); + } + + #[test] + fn serde_roundtrip_all_variants() { + let variants = vec![ + ToolHandler::Agent, + ToolHandler::Editor, + ToolHandler::Deny, + ToolHandler::Hide, + ToolHandler::Dirigent, + ToolHandler::Plugin { name: "my_plugin".to_string() }, + ]; + + for handler in variants { + let json = serde_json::to_string(&handler).expect("serialize"); + let deserialized: ToolHandler = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(handler, deserialized); + } + } + + #[cfg(feature = "server")] + #[test] + fn serde_toml_roundtrip_all_variants() { + // Regression test: adjacently-tagged enums with unit variants fail TOML serialization. + // ToolHandler must use internally-tagged (#[serde(tag = "type")]) to work with TOML. + let directives = vec![ + ToolDirective::passthrough("read_file"), + ToolDirective::checked("shell_exec", ToolHandler::Deny), + ToolDirective::checked("editor_write", ToolHandler::Editor), + ToolDirective::checked("hidden_tool", ToolHandler::Hide), + ToolDirective::checked("dirigent_tool", ToolHandler::Dirigent), + ToolDirective::checked("plugin_tool", ToolHandler::Plugin { name: "my_plugin".to_string() }), + ]; + + // Wrap in a table so TOML can serialize (TOML needs a top-level table) + #[derive(Serialize, Deserialize, PartialEq, Debug)] + struct Wrapper { + directives: Vec<ToolDirective>, + } + + let wrapper = Wrapper { directives }; + let toml_str = toml::to_string_pretty(&wrapper).expect("TOML serialize"); + let roundtripped: Wrapper = toml::from_str(&toml_str).expect("TOML deserialize"); + assert_eq!(wrapper, roundtripped); + } +} diff --git a/crates/dirigent_core/src/tools/mod.rs b/crates/dirigent_core/src/tools/mod.rs new file mode 100644 index 0000000..91cef43 --- /dev/null +++ b/crates/dirigent_core/src/tools/mod.rs @@ -0,0 +1,24 @@ +//! Tool directive and configuration types. +//! +//! This module defines the domain types for controlling how individual tools +//! are routed during an agent session. A tool can be passed through to the +//! agent, intercepted by the editor, denied, hidden, handled by Dirigent +//! itself, or delegated to a plugin. +//! +//! # Key types +//! +//! - [`ToolHandler`] -- Enum of routing targets for a tool call. +//! - [`ToolDirective`] -- Binds a tool ID to a handler (passthrough or checked). +//! - [`ToolConfiguration`] -- Collection of directives with query and +//! persistence helpers. +//! +//! These types are **not** feature-gated and are available on WASM targets, +//! since they are pure data types used by both the server runtime and the +//! web UI. + +pub mod configuration; +pub mod directive; + +// Re-export the main types at the module level for convenience. +pub use configuration::ToolConfiguration; +pub use directive::{ToolDirective, ToolHandler}; diff --git a/crates/dirigent_core/src/types.rs b/crates/dirigent_core/src/types.rs new file mode 100644 index 0000000..2811d74 --- /dev/null +++ b/crates/dirigent_core/src/types.rs @@ -0,0 +1,385 @@ +//! Core types for the Dirigent runtime +//! +//! This module defines the fundamental types used throughout the dirigent_core +//! system for connector management and orchestration. + +use serde::{Deserialize, Serialize}; + +// Re-export user types from dirigent_auth +pub use dirigent_auth::{User, UserId, UserProfile}; + +/// Unique identifier for a connector instance +/// +/// Each connector in the system has a unique ID that is used to reference it +/// across API calls, events, and internal state management. +pub type ConnectorId = String; + +/// Type of connector +/// +/// Identifies the underlying agent system or protocol that a connector bridges to. +/// Each variant represents a different integration with an external agent provider. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq, Hash)] +pub enum ConnectorKind { + /// OpenCode.ai connector (REST + SSE) + OpenCode, + /// Agent-Client Protocol connector + Acp, + /// Mock connector for testing + Mock, + /// Acceptor for incoming connections (accepts inbound ACP connections) + /// + /// Unlike other ConnectorKind variants which initiate outbound connections, + /// Acceptor represents an entry point for incoming connections from external + /// ACP clients. Sessions from Acceptors are routed to other connectors for processing. + Acceptor, + /// Gateway connector for handling messages locally + /// + /// The Gateway connector handles messages locally with configurable behavior + /// including echo mode and built-in commands. It serves as the default connector + /// for incoming ACP sessions before they are routed to an external agent. + Gateway, +} + +impl std::fmt::Display for ConnectorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ConnectorKind::OpenCode => write!(f, "OpenCode"), + ConnectorKind::Acp => write!(f, "ACP"), + ConnectorKind::Mock => write!(f, "Mock"), + ConnectorKind::Acceptor => write!(f, "Acceptor"), + ConnectorKind::Gateway => write!(f, "Gateway"), + } + } +} + +impl ConnectorKind { + /// Whether this connector kind can receive session transfers + /// + /// Returns true for connector types that can have sessions + /// transferred to them from Gateway or other sources. + pub fn supports_session_transfer(&self) -> bool { + match self { + ConnectorKind::OpenCode => true, + ConnectorKind::Acp => true, + ConnectorKind::Gateway => true, + ConnectorKind::Mock => false, // Testing only + ConnectorKind::Acceptor => false, // Entry point only, doesn't process + } + } +} + +/// Current state of a connector +/// +/// Tracks the lifecycle state of a connector from initialization through +/// active operation to shutdown. State transitions are managed by the +/// connector's task loop and can be observed by clients. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum ConnectorState { + /// Connector is being initialized but not yet connecting + Initializing, + /// Connector is attempting to establish connection + Connecting, + /// Connector is connected and operational + Ready, + /// Connector encountered an error (contains error message) + Error(String), + /// Connector has been stopped (intentionally or after unrecoverable error) + Stopped, +} + +/// Structured reason for connector error state. +/// +/// When `ConnectorState` is `Error(String)`, this provides a machine-readable +/// classification of the error for UI display and capability computation. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum ConnectorErrorKind { + /// Service is unreachable (max retries/rapid disconnects exhausted at startup) + Offline, + /// Connection keeps dropping (max rapid disconnects exhausted after stable connection) + Unstable, + /// Initial connection failed (auth, config, network error) + ConnectionFailed, +} + +/// Summary information about a connector +/// +/// Provides a lightweight view of a connector's essential properties. +/// Used by list operations and UI display without requiring full connector details. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ConnectorSummary { + /// Unique connector identifier + pub id: ConnectorId, + /// Type of connector + pub kind: ConnectorKind, + /// User who owns this connector + pub owner: UserId, + /// Human-readable title for display + pub title: String, + /// Current operational state + pub state: ConnectorState, + /// Working directory for this connector + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option<String>, + /// Supported features for this connector (e.g., "cancellation", "session_resume") + #[serde(default)] + pub supported_features: Vec<String>, + /// T034: Optional custom icon path for this connector + #[serde(skip_serializing_if = "Option::is_none")] + pub icon_path: Option<String>, + /// T035: Show connector type emoji as overlay on custom icon + #[serde(default)] + pub show_type_overlay: bool, + /// Whether archiving can be toggled for sessions on this connector + /// + /// When `false`, the UI should show archiving status as read-only (non-clickable). + /// When `true`, the user can toggle archiving on/off for sessions. + #[serde(default = "default_archiving_toggleable")] + pub archiving_toggleable: bool, + /// Agent type for automatic mode/model mapping (Claude, Codex, Gemini, or Custom) + /// + /// This is only set for ACP connectors and indicates the specific agent + /// implementation behind the connector. Used during session transfers to + /// automatically map Gateway mode/model identifiers to agent-specific ones. + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_type: Option<crate::connectors::acp::config::ConnectorAgentType>, + + /// Tool configuration for this connector. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_configuration: Option<crate::tools::ToolConfiguration>, + + /// Plugin assignments for this connector. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub plugin_assignments: Vec<crate::plugins::PluginAssignment>, + + /// Project assignments for this connector. + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub project_assignments: Vec<crate::plugins::ProjectAssignment>, + + /// Whether this connector is used in newly created projects by default. + #[serde(default = "default_true")] + pub use_in_new_projects: bool, + + /// Source of this connector (e.g., "user", "zed"). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub source: Option<String>, + + /// Zed agent name this connector was created from. + /// + /// When set, identifies this connector as Zed-managed and enables + /// automatic binary path refresh when Zed upgrades agents. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub zed_agent_name: Option<String>, + + /// Structured error classification when state is Error + /// + /// Provides machine-readable error type for UI indicators and capability computation. + /// Only set when `state` is `Error(_)`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub error_kind: Option<ConnectorErrorKind>, +} + +fn default_true() -> bool { + true +} + +fn default_archiving_toggleable() -> bool { + true +} + +impl ConnectorSummary { + /// Check if this connector supports cancellation + pub fn supports_cancellation(&self) -> bool { + self.supported_features.iter().any(|f| f == "cancellation") + } + + /// Check if this connector supports session resume + pub fn supports_session_resume(&self) -> bool { + self.supported_features + .iter() + .any(|f| f == "session_resume") + } + + /// Check if this connector supports listing sessions from the connector API + /// + /// When false, session listing falls back to archived sessions only. + /// Connectors that do not expose a session list endpoint (e.g., ACP agents + /// that only support creating new sessions) should not advertise this feature. + pub fn supports_session_list(&self) -> bool { + self.supported_features + .iter() + .any(|f| f == "session_list") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_type_aliases() { + let connector_id: ConnectorId = "conn-001".to_string(); + let user_id: UserId = uuid::Uuid::nil(); + assert_eq!(connector_id, "conn-001"); + assert_eq!(user_id, uuid::Uuid::nil()); + } + + #[test] + fn test_user_struct() { + let user = User::new(UserProfile { + name: Some("Test User".to_string()), + ..Default::default() + }); + assert_eq!(user.display_name(), "Test User"); + } + + #[test] + fn test_connector_kind_enum() { + let open_code = ConnectorKind::OpenCode; + let acp = ConnectorKind::Acp; + let mock = ConnectorKind::Mock; + + assert_eq!(open_code, ConnectorKind::OpenCode); + assert_ne!(open_code, acp); + assert_ne!(acp, mock); + } + + #[test] + fn test_connector_state_enum() { + let _initializing = ConnectorState::Initializing; + let _connecting = ConnectorState::Connecting; + let ready = ConnectorState::Ready; + let error = ConnectorState::Error("Connection failed".to_string()); + let stopped = ConnectorState::Stopped; + + assert_eq!(ready, ConnectorState::Ready); + assert_eq!( + error, + ConnectorState::Error("Connection failed".to_string()) + ); + assert_ne!(ready, stopped); + } + + #[test] + fn test_connector_summary_struct() { + let summary = ConnectorSummary { + id: "conn-001".to_string(), + kind: ConnectorKind::OpenCode, + owner: uuid::Uuid::nil(), + title: "My Connector".to_string(), + state: ConnectorState::Ready, + working_directory: None, + supported_features: vec!["cancellation".to_string()], + icon_path: None, + show_type_overlay: false, + archiving_toggleable: true, + agent_type: None, + tool_configuration: None, + plugin_assignments: vec![], + project_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + error_kind: None, + }; + + assert_eq!(summary.id, "conn-001"); + assert_eq!(summary.kind, ConnectorKind::OpenCode); + assert_eq!(summary.owner, uuid::Uuid::nil()); + assert_eq!(summary.title, "My Connector"); + assert_eq!(summary.state, ConnectorState::Ready); + assert!(summary.supports_cancellation()); + assert!(summary.archiving_toggleable); + } + + #[test] + fn test_serialization() { + let user = User::new(UserProfile { + name: Some("Test User".to_string()), + ..Default::default() + }); + + // Test serialization + let json = serde_json::to_string(&user).expect("Failed to serialize"); + assert!(json.contains("Test User")); + + // Test deserialization + let deserialized: User = serde_json::from_str(&json).expect("Failed to deserialize"); + assert_eq!(deserialized.id, user.id); + assert_eq!(deserialized.display_name(), user.display_name()); + } + + #[test] + fn test_connector_summary_serialization() { + use crate::connectors::acp::config::ConnectorAgentType; + + let summary = ConnectorSummary { + id: "conn-001".to_string(), + kind: ConnectorKind::Acp, + owner: uuid::Uuid::nil(), + title: "ACP Connector".to_string(), + state: ConnectorState::Connecting, + working_directory: None, + supported_features: vec!["session_resume".to_string()], + icon_path: Some("/path/to/icon.png".to_string()), + show_type_overlay: true, + archiving_toggleable: true, + agent_type: Some(ConnectorAgentType::Claude), + tool_configuration: None, + plugin_assignments: vec![], + project_assignments: vec![], + use_in_new_projects: true, + source: None, + zed_agent_name: None, + error_kind: None, + }; + + let json = serde_json::to_string(&summary).expect("Failed to serialize"); + let deserialized: ConnectorSummary = + serde_json::from_str(&json).expect("Failed to deserialize"); + + assert_eq!(deserialized.id, summary.id); + assert_eq!(deserialized.kind, summary.kind); + assert_eq!(deserialized.owner, summary.owner); + assert_eq!(deserialized.title, summary.title); + assert_eq!(deserialized.state, summary.state); + assert_eq!(deserialized.supported_features, summary.supported_features); + assert_eq!(deserialized.icon_path, summary.icon_path); + assert_eq!(deserialized.show_type_overlay, summary.show_type_overlay); + assert_eq!( + deserialized.archiving_toggleable, + summary.archiving_toggleable + ); + assert_eq!(deserialized.agent_type, summary.agent_type); + assert!(deserialized.supports_session_resume()); + assert!(!deserialized.supports_cancellation()); + } + + #[test] + fn test_archiving_toggleable_default() { + // Test that archiving_toggleable defaults to true when not specified in JSON + let json = r#"{ + "id": "conn-001", + "kind": "OpenCode", + "owner": "00000000-0000-0000-0000-000000000000", + "title": "Test Connector", + "state": "Ready", + "supported_features": [] + }"#; + + let deserialized: ConnectorSummary = + serde_json::from_str(json).expect("Failed to deserialize"); + assert!( + deserialized.archiving_toggleable, + "archiving_toggleable should default to true" + ); + } + + #[test] + fn test_supports_session_transfer() { + assert!(ConnectorKind::OpenCode.supports_session_transfer()); + assert!(ConnectorKind::Acp.supports_session_transfer()); + assert!(ConnectorKind::Gateway.supports_session_transfer()); + assert!(!ConnectorKind::Mock.supports_session_transfer()); + assert!(!ConnectorKind::Acceptor.supports_session_transfer()); + } +} diff --git a/crates/dirigent_core/src/vendors/claude.rs b/crates/dirigent_core/src/vendors/claude.rs new file mode 100644 index 0000000..0138d39 --- /dev/null +++ b/crates/dirigent_core/src/vendors/claude.rs @@ -0,0 +1,260 @@ +//! Claude vendor definition +//! +//! This module contains all Claude-specific knowledge including CLI detection, +//! mode/model mappings, and connector templates. + +use super::{MappingResult, ModeMapping, ModelMapping, VendorInfo}; +use crate::connectors::acp::config::ConnectorAgentType; + +/// Claude mode identifiers +pub mod modes { + pub const DEFAULT: &str = "default"; + pub const PLAN: &str = "plan"; + pub const ACCEPT_EDITS: &str = "acceptEdits"; + pub const BYPASS_PERMISSIONS: &str = "bypassPermissions"; +} + +/// Claude model identifiers +pub mod models { + pub const DEFAULT: &str = "default"; + pub const HAIKU: &str = "haiku"; + pub const SONNET: &str = "sonnet"; + pub const OPUS: &str = "opus"; +} + +/// Gateway mode identifiers (for reference) +mod gateway_modes { + pub const PLAN: &str = "plan"; + pub const READONLY: &str = "readonly"; + pub const ASK: &str = "ask"; + pub const WRITE: &str = "write"; + pub const YOLO: &str = "yolo"; +} + +/// Gateway model identifiers (for reference) +mod gateway_models { + pub const SIMPLE: &str = "simple"; + pub const DAILYDRIVER: &str = "dailydriver"; + pub const HIGH: &str = "high"; +} + +/// Static mode mappings for Claude +static MODE_MAPPINGS: &[ModeMapping] = &[ + ModeMapping { + gateway: gateway_modes::ASK, + vendor: modes::DEFAULT, + }, + ModeMapping { + gateway: gateway_modes::PLAN, + vendor: modes::PLAN, + }, + ModeMapping { + gateway: gateway_modes::WRITE, + vendor: modes::ACCEPT_EDITS, + }, + ModeMapping { + gateway: gateway_modes::YOLO, + vendor: modes::BYPASS_PERMISSIONS, + }, +]; + +/// Static model mappings for Claude +static MODEL_MAPPINGS: &[ModelMapping] = &[ + ModelMapping { + gateway: gateway_models::SIMPLE, + vendor: models::HAIKU, + }, + ModelMapping { + gateway: gateway_models::DAILYDRIVER, + vendor: models::SONNET, + }, + ModelMapping { + gateway: gateway_models::HIGH, + vendor: models::OPUS, + }, +]; + +/// Claude vendor implementation +pub struct ClaudeVendor; + +impl VendorInfo for ClaudeVendor { + fn id(&self) -> &'static str { + "claude" + } + + fn display_name(&self) -> &'static str { + "Claude" + } + + fn aliases(&self) -> &'static [&'static str] { + &["anthropic"] + } + + fn agent_type(&self) -> ConnectorAgentType { + ConnectorAgentType::Claude + } + + fn cli_command(&self) -> &'static str { + "claude" + } + + fn cli_args(&self) -> &'static [&'static str] { + &["--acp"] + } + + fn mode_mappings(&self) -> &'static [ModeMapping] { + MODE_MAPPINGS + } + + fn model_mappings(&self) -> &'static [ModelMapping] { + MODEL_MAPPINGS + } + + fn map_mode(&self, gateway_mode: &str) -> MappingResult { + match gateway_mode { + gateway_modes::ASK => MappingResult::exact(modes::DEFAULT), + gateway_modes::PLAN => MappingResult::exact(modes::PLAN), + gateway_modes::READONLY => MappingResult::approximate( + modes::PLAN, + "Gateway 'readonly' mode mapped to Claude 'plan' mode", + ), + gateway_modes::WRITE => MappingResult::exact(modes::ACCEPT_EDITS), + gateway_modes::YOLO => MappingResult::exact(modes::BYPASS_PERMISSIONS), + _ => MappingResult::approximate( + modes::DEFAULT, + format!("Unknown gateway mode '{}' mapped to Claude 'default' mode", gateway_mode), + ), + } + } + + fn map_model(&self, gateway_model: &str) -> MappingResult { + match gateway_model { + gateway_models::SIMPLE => MappingResult::exact(models::HAIKU), + gateway_models::DAILYDRIVER => MappingResult::exact(models::SONNET), + gateway_models::HIGH => MappingResult::exact(models::OPUS), + _ => MappingResult::approximate( + models::SONNET, + format!( + "Unknown gateway model '{}' mapped to Claude 'sonnet' model", + gateway_model + ), + ), + } + } + + fn reverse_map_mode(&self, vendor_mode: &str) -> MappingResult { + match vendor_mode { + modes::DEFAULT => MappingResult::exact(gateway_modes::ASK), + modes::PLAN => MappingResult::exact(gateway_modes::PLAN), + modes::ACCEPT_EDITS => MappingResult::exact(gateway_modes::WRITE), + modes::BYPASS_PERMISSIONS => MappingResult::exact(gateway_modes::YOLO), + _ => MappingResult::approximate( + gateway_modes::ASK, + format!("Unknown Claude mode '{}' mapped to Gateway 'ask' mode", vendor_mode), + ), + } + } + + fn reverse_map_model(&self, vendor_model: &str) -> MappingResult { + match vendor_model { + models::HAIKU => MappingResult::exact(gateway_models::SIMPLE), + models::SONNET => MappingResult::exact(gateway_models::DAILYDRIVER), + models::OPUS => MappingResult::exact(gateway_models::HIGH), + models::DEFAULT => MappingResult::exact(gateway_models::DAILYDRIVER), + _ => MappingResult::approximate( + gateway_models::DAILYDRIVER, + format!( + "Unknown Claude model '{}' mapped to Gateway 'dailydriver' model", + vendor_model + ), + ), + } + } + + fn acp_template(&self) -> serde_json::Value { + serde_json::json!({ + "transport": { + "type": "stdio", + "command": "claude", + "args": ["--acp"] + }, + "protocol_version": 1, + "cwd": ".", + "retry": { + "max_retries": 5, + "retry_delays_ms": [1000, 3000, 5000, 5000, 5000] + } + }) + } + + fn default_features(&self) -> &'static [&'static str] { + // Claude Code has limited ACP support: + // - NOT session_resume: generates ephemeral ACP session IDs + // - NOT cancellation: session/cancel not implemented + &[] + } + + fn icon_path(&self) -> &'static str { + "claude" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_claude_vendor_id() { + let vendor = ClaudeVendor; + assert_eq!(vendor.id(), "claude"); + assert_eq!(vendor.display_name(), "Claude"); + assert_eq!(vendor.aliases(), &["anthropic"]); + } + + #[test] + fn test_claude_mode_mappings() { + let vendor = ClaudeVendor; + + // Forward mappings + assert_eq!(vendor.map_mode("ask").mapped_id, "default"); + assert_eq!(vendor.map_mode("plan").mapped_id, "plan"); + assert_eq!(vendor.map_mode("write").mapped_id, "acceptEdits"); + assert_eq!(vendor.map_mode("yolo").mapped_id, "bypassPermissions"); + + // Reverse mappings + assert_eq!(vendor.reverse_map_mode("default").mapped_id, "ask"); + assert_eq!(vendor.reverse_map_mode("plan").mapped_id, "plan"); + assert_eq!(vendor.reverse_map_mode("acceptEdits").mapped_id, "write"); + assert_eq!(vendor.reverse_map_mode("bypassPermissions").mapped_id, "yolo"); + } + + #[test] + fn test_claude_model_mappings() { + let vendor = ClaudeVendor; + + // Forward mappings + assert_eq!(vendor.map_model("simple").mapped_id, "haiku"); + assert_eq!(vendor.map_model("dailydriver").mapped_id, "sonnet"); + assert_eq!(vendor.map_model("high").mapped_id, "opus"); + + // Reverse mappings + assert_eq!(vendor.reverse_map_model("haiku").mapped_id, "simple"); + assert_eq!(vendor.reverse_map_model("sonnet").mapped_id, "dailydriver"); + assert_eq!(vendor.reverse_map_model("opus").mapped_id, "high"); + } + + #[test] + fn test_claude_unknown_mode() { + let vendor = ClaudeVendor; + let result = vendor.map_mode("unknown"); + assert_eq!(result.mapped_id, "default"); + assert!(result.warning.is_some()); + } + + #[test] + fn test_claude_cli() { + let vendor = ClaudeVendor; + assert_eq!(vendor.cli_command(), "claude"); + assert_eq!(vendor.cli_args(), &["--acp"]); + } +} diff --git a/crates/dirigent_core/src/vendors/codex.rs b/crates/dirigent_core/src/vendors/codex.rs new file mode 100644 index 0000000..2a83b04 --- /dev/null +++ b/crates/dirigent_core/src/vendors/codex.rs @@ -0,0 +1,50 @@ +//! Codex vendor definition (placeholder) +//! +//! This module contains Codex-specific knowledge. Currently a placeholder +//! with pass-through behavior until Codex mappings are implemented. + +use super::{ModeMapping, ModelMapping, VendorInfo}; +use crate::connectors::acp::config::ConnectorAgentType; + +/// Codex vendor implementation (placeholder) +pub struct CodexVendor; + +impl VendorInfo for CodexVendor { + fn id(&self) -> &'static str { + "codex" + } + + fn display_name(&self) -> &'static str { + "Codex" + } + + fn aliases(&self) -> &'static [&'static str] { + &["openai"] + } + + fn agent_type(&self) -> ConnectorAgentType { + ConnectorAgentType::Codex + } + + fn cli_command(&self) -> &'static str { + "codex" + } + + fn mode_mappings(&self) -> &'static [ModeMapping] { + // Not yet implemented - pass through + &[] + } + + fn model_mappings(&self) -> &'static [ModelMapping] { + // Not yet implemented - pass through + &[] + } + + fn default_features(&self) -> &'static [&'static str] { + &[] + } + + fn icon_path(&self) -> &'static str { + "codex" + } +} diff --git a/crates/dirigent_core/src/vendors/custom.rs b/crates/dirigent_core/src/vendors/custom.rs new file mode 100644 index 0000000..a22ea48 --- /dev/null +++ b/crates/dirigent_core/src/vendors/custom.rs @@ -0,0 +1,57 @@ +//! Custom/fallback vendor definition +//! +//! This vendor is used when the agent type is unknown or custom. +//! All mappings pass through unchanged. + +use super::{MappingResult, ModeMapping, ModelMapping, VendorInfo}; +use crate::connectors::acp::config::ConnectorAgentType; + +/// Custom vendor implementation (pass-through) +pub struct CustomVendor; + +impl VendorInfo for CustomVendor { + fn id(&self) -> &'static str { + "custom" + } + + fn display_name(&self) -> &'static str { + "Custom" + } + + fn agent_type(&self) -> ConnectorAgentType { + ConnectorAgentType::Custom + } + + fn cli_command(&self) -> &'static str { + "" + } + + fn mode_mappings(&self) -> &'static [ModeMapping] { + &[] + } + + fn model_mappings(&self) -> &'static [ModelMapping] { + &[] + } + + // Override to pass through without warnings + fn map_mode(&self, gateway_mode: &str) -> MappingResult { + MappingResult::exact(gateway_mode) + } + + fn map_model(&self, gateway_model: &str) -> MappingResult { + MappingResult::exact(gateway_model) + } + + fn reverse_map_mode(&self, vendor_mode: &str) -> MappingResult { + MappingResult::exact(vendor_mode) + } + + fn reverse_map_model(&self, vendor_model: &str) -> MappingResult { + MappingResult::exact(vendor_model) + } + + fn icon_path(&self) -> &'static str { + "acp" + } +} diff --git a/crates/dirigent_core/src/vendors/gemini.rs b/crates/dirigent_core/src/vendors/gemini.rs new file mode 100644 index 0000000..6606592 --- /dev/null +++ b/crates/dirigent_core/src/vendors/gemini.rs @@ -0,0 +1,50 @@ +//! Gemini vendor definition (placeholder) +//! +//! This module contains Gemini-specific knowledge. Currently a placeholder +//! with pass-through behavior until Gemini mappings are implemented. + +use super::{ModeMapping, ModelMapping, VendorInfo}; +use crate::connectors::acp::config::ConnectorAgentType; + +/// Gemini vendor implementation (placeholder) +pub struct GeminiVendor; + +impl VendorInfo for GeminiVendor { + fn id(&self) -> &'static str { + "gemini" + } + + fn display_name(&self) -> &'static str { + "Gemini" + } + + fn aliases(&self) -> &'static [&'static str] { + &["google"] + } + + fn agent_type(&self) -> ConnectorAgentType { + ConnectorAgentType::Gemini + } + + fn cli_command(&self) -> &'static str { + "gemini" + } + + fn mode_mappings(&self) -> &'static [ModeMapping] { + // Not yet implemented - pass through + &[] + } + + fn model_mappings(&self) -> &'static [ModelMapping] { + // Not yet implemented - pass through + &[] + } + + fn default_features(&self) -> &'static [&'static str] { + &[] + } + + fn icon_path(&self) -> &'static str { + "gemini" + } +} diff --git a/crates/dirigent_core/src/vendors/mod.rs b/crates/dirigent_core/src/vendors/mod.rs new file mode 100644 index 0000000..a0acb75 --- /dev/null +++ b/crates/dirigent_core/src/vendors/mod.rs @@ -0,0 +1,222 @@ +//! Vendor registry for agent-specific knowledge +//! +//! This module consolidates vendor-specific configuration and behavior into a single location. +//! When adding support for a new vendor (Claude, Codex, Gemini, etc.), create a new module +//! in this directory implementing the `VendorInfo` trait. +//! +//! # Adding a New Vendor +//! +//! 1. Create a new file in `vendors/` (e.g., `gemini.rs`) +//! 2. Implement the `VendorInfo` trait for your vendor +//! 3. Register your vendor in the `VENDOR_REGISTRY` in this file +//! +//! # Example +//! +//! ```rust,ignore +//! use dirigent_core::vendors::{VENDOR_REGISTRY, VendorInfo}; +//! +//! // Look up vendor by ID +//! if let Some(vendor) = VENDOR_REGISTRY.get("claude") { +//! println!("CLI command: {}", vendor.cli_command()); +//! println!("Display name: {}", vendor.display_name()); +//! } +//! +//! // Detect which vendors are available on the system +//! let available = VENDOR_REGISTRY.detect_available(); +//! for vendor in available { +//! println!("{} is installed", vendor.display_name()); +//! } +//! ``` + +pub mod claude; +mod codex; +mod custom; +mod gemini; +mod registry; + +pub use claude::ClaudeVendor; +pub use codex::CodexVendor; +pub use custom::CustomVendor; +pub use gemini::GeminiVendor; +pub use registry::{VendorRegistry, VENDOR_REGISTRY}; + +use crate::connectors::acp::config::ConnectorAgentType; + +/// Mode mapping entry +#[derive(Debug, Clone)] +pub struct ModeMapping { + /// Gateway mode identifier + pub gateway: &'static str, + /// Vendor-specific mode identifier + pub vendor: &'static str, +} + +/// Model mapping entry +#[derive(Debug, Clone)] +pub struct ModelMapping { + /// Gateway model identifier + pub gateway: &'static str, + /// Vendor-specific model identifier + pub vendor: &'static str, +} + +/// Result of a mode/model mapping operation +#[derive(Debug, Clone, PartialEq)] +pub struct MappingResult { + /// The mapped identifier + pub mapped_id: String, + /// Warning message if mapping was approximate + pub warning: Option<String>, +} + +impl MappingResult { + /// Create an exact mapping result + pub fn exact(id: impl Into<String>) -> Self { + Self { + mapped_id: id.into(), + warning: None, + } + } + + /// Create an approximate mapping with a warning + pub fn approximate(id: impl Into<String>, warning: impl Into<String>) -> Self { + Self { + mapped_id: id.into(), + warning: Some(warning.into()), + } + } +} + +/// Trait defining vendor-specific knowledge +/// +/// Implement this trait for each vendor (Claude, Codex, Gemini, etc.) to consolidate +/// all vendor-specific configuration in one place. +pub trait VendorInfo: Send + Sync { + /// Unique vendor identifier (lowercase, e.g., "claude", "codex") + fn id(&self) -> &'static str; + + /// Human-readable display name (e.g., "Claude", "Codex") + fn display_name(&self) -> &'static str; + + /// Alternative names that map to this vendor (e.g., ["anthropic"] for Claude) + fn aliases(&self) -> &'static [&'static str] { + &[] + } + + /// The corresponding ConnectorAgentType enum value + fn agent_type(&self) -> ConnectorAgentType; + + // ========================================================================= + // CLI Detection + // ========================================================================= + + /// CLI command name for this vendor (e.g., "claude" for Claude) + fn cli_command(&self) -> &'static str; + + /// Default CLI arguments (e.g., ["--acp"] for Claude) + fn cli_args(&self) -> &'static [&'static str] { + &[] + } + + /// CLI command name on Windows (defaults to cli_command) + fn cli_command_windows(&self) -> &'static str { + self.cli_command() + } + + // ========================================================================= + // Mode/Model Mappings + // ========================================================================= + + /// Mode mappings from Gateway identifiers to vendor-specific identifiers + fn mode_mappings(&self) -> &'static [ModeMapping] { + &[] + } + + /// Model mappings from Gateway identifiers to vendor-specific identifiers + fn model_mappings(&self) -> &'static [ModelMapping] { + &[] + } + + /// Map a Gateway mode to vendor-specific mode + fn map_mode(&self, gateway_mode: &str) -> MappingResult { + for mapping in self.mode_mappings() { + if mapping.gateway == gateway_mode { + return MappingResult::exact(mapping.vendor); + } + } + // Default: pass through with warning + MappingResult::approximate( + gateway_mode, + format!( + "Unknown mode '{}' for {}", + gateway_mode, + self.display_name() + ), + ) + } + + /// Map a Gateway model to vendor-specific model + fn map_model(&self, gateway_model: &str) -> MappingResult { + for mapping in self.model_mappings() { + if mapping.gateway == gateway_model { + return MappingResult::exact(mapping.vendor); + } + } + // Default: pass through with warning + MappingResult::approximate( + gateway_model, + format!( + "Unknown model '{}' for {}", + gateway_model, + self.display_name() + ), + ) + } + + /// Reverse map a vendor-specific mode to Gateway mode + fn reverse_map_mode(&self, vendor_mode: &str) -> MappingResult { + for mapping in self.mode_mappings() { + if mapping.vendor == vendor_mode { + return MappingResult::exact(mapping.gateway); + } + } + // Default: pass through with warning + MappingResult::approximate( + vendor_mode, + format!("Unknown {} mode '{}'", self.display_name(), vendor_mode), + ) + } + + /// Reverse map a vendor-specific model to Gateway model + fn reverse_map_model(&self, vendor_model: &str) -> MappingResult { + for mapping in self.model_mappings() { + if mapping.vendor == vendor_model { + return MappingResult::exact(mapping.gateway); + } + } + // Default: pass through with warning + MappingResult::approximate( + vendor_model, + format!("Unknown {} model '{}'", self.display_name(), vendor_model), + ) + } + + // ========================================================================= + // Templates & Defaults + // ========================================================================= + + /// Default ACP connector template for this vendor + fn acp_template(&self) -> serde_json::Value { + serde_json::json!({}) + } + + /// Default supported features for this vendor + fn default_features(&self) -> &'static [&'static str] { + &[] + } + + /// Default icon path for this vendor + fn icon_path(&self) -> &'static str { + "acp" + } +} diff --git a/crates/dirigent_core/src/vendors/registry.rs b/crates/dirigent_core/src/vendors/registry.rs new file mode 100644 index 0000000..aad4ae0 --- /dev/null +++ b/crates/dirigent_core/src/vendors/registry.rs @@ -0,0 +1,220 @@ +//! Vendor registry for looking up vendor implementations +//! +//! The registry provides a centralized way to look up vendors by ID, alias, +//! or ConnectorAgentType. + +use once_cell::sync::Lazy; +use std::collections::HashMap; + +use super::{ClaudeVendor, CodexVendor, CustomVendor, GeminiVendor, VendorInfo}; +use crate::connectors::acp::config::ConnectorAgentType; + +/// Global vendor registry instance +pub static VENDOR_REGISTRY: Lazy<VendorRegistry> = Lazy::new(VendorRegistry::new); + +/// Registry of all known vendors +pub struct VendorRegistry { + /// Vendors indexed by their primary ID + vendors: HashMap<&'static str, Box<dyn VendorInfo>>, + /// Alias to vendor ID mapping + aliases: HashMap<&'static str, &'static str>, +} + +impl VendorRegistry { + /// Create a new vendor registry with all known vendors + pub fn new() -> Self { + let mut registry = Self { + vendors: HashMap::new(), + aliases: HashMap::new(), + }; + + // Register all vendors + registry.register(Box::new(ClaudeVendor)); + registry.register(Box::new(CodexVendor)); + registry.register(Box::new(GeminiVendor)); + registry.register(Box::new(CustomVendor)); + + registry + } + + /// Register a vendor in the registry + fn register(&mut self, vendor: Box<dyn VendorInfo>) { + let id = vendor.id(); + + // Register aliases + for alias in vendor.aliases() { + self.aliases.insert(alias, id); + } + + // Register the vendor + self.vendors.insert(id, vendor); + } + + /// Look up a vendor by ID or alias + pub fn get(&self, id_or_alias: &str) -> Option<&dyn VendorInfo> { + let id = self + .aliases + .get(id_or_alias) + .copied() + .unwrap_or(id_or_alias); + self.vendors.get(id).map(|v| v.as_ref()) + } + + /// Look up a vendor by ConnectorAgentType + pub fn get_by_agent_type(&self, agent_type: ConnectorAgentType) -> Option<&dyn VendorInfo> { + match agent_type { + ConnectorAgentType::Claude => self.get("claude"), + ConnectorAgentType::Codex => self.get("codex"), + ConnectorAgentType::Gemini => self.get("gemini"), + ConnectorAgentType::Custom => self.get("custom"), + } + } + + /// List all registered vendors + pub fn list_all(&self) -> Vec<&dyn VendorInfo> { + self.vendors.values().map(|v| v.as_ref()).collect() + } + + /// List vendor IDs + pub fn list_ids(&self) -> Vec<&'static str> { + self.vendors.keys().copied().collect() + } + + /// Check if a vendor CLI is available on the system + #[cfg(not(target_arch = "wasm32"))] + pub fn is_cli_available(&self, vendor_id: &str) -> bool { + use std::process::Command; + + let Some(vendor) = self.get(vendor_id) else { + return false; + }; + + let command = vendor.cli_command(); + if command.is_empty() { + return false; + } + + // Use 'which' on Unix, 'where' on Windows + #[cfg(target_os = "windows")] + let check_cmd = "where"; + #[cfg(not(target_os = "windows"))] + let check_cmd = "which"; + + Command::new(check_cmd) + .arg(command) + .output() + .map(|o| o.status.success()) + .unwrap_or(false) + } + + /// Detect which vendor CLIs are available on the system + #[cfg(not(target_arch = "wasm32"))] + pub fn detect_available(&self) -> Vec<&dyn VendorInfo> { + self.vendors + .values() + .filter(|v| { + let cmd = v.cli_command(); + !cmd.is_empty() && self.is_cli_available(v.id()) + }) + .map(|v| v.as_ref()) + .collect() + } +} + +impl Default for VendorRegistry { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_by_id() { + let registry = VendorRegistry::new(); + + let claude = registry.get("claude").unwrap(); + assert_eq!(claude.id(), "claude"); + assert_eq!(claude.display_name(), "Claude"); + + let codex = registry.get("codex").unwrap(); + assert_eq!(codex.id(), "codex"); + + let gemini = registry.get("gemini").unwrap(); + assert_eq!(gemini.id(), "gemini"); + + let custom = registry.get("custom").unwrap(); + assert_eq!(custom.id(), "custom"); + } + + #[test] + fn test_get_by_alias() { + let registry = VendorRegistry::new(); + + // "anthropic" should resolve to "claude" + let claude = registry.get("anthropic").unwrap(); + assert_eq!(claude.id(), "claude"); + + // "openai" should resolve to "codex" + let codex = registry.get("openai").unwrap(); + assert_eq!(codex.id(), "codex"); + + // "google" should resolve to "gemini" + let gemini = registry.get("google").unwrap(); + assert_eq!(gemini.id(), "gemini"); + } + + #[test] + fn test_get_by_agent_type() { + let registry = VendorRegistry::new(); + + let claude = registry + .get_by_agent_type(ConnectorAgentType::Claude) + .unwrap(); + assert_eq!(claude.id(), "claude"); + + let codex = registry + .get_by_agent_type(ConnectorAgentType::Codex) + .unwrap(); + assert_eq!(codex.id(), "codex"); + + let gemini = registry + .get_by_agent_type(ConnectorAgentType::Gemini) + .unwrap(); + assert_eq!(gemini.id(), "gemini"); + + let custom = registry + .get_by_agent_type(ConnectorAgentType::Custom) + .unwrap(); + assert_eq!(custom.id(), "custom"); + } + + #[test] + fn test_list_all() { + let registry = VendorRegistry::new(); + let vendors = registry.list_all(); + + assert_eq!(vendors.len(), 4); + + let ids: Vec<_> = vendors.iter().map(|v| v.id()).collect(); + assert!(ids.contains(&"claude")); + assert!(ids.contains(&"codex")); + assert!(ids.contains(&"gemini")); + assert!(ids.contains(&"custom")); + } + + #[test] + fn test_unknown_vendor() { + let registry = VendorRegistry::new(); + assert!(registry.get("unknown").is_none()); + } + + #[test] + fn test_global_registry() { + // Test the static VENDOR_REGISTRY + let claude = VENDOR_REGISTRY.get("claude").unwrap(); + assert_eq!(claude.id(), "claude"); + } +} diff --git a/crates/dirigent_core/tests/acp_advanced_tests.rs b/crates/dirigent_core/tests/acp_advanced_tests.rs new file mode 100644 index 0000000..e25e655 --- /dev/null +++ b/crates/dirigent_core/tests/acp_advanced_tests.rs @@ -0,0 +1,972 @@ +//! TEST-003 through TEST-007: Advanced ACP Tests +//! +//! This file contains advanced integration tests covering: +//! - TEST-003: Session lifecycle tests (create, prompt, cancel, load) +//! - TEST-004: Error condition tests (network failures, invalid JSON, timeouts) +//! - TEST-005: Reconnection tests (during init, active session, max retries) +//! - TEST-006: Edge case tests (empty messages, long messages, rapid-fire, leaks) +//! - TEST-007: Performance tests (latency, throughput, memory, concurrent sessions) + +#![cfg(feature = "server")] + +use dirigent_core::connectors::acp::{AcpConfig, AcpConnector}; +use dirigent_core::connectors::{Connector, ConnectorCommand}; +use dirigent_core::types::ConnectorState; +use dirigent_protocol::Event; +use std::path::PathBuf; +use std::time::{Duration, Instant}; +use tokio::time::timeout; + +// ============================================================================ +// Test Helpers +// ============================================================================ + +fn test_fixture_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("fixtures") + .join("acp_test.yaml") +} + +fn mocker_binary() -> String { + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let project_root = manifest_dir.parent().unwrap().parent().unwrap(); + let debug_path = project_root + .join("target") + .join("debug") + .join("dirigate.exe"); + + if debug_path.exists() { + debug_path.to_string_lossy().to_string() + } else { + "cargo".to_string() + } +} + +fn mocker_args() -> Vec<String> { + let bin = mocker_binary(); + let fixture = test_fixture_path(); + + if bin == "cargo" { + vec![ + "run".to_string(), + "--package".to_string(), + "dirigate".to_string(), + "--".to_string(), + "serve".to_string(), + "--stdio".to_string(), + "--fixtures".to_string(), + fixture.to_string_lossy().to_string(), + ] + } else { + vec![ + "serve".to_string(), + "--stdio".to_string(), + "--fixtures".to_string(), + fixture.to_string_lossy().to_string(), + ] + } +} + +fn create_test_connector(id: &str) -> AcpConnector { + let config = AcpConfig::stdio(mocker_binary(), mocker_args()) + .with_retry_max_attempts(3) + .with_retry_initial_delay(Duration::from_millis(500)); + + AcpConnector::new( + id.to_string(), + uuid::Uuid::nil(), + "Test ACP Connector".to_string(), + config, + dirigent_core::sharing::bus::SharingBus::new(), + ) + .expect("Failed to create connector") +} + +async fn wait_for_event<F>( + events_rx: &mut tokio::sync::broadcast::Receiver<Event>, + predicate: F, + timeout_duration: Duration, +) -> anyhow::Result<Event> +where + F: Fn(&Event) -> bool, +{ + let result = timeout(timeout_duration, async { + loop { + match events_rx.recv().await { + Ok(event) => { + if predicate(&event) { + return Ok(event); + } + } + Err(e) => return Err(anyhow::anyhow!("Recv error: {}", e)), + } + } + }) + .await; + + match result { + Ok(Ok(event)) => Ok(event), + Ok(Err(e)) => Err(e), + Err(_) => Err(anyhow::anyhow!("Timeout")), + } +} + +// ============================================================================ +// TEST-003: Session Lifecycle Tests +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t003_01_session_create_flow() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("lifecycle-create"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + // Create session + connector + .command_tx() + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + // Verify session created + let session = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + assert_eq!(session.title, "Lifecycle Test"); + tracing::info!("✓ Session created successfully"); + + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t003_02_session_prompt_flow() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("lifecycle-prompt"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + // Create session + connector + .command_tx() + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + // Send prompt + connector + .command_tx() + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "Test prompt".to_string(), + }) + .await?; + + // Verify message flow: Started -> Content -> Completed + wait_for_event( + &mut events, + |e| matches!(e, Event::MessageStarted { .. }), + Duration::from_secs(5), + ) + .await?; + tracing::info!("✓ Message started"); + + // Wait for content or completion + let mut got_response = false; + for _ in 0..10 { + if let Ok(event) = timeout(Duration::from_secs(1), events.recv()).await { + match event? { + Event::SessionUpdate { .. } => { + got_response = true; + break; + } + Event::MessageCompleted { .. } => { + got_response = true; + break; + } + _ => {} + } + } + } + + assert!(got_response, "Should receive response"); + tracing::info!("✓ Prompt flow completed"); + + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t003_03_session_cancel_flow() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("lifecycle-cancel"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + // Setup + connector + .command_tx() + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + // Start message + connector + .command_tx() + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "Long message".to_string(), + }) + .await?; + + wait_for_event( + &mut events, + |e| matches!(e, Event::MessageStarted { .. }), + Duration::from_secs(5), + ) + .await?; + + // Cancel + connector + .command_tx() + .send(ConnectorCommand::CancelGeneration { session_id }) + .await?; + + tracing::info!("✓ Cancel sent successfully"); + + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t003_04_session_load_flow() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("lifecycle-load"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + // Load existing session from fixture + connector + .command_tx() + .send(ConnectorCommand::LoadSession { + session_id: "test-session-1".to_string(), + cwd: ".".to_string(), + mcp_servers: None, + }) + .await?; + + let session = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + assert_eq!(session.id, "test-session-1"); + tracing::info!("✓ Session loaded from fixture"); + + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + Ok(()) +} + +// ============================================================================ +// TEST-004: Error Condition Tests +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t004_01_invalid_binary_path() { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let config = AcpConfig::stdio("nonexistent-binary".to_string(), vec![]); + let result = AcpConnector::new( + "error-test".to_string(), + uuid::Uuid::nil(), + "Error Test".to_string(), + config, + dirigent_core::sharing::bus::SharingBus::new(), + ); + + // Should fail validation or fail to start + // Either is acceptable - the key is it doesn't panic + if let Ok(connector) = result { + let mut events = connector.subscribe(); + let _handle = connector.start_task().await; + + // Should either fail to connect or error out + let result = timeout(Duration::from_secs(5), events.recv()).await; + match result { + Ok(Ok(Event::Error { .. })) => { + tracing::info!("✓ Error handled gracefully"); + } + Ok(Ok(Event::Connected)) => { + panic!("Should not connect to nonexistent binary"); + } + _ => { + tracing::info!("✓ Connection failed as expected"); + } + } + } else { + tracing::info!("✓ Config validation caught invalid binary"); + } +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t004_02_timeout_handling() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + // Create connector with very short timeout + let config = AcpConfig::stdio(mocker_binary(), mocker_args()) + .with_request_timeout(Duration::from_millis(100)); + + let connector = AcpConnector::new( + "timeout-test".to_string(), + uuid::Uuid::nil(), + "Timeout Test".to_string(), + config, + dirigent_core::sharing::bus::SharingBus::new(), + )?; + let mut events = connector.subscribe(); + let _handle = connector.start_task().await; + + // Try to connect - might timeout during handshake + match timeout(Duration::from_secs(10), events.recv()).await { + Ok(Ok(Event::Connected)) => { + tracing::info!("✓ Connected despite short timeout"); + } + Ok(Ok(Event::Error { message })) => { + tracing::info!("✓ Timeout error handled: {}", message); + } + Err(_) => { + tracing::info!("✓ Timeout occurred as expected"); + } + _ => {} + } + + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t004_03_connection_state_on_error() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let config = AcpConfig::stdio("nonexistent".to_string(), vec![]); + + if let Ok(connector) = AcpConnector::new( + "state-test".to_string(), + uuid::Uuid::nil(), + "State Test".to_string(), + config, + dirigent_core::sharing::bus::SharingBus::new(), + ) { + let state = connector.state_arc(); + let _handle = connector.start_task().await; + + // Wait a bit for state to update + tokio::time::sleep(Duration::from_secs(2)).await; + + let current_state = state.read().await; + match *current_state { + ConnectorState::Error { .. } | ConnectorState::Stopped => { + tracing::info!("✓ State correctly reflects error: {:?}", *current_state); + } + _ => { + tracing::warn!("State is: {:?}", *current_state); + } + } + } + + Ok(()) +} + +// ============================================================================ +// TEST-005: Reconnection Tests +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t005_01_reconnect_after_disconnect() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("reconnect-test"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + // Wait for initial connection + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + tracing::info!("✓ Initial connection established"); + + // Stop and restart + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + + // Create new connector with same config + let connector2 = create_test_connector("reconnect-test-2"); + let mut events2 = connector2.subscribe(); + let handle2 = connector2.start_task().await; + + // Should reconnect + wait_for_event( + &mut events2, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + tracing::info!("✓ Reconnected successfully"); + + connector2.stop(); + timeout(Duration::from_secs(5), handle2).await??; + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t005_02_max_retries_respected() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let config = AcpConfig::stdio("nonexistent".to_string(), vec![]) + .with_retry_max_attempts(2) + .with_retry_initial_delay(Duration::from_millis(100)); + + if let Ok(connector) = AcpConnector::new( + "retry-test".to_string(), + uuid::Uuid::nil(), + "Retry Test".to_string(), + config, + dirigent_core::sharing::bus::SharingBus::new(), + ) { + let _events = connector.subscribe(); + let _handle = connector.start_task().await; + + // Should eventually give up after max retries + tokio::time::sleep(Duration::from_secs(3)).await; + + // Check state is error or stopped + let state = connector.state_arc(); + let current = state.read().await; + match *current { + ConnectorState::Error { .. } | ConnectorState::Stopped => { + tracing::info!("✓ Max retries respected, connector stopped"); + } + _ => { + tracing::warn!("State: {:?}", *current); + } + } + } + + Ok(()) +} + +// ============================================================================ +// TEST-006: Edge Case Tests +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t006_01_empty_message() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("empty-msg-test"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + connector + .command_tx() + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + // Send empty message + connector + .command_tx() + .send(ConnectorCommand::SendMessage { + session_id, + text: "".to_string(), + }) + .await?; + + // Should handle gracefully (either respond or error) + let result = timeout(Duration::from_secs(3), events.recv()).await; + tracing::info!("✓ Empty message handled: {:?}", result.is_ok()); + + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t006_02_very_long_message() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("long-msg-test"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + connector + .command_tx() + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + // Send very long message (10KB) + let long_content = "x".repeat(10_000); + connector + .command_tx() + .send(ConnectorCommand::SendMessage { + session_id, + text: long_content, + }) + .await?; + + // Should handle without issue + wait_for_event( + &mut events, + |e| matches!(e, Event::MessageStarted { .. }), + Duration::from_secs(5), + ) + .await?; + tracing::info!("✓ Long message handled"); + + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t006_03_rapid_fire_requests() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("rapid-test"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + connector + .command_tx() + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + // Send 10 messages rapidly + for i in 0..10 { + connector + .command_tx() + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: format!("Message {}", i), + }) + .await?; + } + + tracing::info!("✓ Rapid-fire messages sent without blocking"); + + // Just verify we don't crash + tokio::time::sleep(Duration::from_secs(2)).await; + + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t006_04_resource_cleanup() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + // Create and destroy multiple connectors + for i in 0..5 { + let connector = create_test_connector(&format!("cleanup-test-{}", i)); + let handle = connector.start_task().await; + + tokio::time::sleep(Duration::from_millis(500)).await; + + connector.stop(); + let _ = timeout(Duration::from_secs(2), handle).await; + + tracing::info!("✓ Connector {} cleaned up", i); + } + + tracing::info!("✓ All resources cleaned up successfully"); + Ok(()) +} + +// ============================================================================ +// TEST-007: Performance Tests +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t007_01_connection_latency() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("latency-test"); + let mut events = connector.subscribe(); + + let start = Instant::now(); + let _handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + let latency = start.elapsed(); + tracing::info!("✓ Connection latency: {:?}", latency); + + // Should connect within 5 seconds + assert!(latency < Duration::from_secs(5), "Connection took too long"); + + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t007_02_request_roundtrip() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("roundtrip-test"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + // Measure session creation time + let start = Instant::now(); + connector + .command_tx() + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await?; + let roundtrip = start.elapsed(); + + tracing::info!("✓ Session creation roundtrip: {:?}", roundtrip); + assert!(roundtrip < Duration::from_secs(2), "Roundtrip too slow"); + + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t007_03_event_delivery_latency() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("event-latency-test"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + connector + .command_tx() + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + // Measure event delivery + let start = Instant::now(); + connector + .command_tx() + .send(ConnectorCommand::SendMessage { + session_id, + text: "Test".to_string(), + }) + .await?; + + wait_for_event( + &mut events, + |e| matches!(e, Event::MessageStarted { .. }), + Duration::from_secs(5), + ) + .await?; + let delivery_time = start.elapsed(); + + tracing::info!("✓ Event delivery latency: {:?}", delivery_time); + assert!( + delivery_time < Duration::from_secs(1), + "Event delivery too slow" + ); + + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t007_04_concurrent_session_capacity() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + let connector = create_test_connector("capacity-test"); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + // Create 10 sessions + let start = Instant::now(); + let mut session_ids = Vec::new(); + + for _i in 0..10 { + connector + .command_tx() + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + if let Event::SessionCreated { session, .. } = wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + session_ids.push(session.id); + } + } + + let total_time = start.elapsed(); + tracing::info!( + "✓ Created {} sessions in {:?}", + session_ids.len(), + total_time + ); + assert_eq!(session_ids.len(), 10); + + connector.stop(); + timeout(Duration::from_secs(5), handle).await??; + Ok(()) +} + +#[tokio::test] +#[ignore = "Requires mocker binary"] +async fn test_t007_05_memory_stability() -> anyhow::Result<()> { + tracing_subscriber::fmt().with_test_writer().try_init().ok(); + + // Run 20 create/destroy cycles to check for memory leaks + for i in 0..20 { + let connector = create_test_connector(&format!("mem-test-{}", i)); + let mut events = connector.subscribe(); + let handle = connector.start_task().await; + + if let Ok(_) = timeout( + Duration::from_secs(5), + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ), + ) + .await + { + tracing::info!("Cycle {} connected", i); + } + + connector.stop(); + let _ = timeout(Duration::from_secs(2), handle).await; + } + + tracing::info!("✓ 20 cycles completed without crash (manual memory check needed)"); + Ok(()) +} diff --git a/crates/dirigent_core/tests/acp_crash_resilience_tests.rs b/crates/dirigent_core/tests/acp_crash_resilience_tests.rs new file mode 100644 index 0000000..1105e49 --- /dev/null +++ b/crates/dirigent_core/tests/acp_crash_resilience_tests.rs @@ -0,0 +1,526 @@ +//! Tests for ACP connector crash resilience. +//! +//! These tests verify behavior when child processes crash, die mid-write, +//! or disconnect unexpectedly during active sessions. +//! +//! Bug reference: docs/workpad/bugs/bug_001_analysis.md + +#[cfg(feature = "server")] +mod crash_resilience { + use async_trait::async_trait; + use dirigent_core::connectors::acp::transport::{AcpTransport, TransportResult}; + use dirigent_core::connectors::acp::protocol::ProtocolHandler; + use serde_json::{json, Value}; + use std::collections::VecDeque; + use std::sync::Arc; + use tokio::sync::Mutex; + + // ========================================================================= + // Mock Transport + // ========================================================================= + + /// A mock transport that lets tests control exactly what messages are + /// received and when the connection "dies." + struct MockTransport { + /// Messages queued for recv() to return + recv_queue: Arc<Mutex<VecDeque<TransportResult<Option<Value>>>>>, + /// Messages captured by send() + sent: Arc<Mutex<Vec<Value>>>, + closed: bool, + } + + impl MockTransport { + fn new() -> Self { + Self { + recv_queue: Arc::new(Mutex::new(VecDeque::new())), + sent: Arc::new(Mutex::new(Vec::new())), + closed: false, + } + } + + /// Queue a successful message to be returned by recv() + async fn queue_message(&self, msg: Value) { + self.recv_queue.lock().await.push_back(Ok(Some(msg))); + } + + /// Queue a transport error (simulates partial read / JSON parse failure) + async fn queue_error(&self, err: &str) { + self.recv_queue + .lock() + .await + .push_back(Err(err.to_string().into())); + } + + /// Queue an EOF (simulates clean transport close / process exit) + async fn queue_eof(&self) { + self.recv_queue.lock().await.push_back(Ok(None)); + } + + /// Get all messages that were sent through this transport + async fn sent_messages(&self) -> Vec<Value> { + self.sent.lock().await.clone() + } + } + + #[async_trait] + impl AcpTransport for MockTransport { + async fn connect(&mut self) -> TransportResult<()> { + Ok(()) + } + + async fn send(&mut self, message: Value) -> TransportResult<()> { + self.sent.lock().await.push(message); + Ok(()) + } + + async fn recv(&mut self) -> TransportResult<Option<Value>> { + let mut queue = self.recv_queue.lock().await; + if let Some(item) = queue.pop_front() { + item + } else { + // No more queued items — block forever (simulates waiting for data) + drop(queue); + std::future::pending().await + } + } + + async fn close(&mut self) -> TransportResult<()> { + self.closed = true; + Ok(()) + } + } + + // ========================================================================= + // Test 1: Partial read produces transport error, not JSON parse error + // ========================================================================= + + /// When a child process crashes mid-write, read_line() returns partial data + /// without a trailing newline. The transport currently tries to parse this + /// as JSON and produces a confusing "Failed to parse JSON: trailing characters" + /// error. It SHOULD return Ok(None) — a clean EOF — instead. + /// + /// This test uses the real StdioTransport with controlled pipes to simulate + /// a crash mid-write. + #[tokio::test] + async fn test_partial_write_should_return_eof_not_parse_error() { + use tokio::io::{AsyncWriteExt, duplex}; + + // Create a pipe pair that simulates stdout of a child process + let (mut writer, reader) = duplex(8192); + + // Write partial JSON (no trailing newline — crash mid-write) + writer + .write_all(b"{\"jsonrpc\":\"2.0\",\"method\":\"session/update\",\"params\":{\"sessionId\":\"019cc2e7") + .await + .unwrap(); + + // Close the writer — simulates process crash (EOF after partial data) + drop(writer); + + // Use BufReader + read_line to simulate what StdioTransport.recv() does + use tokio::io::AsyncBufReadExt; + let mut buf_reader = tokio::io::BufReader::new(reader); + let mut line = String::new(); + let bytes_read = buf_reader.read_line(&mut line).await.unwrap(); + + // Verify: we got partial data (bytes > 0) but no trailing newline + assert!(bytes_read > 0, "Should have read some bytes"); + assert!( + !line.ends_with('\n'), + "Partial read should NOT end with newline — this is the crash signature" + ); + + // Current behavior: serde_json::from_str fails on partial JSON + let parse_result = serde_json::from_str::<Value>(line.trim()); + assert!( + parse_result.is_err(), + "Partial JSON should fail to parse" + ); + + // The transport detects partial reads (no trailing newline) as crash + // artifacts. This verifies the crash signature is identifiable from + // the raw stream data. + assert!( + !line.ends_with('\n'), + "Partial read without newline indicates a crash mid-write" + ); + } + + // ========================================================================= + // Test 2: Dropping ProtocolHandler signals pending receivers with error + // ========================================================================= + + /// When a ProtocolHandler is dropped while requests are pending, the oneshot + /// senders are dropped, causing receivers to get RecvError. This test verifies + /// that this mechanism works and demonstrates the error path. + #[tokio::test] + async fn test_protocol_handler_drop_signals_pending_receivers() { + let handler = ProtocolHandler::new(); + + // Prepare a request — creates oneshot channel, stores sender in pending_requests + let request = json!({ + "jsonrpc": "2.0", + "method": "session/prompt", + "params": { + "sessionId": "test-session", + "message": {"role": "user", "content": {"type": "text", "text": "hello"}} + } + }); + + let (_message_with_id, response_rx) = handler.prepare_request(request).await; + + // Drop the handler — this drops pending_requests, which drops the sender + drop(handler); + + // The receiver should immediately get an error + let result = response_rx.await; + assert!( + result.is_err(), + "Receiver should get error when sender is dropped (handler dropped)" + ); + + // This IS the mechanism that causes "Response channel dropped" in production. + // The error is a bare RecvError with no context about WHY it was dropped. + } + + // ========================================================================= + // Test 3: Response channel error should report CONNECTION_LOST, not TIMEOUT + // ========================================================================= + + /// When a response channel receives a cancellation from cancel_all_pending() + /// (because the transport died), the error reported to the UI should say + /// CONNECTION_LOST, not TIMEOUT. + /// + /// This test simulates the full flow: prepare request → cancel_all_pending → + /// spawned task gets structured error → emits CONNECTION_LOST. + #[tokio::test] + async fn test_response_channel_drop_error_should_indicate_connection_lost() { + use tokio::sync::broadcast; + + let handler = ProtocolHandler::new(); + + // Prepare request (creates pending oneshot channel) + let request = json!({ + "jsonrpc": "2.0", + "method": "session/prompt", + "params": { + "sessionId": "test-session", + "message": {"role": "user", "content": {"type": "text", "text": "hello"}} + } + }); + + let (_msg, response_rx) = handler.prepare_request(request).await; + + // Simulate what the connector does: spawn a task that waits for the response + let (error_tx, mut error_rx) = broadcast::channel::<(String, String)>(10); + + let task = tokio::spawn(async move { + match response_rx.await { + Ok(response) => { + // Check if this is an error response from cancel_all_pending + if let Some(error) = response.get("error") { + let error_code = "CONNECTION_LOST".to_string(); + let error_message = error.get("message") + .and_then(|m| m.as_str()) + .unwrap_or("unknown") + .to_string(); + let _ = error_tx.send((error_code, error_message)); + } + } + Err(_recv_error) => { + // Fallback: bare drop without cancel_all_pending + let _ = error_tx.send(("TIMEOUT".to_string(), "channel dropped".to_string())); + } + } + }); + + // Cancel all pending instead of just dropping + handler.cancel_all_pending("Connection lost: agent process exited").await; + drop(handler); + + // Wait for the spawned task to detect the cancellation and report + let (error_code, _error_message) = error_rx.recv().await.unwrap(); + + assert_eq!( + error_code, "CONNECTION_LOST", + "Error code should be CONNECTION_LOST after cancel_all_pending" + ); + + task.await.unwrap(); + } + + // ========================================================================= + // Test 4: Protocol handler should support explicit cancellation + // ========================================================================= + + /// When we know the transport is dying, we can explicitly cancel all pending + /// requests with a structured error response (containing the real reason), + /// rather than relying on implicit Drop which gives receivers a bare + /// RecvError with no context. + #[tokio::test] + async fn test_protocol_handler_should_support_explicit_cancellation() { + let handler = ProtocolHandler::new(); + + // Prepare two concurrent requests + let req1 = json!({"jsonrpc": "2.0", "method": "session/prompt", "params": {}}); + let req2 = json!({"jsonrpc": "2.0", "method": "session/list", "params": {}}); + + let (_msg1, response_rx1) = handler.prepare_request(req1).await; + let (_msg2, response_rx2) = handler.prepare_request(req2).await; + + // cancel_all_pending() now exists — call it + handler.cancel_all_pending("Connection lost: agent process exited").await; + + // Both receivers should get structured error responses (not bare RecvError) + let result1 = response_rx1.await; + let result2 = response_rx2.await; + + assert!(result1.is_ok(), "Receiver 1 should get Ok (structured error), not Err"); + assert!(result2.is_ok(), "Receiver 2 should get Ok (structured error), not Err"); + + // Verify the error response contains the reason + let resp1 = result1.unwrap(); + let error_obj = resp1.get("error").expect("Should have error field"); + assert_eq!(error_obj.get("code").unwrap(), -32000); + assert!( + error_obj.get("message").unwrap().as_str().unwrap().contains("Connection lost"), + "Error message should contain the cancellation reason" + ); + } + + // ========================================================================= + // Test 5: Full crash sequence — request in flight, transport dies + // ========================================================================= + + /// Simulates the exact sequence from the production logs: + /// 1. Protocol handler prepares a request + /// 2. Request is "sent" via mock transport + /// 3. Transport starts returning notifications (streaming chunks) + /// 4. Transport returns error (partial read from crash) + /// 5. Transport returns EOF (process dead) + /// 6. Protocol handler is dropped + /// 7. Pending response receiver gets error + /// + /// This is the integration test that proves the full cascade. + #[tokio::test] + async fn test_full_crash_sequence_request_in_flight() { + let handler = ProtocolHandler::new(); + let mut transport = MockTransport::new(); + + // Step 1: Prepare a prompt request + let request = json!({ + "jsonrpc": "2.0", + "method": "session/prompt", + "params": { + "sessionId": "019cc2e7-26d6-7102-bed6-4c953c023109", + "message": {"role": "user", "content": {"type": "text", "text": "hello"}} + } + }); + let (msg_with_id, response_rx) = handler.prepare_request(request).await; + let request_id = msg_with_id.get("id").cloned().unwrap(); + + // Step 2: Send via transport + transport.send(msg_with_id).await.unwrap(); + + // Step 3: Queue some streaming notifications (agent is responding) + transport + .queue_message(json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "019cc2e7-26d6-7102-bed6-4c953c023109", + "update": { + "sessionUpdate": "agent_message_chunk", + "content": {"type": "text", "text": "Hello! I'm an AI"} + } + } + })) + .await; + + // Step 4: Queue a transport error (partial read from crash) + transport + .queue_error("Failed to parse JSON: trailing characters at line 1 column 4. Line: -26d6-7102-bed6-4c953c023109") + .await; + + // Step 5: Queue EOF (process dead) + transport.queue_eof().await; + + // Process the notification — it should be routed to notification channel + let msg = transport.recv().await.unwrap().unwrap(); + let _result = handler.handle_message(msg).await; + + // Process the error — transport returns Err + let err_result = transport.recv().await; + assert!(err_result.is_err(), "Transport should return error for partial read"); + let err_msg = err_result.unwrap_err().to_string(); + assert!( + err_msg.contains("Failed to parse JSON"), + "Error should mention JSON parse failure, got: {}", + err_msg + ); + + // Process the EOF + let eof = transport.recv().await.unwrap(); + assert!(eof.is_none(), "Transport should return None for EOF"); + + // Step 6: Cancel all pending (simulates what connector does before break) + handler.cancel_all_pending("Connection lost: agent process exited").await; + drop(handler); + + // Step 7: The pending response receiver should get a structured error + let result = response_rx.await; + assert!( + result.is_ok(), + "Pending request should receive structured error response from cancel_all_pending" + ); + + let response = result.unwrap(); + let error = response.get("error").expect("Should have error field"); + assert_eq!(error.get("code").unwrap(), -32000); + assert!( + error.get("message").unwrap().as_str().unwrap().contains("Connection lost"), + "Error should contain crash reason" + ); + + // Verify the request was actually sent + let sent = transport.sent_messages().await; + assert_eq!(sent.len(), 1); + assert_eq!(sent[0].get("id"), Some(&request_id)); + } + + // ========================================================================= + // Test 6: Crash context — stderr, exit status, partial stdout should be + // assembled into a single diagnostic report + // ========================================================================= + + /// When a child process crashes, three separate data sources contain the + /// explanation: stderr (panic/error message), exit status (signal/code), + /// and partial stdout (what was being written). Currently these are in + /// separate silos and never assembled together. + /// + /// This test simulates a real child process crash and verifies that: + /// - stderr output IS captured (currently: drained to warn! log, then lost) + /// - exit status IS available at crash time (currently: only in close()) + /// - partial stdout IS identified as crash artifact (currently: JSON parse error) + #[tokio::test] + async fn test_crash_context_should_capture_stderr_and_exit_status() { + // Simulate a child process crash where three data sources exist: + // 1. Writes partial JSON to stdout (simulates crash mid-write) + // 2. Writes an error message to stderr (simulates panic/error output) + // 3. Exits with code 1 + // + // We use a simple inline script via the shell. + // Use tokio duplex streams to simulate a crashing child process + // without needing an external program. This is more reliable than + // spawning shell commands across platforms. + use tokio::io::AsyncWriteExt; + + // Simulate stdout: partial JSON without trailing newline + let (mut stdout_writer, stdout_read) = tokio::io::duplex(8192); + stdout_writer + .write_all(b"{\"jsonrpc\":\"2.0\",\"method\":\"crash") + .await + .unwrap(); + stdout_writer.flush().await.unwrap(); + drop(stdout_writer); // EOF — simulates process death + + // Simulate stderr: error message from the crashing process + let (mut stderr_writer, stderr_read) = tokio::io::duplex(8192); + stderr_writer + .write_all(b"FATAL: panicked at 'index out of bounds', src/main.rs:42\n") + .await + .unwrap(); + stderr_writer.flush().await.unwrap(); + drop(stderr_writer); // EOF + + // Read stdout — should get partial data (no trailing newline) + use tokio::io::AsyncBufReadExt; + let mut stdout_reader = tokio::io::BufReader::new(stdout_read); + let mut stdout_line = String::new(); + let bytes = stdout_reader.read_line(&mut stdout_line).await.unwrap(); + + // Read stderr — should get the error message + let mut stderr_reader = tokio::io::BufReader::new(stderr_read); + let mut stderr_line = String::new(); + let _ = stderr_reader.read_line(&mut stderr_line).await.unwrap(); + + // VERIFY: All three data sources are available in principle + assert!(bytes > 0, "Should have stdout data"); + assert!( + !stdout_line.ends_with('\n'), + "Partial stdout should NOT end with newline (crash mid-write)" + ); + assert!( + stderr_line.contains("FATAL") || stderr_line.contains("panicked"), + "Stderr should contain the error reason, got: {}", + stderr_line + ); + + // Verify that the partial stdout is NOT valid JSON + let parse_result = serde_json::from_str::<Value>(&stdout_line.trim()); + assert!( + parse_result.is_err(), + "Partial stdout should not parse as valid JSON" + ); + + // Demonstrate that CrashContext can be assembled from these data sources. + // StdioTransport now exposes get_crash_context() which collects stderr, + // exit status, and partial stdout into a single diagnostic struct. + use dirigent_core::connectors::acp::transport::CrashContext; + + let crash_ctx = CrashContext { + recent_stderr: vec![stderr_line.trim().to_string()], + exit_status: None, // duplex streams don't have exit status + partial_stdout: Some(stdout_line.clone()), + }; + + assert!(!crash_ctx.recent_stderr.is_empty(), "CrashContext should contain stderr"); + assert!(crash_ctx.partial_stdout.is_some(), "CrashContext should contain partial stdout"); + assert!( + crash_ctx.recent_stderr[0].contains("FATAL"), + "Stderr in CrashContext should contain the crash reason" + ); + } + + // ========================================================================= + // Test 7: Response arrives correctly when transport is healthy + // ========================================================================= + + /// Baseline test: verify the happy path works — request sent, response + /// arrives, oneshot channel delivers it correctly. + #[tokio::test] + async fn test_happy_path_response_delivered_correctly() { + let handler = ProtocolHandler::new(); + + // Prepare request + let request = json!({ + "jsonrpc": "2.0", + "method": "session/new", + "params": {"cwd": "."} + }); + let (msg_with_id, response_rx) = handler.prepare_request(request).await; + let request_id = msg_with_id.get("id").cloned().unwrap(); + + // Simulate response arriving (what handle_message does when response comes) + let response = json!({ + "jsonrpc": "2.0", + "id": request_id, + "result": {"sessionId": "new-session-123"} + }); + handler.handle_message(response).await; + + // Response should be delivered via oneshot + let result = response_rx.await; + assert!(result.is_ok(), "Response should be delivered successfully"); + + let response_value = result.unwrap(); + assert_eq!( + response_value + .get("result") + .unwrap() + .get("sessionId") + .unwrap(), + "new-session-123" + ); + } +} diff --git a/crates/dirigent_core/tests/acp_http_integration_test.rs b/crates/dirigent_core/tests/acp_http_integration_test.rs new file mode 100644 index 0000000..1e10675 --- /dev/null +++ b/crates/dirigent_core/tests/acp_http_integration_test.rs @@ -0,0 +1,740 @@ +//! TEST-002: Comprehensive HTTP Integration Tests +//! +//! This file contains comprehensive integration tests for the ACP client over HTTP transport. +//! It tests the full lifecycle of ACP communication using the dirigate HTTP server. +//! +//! Test Coverage: +//! - T009: Full HTTP lifecycle (initialize, new session, prompt, cancel) +//! - T010: HTTP initialize handshake +//! - T011: HTTP session creation +//! - T012: HTTP message sending with SSE streaming +//! - T013: HTTP session cancellation +//! - T014: SSE connection and events +//! - T015: Multiple concurrent HTTP sessions +//! - T016: HTTP connection recovery + +#![cfg(feature = "server")] + +use dirigent_core::connectors::acp::{AcpConfig, AcpConnector}; +use dirigent_core::connectors::{Connector, ConnectorCommand}; +use dirigent_protocol::Event; +use std::path::PathBuf; +use std::time::Duration; +use tokio::time::timeout; + +/// Helper to get test port (unique per test to avoid conflicts) +fn test_port(offset: u16) -> u16 { + 9000 + offset +} + +/// Helper to create a test fixture path +fn test_fixture_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("fixtures") + .join("acp_test.yaml") +} + +/// Helper to spawn the mocker HTTP server in the background +async fn spawn_mocker_server(port: u16) -> anyhow::Result<tokio::task::JoinHandle<()>> { + let fixture_path = test_fixture_path(); + + // Build the mocker + tracing::info!("Building dirigate..."); + let build_status = tokio::process::Command::new("cargo") + .args(&["build", "--package", "dirigate"]) + .status() + .await?; + + if !build_status.success() { + anyhow::bail!("Failed to build dirigate"); + } + + // Find the binary + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let project_root = manifest_dir.parent().unwrap().parent().unwrap(); + let binary_path = project_root + .join("target") + .join("debug") + .join("dirigate.exe"); + + if !binary_path.exists() { + anyhow::bail!("Mocker binary not found at {:?}", binary_path); + } + + // Spawn the server + tracing::info!("Starting mocker server on port {}", port); + let handle = tokio::spawn(async move { + let _ = tokio::process::Command::new(&binary_path) + .args(&[ + "serve", + "--fixtures", + fixture_path.to_str().unwrap(), + "--port", + &port.to_string(), + ]) + .status() + .await; + }); + + // Wait for server to be ready + tokio::time::sleep(Duration::from_secs(2)).await; + + Ok(handle) +} + +/// Helper to create a test ACP connector with HTTP transport +fn create_http_connector(port: u16) -> AcpConnector { + let config = AcpConfig::http(format!("http://127.0.0.1:{}", port)) + .with_retry_max_attempts(3) + .with_retry_initial_delay(Duration::from_millis(500)); + + AcpConnector::new( + format!("http-test-connector-{}", port), + uuid::Uuid::nil(), + "HTTP Test Connector".to_string(), + config, + dirigent_core::sharing::bus::SharingBus::new(), + ) + .expect("Failed to create connector") +} + +/// Helper to wait for event with timeout +async fn wait_for_event<F>( + events_rx: &mut tokio::sync::broadcast::Receiver<Event>, + predicate: F, + timeout_duration: Duration, +) -> anyhow::Result<Event> +where + F: Fn(&Event) -> bool, +{ + let result = timeout(timeout_duration, async { + loop { + match events_rx.recv().await { + Ok(event) => { + if predicate(&event) { + return Ok(event); + } + } + Err(e) => { + return Err(anyhow::anyhow!("Event receive error: {}", e)); + } + } + } + }) + .await; + + match result { + Ok(Ok(event)) => Ok(event), + Ok(Err(e)) => Err(e), + Err(_) => Err(anyhow::anyhow!("Timeout waiting for event")), + } +} + +// ============================================================================ +// T009: Full HTTP Lifecycle Test +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires building mocker and port availability"] +async fn test_t009_full_lifecycle_http() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let port = test_port(1); + let _server_handle = spawn_mocker_server(port).await?; + + // Create connector + let connector = create_http_connector(port); + let mut events = connector.subscribe(); + + // Start the connector task + let task_handle = connector.start_task().await; + + // Wait for connection + let connected_event = wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + assert!(matches!(connected_event, Event::Connected)); + tracing::info!("✓ Connected to HTTP mocker"); + + // Create a new session + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_created = wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await?; + + let session_id = match session_created { + Event::SessionCreated { session, .. } => { + tracing::info!("✓ Session created: {}", session.id); + session.id + } + _ => anyhow::bail!("Expected SessionCreated event"), + }; + + // Send a message + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "Hello via HTTP!".to_string(), + }) + .await?; + + // Wait for message started + wait_for_event( + &mut events, + |e| matches!(e, Event::MessageStarted { .. }), + Duration::from_secs(5), + ) + .await?; + + tracing::info!("✓ Message started"); + + // Wait for some content + let mut received_content = false; + for _ in 0..10 { + if let Ok(event) = timeout(Duration::from_secs(2), events.recv()).await { + match event? { + Event::SessionUpdate { .. } => { + received_content = true; + tracing::info!("✓ Received content via SSE"); + break; + } + Event::MessageCompleted { .. } => { + received_content = true; + break; + } + _ => {} + } + } + } + + assert!(received_content, "Should receive content via SSE"); + + // Stop the connector + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + tracing::info!("✓ Full HTTP lifecycle completed successfully"); + Ok(()) +} + +// ============================================================================ +// T010: HTTP Initialize Handshake +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires building mocker and port availability"] +async fn test_t010_http_initialize() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let port = test_port(2); + let _server_handle = spawn_mocker_server(port).await?; + + let connector = create_http_connector(port); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Wait for connected event + let connected = wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + assert!(matches!(connected, Event::Connected)); + tracing::info!("✓ HTTP initialize handshake completed"); + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T011: HTTP Session Creation +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires building mocker and port availability"] +async fn test_t011_http_session_creation() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let port = test_port(3); + let _server_handle = spawn_mocker_server(port).await?; + + let connector = create_http_connector(port); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Wait for connection + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + // Create session + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + // Wait for session created event + let session_created = wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await?; + + match session_created { + Event::SessionCreated { session, .. } => { + assert_eq!(session.title, "HTTP Session Test"); + tracing::info!("✓ HTTP session created with correct title"); + } + _ => anyhow::bail!("Expected SessionCreated event"), + } + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T012: HTTP Message Sending with SSE Streaming +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires building mocker and port availability"] +async fn test_t012_http_sse_streaming() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let port = test_port(4); + let _server_handle = spawn_mocker_server(port).await?; + + let connector = create_http_connector(port); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Setup + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + // Send message + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "Stream me some content".to_string(), + }) + .await?; + + // Wait for message started + wait_for_event( + &mut events, + |e| matches!(e, Event::MessageStarted { .. }), + Duration::from_secs(5), + ) + .await?; + + // Collect SSE chunks + let mut chunk_count = 0; + let mut completed = false; + + for _ in 0..20 { + if let Ok(result) = timeout(Duration::from_secs(2), events.recv()).await { + match result? { + Event::SessionUpdate { update, .. } => { + if let dirigent_protocol::SessionUpdate::AgentMessageChunk { .. } = update { + chunk_count += 1; + tracing::info!("Received SSE chunk {}", chunk_count); + } + } + Event::MessageCompleted { .. } => { + completed = true; + tracing::info!("✓ Message completed via SSE after {} chunks", chunk_count); + break; + } + _ => {} + } + } + } + + assert!(chunk_count > 0, "Should receive SSE chunks"); + assert!(completed, "Message should complete via SSE"); + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T013: HTTP Session Cancellation +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires building mocker and port availability"] +async fn test_t013_http_cancellation() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let port = test_port(5); + let _server_handle = spawn_mocker_server(port).await?; + + let connector = create_http_connector(port); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Setup + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + // Send message + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "Long response".to_string(), + }) + .await?; + + // Wait for streaming to start + wait_for_event( + &mut events, + |e| matches!(e, Event::MessageStarted { .. }), + Duration::from_secs(5), + ) + .await?; + + wait_for_event( + &mut events, + |e| matches!(e, Event::SessionUpdate { .. }), + Duration::from_secs(5), + ) + .await?; + + // Cancel + cmd_tx + .send(ConnectorCommand::CancelGeneration { + session_id: session_id.clone(), + }) + .await?; + + tracing::info!("✓ HTTP cancellation sent successfully"); + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T014: SSE Connection and Events +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires building mocker and port availability"] +async fn test_t014_sse_connection() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let port = test_port(6); + let _server_handle = spawn_mocker_server(port).await?; + + let connector = create_http_connector(port); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // The SSE connection is established during initialization + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + tracing::info!("✓ SSE connection established"); + + // Create a session and send a message to verify SSE delivery + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id, + text: "Test SSE".to_string(), + }) + .await?; + + // Verify we receive events via SSE + let mut received_via_sse = false; + for _ in 0..10 { + if let Ok(event) = timeout(Duration::from_secs(1), events.recv()).await { + match event? { + Event::SessionUpdate { .. } + | Event::MessageStarted { .. } + | Event::MessageCompleted { .. } => { + received_via_sse = true; + break; + } + _ => {} + } + } + } + + assert!(received_via_sse, "Should receive events via SSE"); + tracing::info!("✓ SSE event delivery verified"); + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T015: Multiple Concurrent HTTP Sessions +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires building mocker and port availability"] +async fn test_t015_http_multiple_sessions() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let port = test_port(7); + let _server_handle = spawn_mocker_server(port).await?; + + let connector = create_http_connector(port); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + let cmd_tx = connector.command_tx(); + let mut session_ids = Vec::new(); + + // Create multiple sessions via HTTP + for _i in 0..3 { + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + session_ids.push(session_id); + } + + assert_eq!(session_ids.len(), 3); + tracing::info!("✓ Created 3 concurrent HTTP sessions"); + + // Send messages to all sessions + for session_id in &session_ids { + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "Test".to_string(), + }) + .await?; + } + + // Wait for responses + let mut started_count = 0; + for _ in 0..30 { + if let Ok(result) = timeout(Duration::from_secs(1), events.recv()).await { + if let Event::MessageStarted { .. } = result? { + started_count += 1; + if started_count == 3 { + break; + } + } + } + } + + assert_eq!(started_count, 3); + tracing::info!("✓ All HTTP messages received responses"); + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T016: HTTP Connection Recovery +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires building mocker and port availability"] +async fn test_t016_http_connection_recovery() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let port = test_port(8); + + // Create connector first (without server running) + let connector = create_http_connector(port); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Wait a bit for failed connection attempts + tokio::time::sleep(Duration::from_secs(2)).await; + + // Now start the server + tracing::info!("Starting server after connector creation"); + let _server_handle = spawn_mocker_server(port).await?; + + // Should eventually connect (with retry) + let connected = wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(15), + ) + .await?; + + assert!(matches!(connected, Event::Connected)); + tracing::info!("✓ HTTP connector recovered and connected"); + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} diff --git a/crates/dirigent_core/tests/acp_initialization_tests.rs b/crates/dirigent_core/tests/acp_initialization_tests.rs new file mode 100644 index 0000000..52a89db --- /dev/null +++ b/crates/dirigent_core/tests/acp_initialization_tests.rs @@ -0,0 +1,357 @@ +//! Integration tests for ACP initialization and capability negotiation. +//! +//! These tests verify that the ACP protocol implementation correctly performs: +//! - Protocol version negotiation +//! - Capability exchange +//! - Optional authentication +//! - Capability validation +//! +//! Tests use the dirigate for both stdio and HTTP transport modes. + +#[cfg(feature = "server")] +mod initialization_tests { + use dirigent_core::acp::{ + authenticate, capabilities, connector_state::*, initialize, + }; + + /// Test helper to create default client info. + fn client_info() -> ImplementationInfo { + ImplementationInfo { + name: "dirigent-test".to_string(), + title: Some("Dirigent Test Client".to_string()), + version: Some("0.1.0".to_string()), + } + } + + #[tokio::test] + async fn test_initialize_request_serialization() { + let caps = ClientCapabilities::default_safe(); + let request = initialize::InitializeRequest::new(caps, Some(client_info())); + + // Verify serialization + let jsonrpc = request.to_jsonrpc(1); + assert_eq!(jsonrpc.method, "initialize"); + assert_eq!(jsonrpc.jsonrpc, "2.0"); + assert!(jsonrpc.params.is_some()); + + // Verify the serialized params contain expected fields + let params = jsonrpc.params.unwrap(); + assert!(params.get("protocol_version").is_some()); + assert!(params.get("client_capabilities").is_some()); + assert!(params.get("client_info").is_some()); + } + + #[tokio::test] + async fn test_initialize_response_parsing() { + let json = serde_json::json!({ + "protocol_version": 1, + "agent_capabilities": { + "load_session": true, + "prompt_capabilities": { + "image": true, + "audio": false, + "embedded_context": true + } + }, + "agent_info": { + "name": "test-agent", + "version": "1.0.0" + }, + "auth_methods": [] + }); + + let response = initialize::InitializeResponse::from_jsonrpc(&json).unwrap(); + + assert_eq!(response.protocol_version, 1); + assert!(response.agent_capabilities.supports_load_session()); + assert!(response.agent_capabilities.supports_image()); + assert!(!response.agent_capabilities.supports_audio()); + assert!(response.agent_capabilities.supports_embedded_context()); + assert_eq!(response.auth_methods.len(), 0); + } + + #[tokio::test] + async fn test_version_mismatch_detection() { + let client_version = 1; + let agent_version = 2; + + assert!(!initialize::is_version_compatible( + client_version, + agent_version + )); + assert!(initialize::is_version_compatible(client_version, 1)); + } + + #[tokio::test] + async fn test_authenticate_request_redacts_credentials() { + let creds = serde_json::json!({"api_key": "super_secret_key_12345"}); + let request = authenticate::AuthenticateRequest::new("api_key", creds); + + let debug_str = format!("{:?}", request); + + // Should NOT contain the actual secret + assert!(!debug_str.contains("super_secret_key")); + assert!(!debug_str.contains("12345")); + + // Should show that it's redacted + assert!(debug_str.contains("REDACTED")); + + // Should still show the method + assert!(debug_str.contains("api_key")); + } + + #[tokio::test] + async fn test_authenticate_response_parsing() { + let success_json = serde_json::json!({ + "success": true + }); + + let response = authenticate::AuthenticateResponse::from_jsonrpc(&success_json).unwrap(); + assert!(response.success); + assert!(response.error.is_none()); + + let failure_json = serde_json::json!({ + "success": false, + "error": "Invalid credentials" + }); + + let response = authenticate::AuthenticateResponse::from_jsonrpc(&failure_json).unwrap(); + assert!(!response.success); + assert_eq!(response.error, Some("Invalid credentials".to_string())); + } + + #[tokio::test] + async fn test_auth_required_check() { + assert!(!authenticate::is_auth_required(&[])); + assert!(authenticate::is_auth_required(&["api_key".to_string()])); + assert!(authenticate::is_auth_required(&[ + "api_key".to_string(), + "oauth".to_string() + ])); + } + + #[tokio::test] + async fn test_capability_validation_fs_read() { + let caps_enabled = ClientCapabilities { + fs: Some(FsCapabilities { + read_text_file: Some(true), + write_text_file: Some(false), + }), + terminal: Some(false), + _meta: None, + }; + + let result = capabilities::validate_capability("fs/read_text_file", &caps_enabled); + assert_eq!(result, capabilities::CapabilityValidation::Supported); + + let caps_disabled = ClientCapabilities { + fs: Some(FsCapabilities { + read_text_file: Some(false), + write_text_file: Some(false), + }), + terminal: Some(false), + _meta: None, + }; + + let result = capabilities::validate_capability("fs/read_text_file", &caps_disabled); + assert!(matches!( + result, + capabilities::CapabilityValidation::Unsupported(_) + )); + } + + #[tokio::test] + async fn test_capability_validation_fs_write() { + let caps_enabled = ClientCapabilities { + fs: Some(FsCapabilities { + read_text_file: Some(true), + write_text_file: Some(true), + }), + terminal: Some(false), + _meta: None, + }; + + let result = capabilities::validate_capability("fs/write_text_file", &caps_enabled); + assert_eq!(result, capabilities::CapabilityValidation::Supported); + + let caps_disabled = ClientCapabilities { + fs: Some(FsCapabilities { + read_text_file: Some(true), + write_text_file: Some(false), + }), + terminal: Some(false), + _meta: None, + }; + + let result = capabilities::validate_capability("fs/write_text_file", &caps_disabled); + assert!(matches!( + result, + capabilities::CapabilityValidation::Unsupported(_) + )); + } + + #[tokio::test] + async fn test_capability_validation_terminal() { + let caps_enabled = ClientCapabilities { + fs: None, + terminal: Some(true), + _meta: None, + }; + + let result = capabilities::validate_capability("terminal/execute", &caps_enabled); + assert_eq!(result, capabilities::CapabilityValidation::Supported); + + let result = capabilities::validate_capability("terminal/read", &caps_enabled); + assert_eq!(result, capabilities::CapabilityValidation::Supported); + + let caps_disabled = ClientCapabilities { + fs: None, + terminal: Some(false), + _meta: None, + }; + + let result = capabilities::validate_capability("terminal/execute", &caps_disabled); + assert!(matches!( + result, + capabilities::CapabilityValidation::Unsupported(_) + )); + } + + #[tokio::test] + async fn test_capability_validation_core_methods() { + let minimal_caps = ClientCapabilities { + fs: None, + terminal: Some(false), + _meta: None, + }; + + // Core protocol methods should always be supported + assert_eq!( + capabilities::validate_capability("initialize", &minimal_caps), + capabilities::CapabilityValidation::Supported + ); + assert_eq!( + capabilities::validate_capability("authenticate", &minimal_caps), + capabilities::CapabilityValidation::Supported + ); + assert_eq!( + capabilities::validate_capability("session/new", &minimal_caps), + capabilities::CapabilityValidation::Supported + ); + assert_eq!( + capabilities::validate_capability("session/prompt", &minimal_caps), + capabilities::CapabilityValidation::Supported + ); + } + + #[tokio::test] + async fn test_requires_capability_check() { + assert!(capabilities::requires_capability_check("fs/read_text_file")); + assert!(capabilities::requires_capability_check("fs/write_text_file")); + assert!(capabilities::requires_capability_check("terminal/execute")); + + assert!(!capabilities::requires_capability_check("initialize")); + assert!(!capabilities::requires_capability_check("authenticate")); + assert!(!capabilities::requires_capability_check("session/new")); + } + + #[tokio::test] + async fn test_connector_state_lifecycle() { + let mut state = AcpConnectorState::new(); + + assert_eq!(state.connection_state, ConnectionState::Uninitialized); + assert!(!state.is_ready()); + + state.connection_state = ConnectionState::Initializing; + assert!(!state.is_ready()); + + state.connection_state = ConnectionState::Initialized; + assert!(!state.is_ready()); + + state.connection_state = ConnectionState::Ready; + assert!(state.is_ready()); + + state.connection_state = ConnectionState::Disconnected; + assert!(!state.is_ready()); + } + + #[tokio::test] + async fn test_connector_state_auth_required() { + let mut state = AcpConnectorState::new(); + + assert!(!state.requires_auth()); + + state.auth_methods = vec!["api_key".to_string()]; + assert!(state.requires_auth()); + + state.authenticated = true; + assert!(!state.requires_auth()); + } + + #[tokio::test] + async fn test_connector_state_error() { + let mut state = AcpConnectorState::new(); + + state.set_error("Test error"); + + match state.connection_state { + ConnectionState::Error(msg) => { + assert_eq!(msg, "Test error"); + } + _ => panic!("Expected Error state"), + } + } + + #[tokio::test] + async fn test_capability_not_supported_error() { + let error = capabilities::capability_not_supported_error( + "fs/write_text_file", + "write capability not enabled", + ); + + assert_eq!(error.code, capabilities::ERROR_METHOD_NOT_FOUND); + assert!(error.message.contains("fs/write_text_file")); + assert!(error.message.contains("write capability not enabled")); + } + + #[tokio::test] + async fn test_agent_capabilities_helpers() { + let caps = AgentCapabilities { + load_session: Some(true), + prompt_capabilities: Some(PromptCapabilities { + image: Some(true), + audio: Some(false), + embedded_context: Some(true), + }), + mcp: Some(McpCapabilities { + http: Some(true), + sse: Some(false), + }), + _meta: None, + }; + + assert!(caps.supports_load_session()); + assert!(caps.supports_image()); + assert!(!caps.supports_audio()); + assert!(caps.supports_embedded_context()); + } + + #[tokio::test] + async fn test_client_capabilities_presets() { + let safe_caps = ClientCapabilities::default_safe(); + assert_eq!( + safe_caps.fs.as_ref().unwrap().read_text_file, + Some(true) + ); + assert_eq!( + safe_caps.fs.as_ref().unwrap().write_text_file, + Some(false) + ); + assert_eq!(safe_caps.terminal, Some(false)); + + let all_caps = ClientCapabilities::all_enabled(); + assert_eq!(all_caps.fs.as_ref().unwrap().read_text_file, Some(true)); + assert_eq!(all_caps.fs.as_ref().unwrap().write_text_file, Some(true)); + assert_eq!(all_caps.terminal, Some(true)); + } +} diff --git a/crates/dirigent_core/tests/acp_integration/README.md b/crates/dirigent_core/tests/acp_integration/README.md new file mode 100644 index 0000000..b65b70a --- /dev/null +++ b/crates/dirigent_core/tests/acp_integration/README.md @@ -0,0 +1,174 @@ +# ACP Integration Test Environment + +This directory contains utilities and fixtures for integration testing with the `dirigent_acp_mocker`. + +## Structure + +- **mocker_utils.rs** - Utilities for spawning and managing mocker processes +- **golden_transcripts.rs** - Golden transcript fixtures for common ACP flows +- **README.md** - This file + +## Running Tests + +### Basic Tests (No Process Spawning) + +```bash +# Run all tests +cargo test --package dirigent_core --test acp_mocker_test + +# Run specific test +cargo test --package dirigent_core --test acp_mocker_test test_golden_transcripts_available +``` + +### Integration Tests (With Process Spawning) + +These tests spawn actual mocker processes and are ignored by default. + +```bash +# Run all ignored tests (spawns processes) +cargo test --package dirigent_core --test acp_mocker_test -- --ignored + +# Run specific ignored test +cargo test --package dirigent_core --test acp_mocker_test test_spawn_stdio_mocker -- --ignored + +# Run all tests (including ignored ones) +cargo test --package dirigent_core --test acp_mocker_test -- --include-ignored +``` + +## Mocker Utilities + +### Spawning a Mocker in Stdio Mode + +```rust +use dirigent_core::tests::acp_integration::MockerProcess; + +#[tokio::test] +#[ignore] +async fn test_with_stdio_mocker() { + let mocker = MockerProcess::spawn_stdio().await.unwrap(); + + // Use mocker via stdin/stdout... + + mocker.kill().await.unwrap(); +} +``` + +### Spawning a Mocker in HTTP Mode + +```rust +use dirigent_core::tests::acp_integration::MockerProcess; + +#[tokio::test] +#[ignore] +async fn test_with_http_mocker() { + let port = 18888; + let mocker = MockerProcess::spawn_http(port).await.unwrap(); + + // Connect to mocker at http://localhost:18888 + + mocker.kill().await.unwrap(); +} +``` + +### Using Configuration Presets + +```rust +use dirigent_core::tests::acp_integration::{MockerProcess, MockerConfig}; + +#[tokio::test] +#[ignore] +async fn test_with_configured_mocker() { + let config = MockerConfig::with_preset("basic"); + let args: Vec<&str> = config.to_args().iter().map(|s| s.as_str()).collect(); + + let mocker = MockerProcess::spawn_stdio_with_args(&args).await.unwrap(); + + // Use mocker... + + mocker.kill().await.unwrap(); +} +``` + +## Golden Transcripts + +Golden transcripts represent expected request/response sequences for testing. + +```rust +use dirigent_core::tests::acp_integration::load_golden_transcript; + +#[test] +fn test_with_golden_transcript() { + let transcript = load_golden_transcript("initialize").unwrap(); + + let request = &transcript["request"]; + let response = &transcript["response"]; + + // Validate against actual mocker responses... +} +``` + +### Available Transcripts + +- **initialize** - Initialize handshake with capabilities exchange +- **new_session** - Create a new session +- **prompt** - Send a simple prompt and receive streaming response +- **tool_call_read** - Tool call flow for reading a file +- **cancel** - Cancel a running session + +## Windows-Specific Notes + +- Process spawning uses `cargo run` to build and run the mocker +- Paths are handled cross-platform by default +- Process cleanup is automatic via Drop implementation +- If tests hang, check for orphaned mocker processes in Task Manager + +## Troubleshooting + +### Mocker Won't Start + +1. Ensure `dirigent_acp_mocker` package builds: + ```bash + cargo build --package dirigent_acp_mocker + ``` + +2. Try running the mocker manually: + ```bash + cargo run --package dirigent_acp_mocker -- serve --stdio + ``` + +3. Check for port conflicts (HTTP mode): + ```bash + netstat -ano | findstr :18888 + ``` + +### Tests Timeout + +- Increase the timeout duration in the test +- Check mocker logs for errors (stderr is inherited) +- Verify mocker is actually starting (add debug logging) + +### Process Cleanup Issues + +- The `Drop` implementation should clean up automatically +- If orphaned processes remain, kill them manually: + ```bash + # Windows + taskkill /F /IM dirigent_acp_mocker.exe + + # Linux/macOS + pkill -f dirigent_acp_mocker + ``` + +## Future Work + +This infrastructure is ready for: + +- **TEST-01**: Protocol validation tests (stdio mode) +- **TEST-02**: Protocol validation tests (HTTP mode) +- **TEST-03**: Session update rendering tests +- **TEST-04**: Permission prompt flow tests +- **TEST-05**: File operations sandbox tests +- **TEST-06**: Terminal lifecycle tests +- **TEST-07**: Search operations tests + +See `docs/building/04_acp_client/04_tasks_00_scaffolding_and_finishing.md` for the full test plan. diff --git a/crates/dirigent_core/tests/acp_integration/golden_transcripts.rs b/crates/dirigent_core/tests/acp_integration/golden_transcripts.rs new file mode 100644 index 0000000..c58d7e5 --- /dev/null +++ b/crates/dirigent_core/tests/acp_integration/golden_transcripts.rs @@ -0,0 +1,229 @@ +//! Golden transcript fixtures for common ACP flows. +//! +//! These fixtures represent expected request/response sequences for testing. + +use serde_json::json; + +/// Golden transcript for initialize flow. +pub fn golden_initialize() -> serde_json::Value { + json!({ + "request": { + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "capabilities": { + "tools": true, + "streaming": true + }, + "clientInfo": { + "name": "dirigent", + "version": "0.1.0" + } + }, + "id": 1 + }, + "response": { + "jsonrpc": "2.0", + "result": { + "capabilities": { + "tools": ["read", "write", "edit", "search", "execute"], + "streaming": true, + "embeddedContext": true + }, + "serverInfo": { + "name": "dirigate", + "version": "0.1.0" + } + }, + "id": 1 + } + }) +} + +/// Golden transcript for new_session flow. +pub fn golden_new_session() -> serde_json::Value { + json!({ + "request": { + "jsonrpc": "2.0", + "method": "session/new", + "params": { + "mode": "ask" + }, + "id": 2 + }, + "response": { + "jsonrpc": "2.0", + "result": { + "sessionId": "test-session-123" + }, + "id": 2 + } + }) +} + +/// Golden transcript for simple prompt flow. +pub fn golden_prompt() -> serde_json::Value { + json!({ + "request": { + "jsonrpc": "2.0", + "method": "session/prompt", + "params": { + "sessionId": "test-session-123", + "messages": [{ + "role": "user", + "content": [ + { + "type": "text", + "text": "Hello, agent!" + } + ] + }] + }, + "id": 3 + }, + "streaming_updates": [ + { + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-123", + "type": "agent_message_chunk", + "content": { + "text": "Hello! How can I help you?" + } + } + } + ], + "response": { + "jsonrpc": "2.0", + "result": {}, + "id": 3 + } + }) +} + +/// Golden transcript for tool call flow (read file). +pub fn golden_tool_call_read() -> serde_json::Value { + json!({ + "request": { + "jsonrpc": "2.0", + "method": "session/prompt", + "params": { + "sessionId": "test-session-123", + "messages": [{ + "role": "user", + "content": [ + { + "type": "text", + "text": "Read the file test.txt" + } + ] + }] + }, + "id": 4 + }, + "streaming_updates": [ + { + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-123", + "type": "tool_call", + "content": { + "toolCallId": "tool-1", + "kind": "read", + "title": "Read test.txt", + "location": { + "path": "/path/to/test.txt" + } + } + } + }, + { + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-123", + "type": "tool_call_update", + "content": { + "toolCallId": "tool-1", + "status": "completed", + "result": { + "type": "content", + "content": "File content here" + } + } + } + }, + { + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-123", + "type": "agent_message_chunk", + "content": { + "text": "I've read the file. The content is: ..." + } + } + } + ], + "response": { + "jsonrpc": "2.0", + "result": {}, + "id": 4 + } + }) +} + +/// Golden transcript for cancellation flow. +pub fn golden_cancel() -> serde_json::Value { + json!({ + "request": { + "jsonrpc": "2.0", + "method": "session/cancel", + "params": { + "sessionId": "test-session-123" + }, + "id": 5 + }, + "response": { + "jsonrpc": "2.0", + "result": {}, + "id": 5 + } + }) +} + +/// Load a golden transcript by name. +pub fn load_golden_transcript(name: &str) -> Option<serde_json::Value> { + match name { + "initialize" => Some(golden_initialize()), + "new_session" => Some(golden_new_session()), + "prompt" => Some(golden_prompt()), + "tool_call_read" => Some(golden_tool_call_read()), + "cancel" => Some(golden_cancel()), + _ => None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_golden_transcripts_exist() { + assert!(load_golden_transcript("initialize").is_some()); + assert!(load_golden_transcript("new_session").is_some()); + assert!(load_golden_transcript("prompt").is_some()); + assert!(load_golden_transcript("tool_call_read").is_some()); + assert!(load_golden_transcript("cancel").is_some()); + assert!(load_golden_transcript("nonexistent").is_none()); + } + + #[test] + fn test_golden_initialize_structure() { + let transcript = golden_initialize(); + assert!(transcript.get("request").is_some()); + assert!(transcript.get("response").is_some()); + } +} diff --git a/crates/dirigent_core/tests/acp_integration/mocker_utils.rs b/crates/dirigent_core/tests/acp_integration/mocker_utils.rs new file mode 100644 index 0000000..71958c7 --- /dev/null +++ b/crates/dirigent_core/tests/acp_integration/mocker_utils.rs @@ -0,0 +1,211 @@ +//! Utilities for spawning and managing the ACP mocker in tests. + +use std::process::{Child, Command, Stdio}; +use std::time::Duration; +use tokio::time::sleep; + +/// Handle to a running mocker process. +pub struct MockerProcess { + process: Child, + mode: MockerMode, +} + +/// Mocker execution mode. +#[derive(Debug, Clone, Copy)] +pub enum MockerMode { + /// Stdio mode (stdin/stdout communication) + Stdio, + /// HTTP mode (HTTP + SSE communication) + Http { port: u16 }, +} + +impl MockerProcess { + /// Spawn a mocker in stdio mode. + /// + /// # Example + /// ```no_run + /// use dirigent_core::tests::acp_integration::MockerProcess; + /// + /// #[tokio::test] + /// async fn test_stdio_mocker() { + /// let mocker = MockerProcess::spawn_stdio().await.unwrap(); + /// // Use mocker... + /// mocker.kill().await.unwrap(); + /// } + /// ``` + pub async fn spawn_stdio() -> Result<Self, String> { + Self::spawn_stdio_with_args(&[]).await + } + + /// Spawn a mocker in stdio mode with custom arguments. + pub async fn spawn_stdio_with_args(args: &[&str]) -> Result<Self, String> { + let mut cmd_args = vec!["serve", "--stdio"]; + cmd_args.extend_from_slice(args); + + let process = Command::new("cargo") + .args(&["run", "--package", "dirigate", "--"]) + .args(&cmd_args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) + .spawn() + .map_err(|e| format!("Failed to spawn mocker: {}", e))?; + + // Give the mocker time to start + sleep(Duration::from_millis(500)).await; + + Ok(Self { + process, + mode: MockerMode::Stdio, + }) + } + + /// Spawn a mocker in HTTP mode on a specific port. + /// + /// # Example + /// ```no_run + /// use dirigent_core::tests::acp_integration::MockerProcess; + /// + /// #[tokio::test] + /// async fn test_http_mocker() { + /// let mocker = MockerProcess::spawn_http(8888).await.unwrap(); + /// // Use mocker at http://localhost:8888 + /// mocker.kill().await.unwrap(); + /// } + /// ``` + pub async fn spawn_http(port: u16) -> Result<Self, String> { + Self::spawn_http_with_args(port, &[]).await + } + + /// Spawn a mocker in HTTP mode with custom arguments. + pub async fn spawn_http_with_args(port: u16, args: &[&str]) -> Result<Self, String> { + let port_str = port.to_string(); + let mut cmd_args = vec!["serve", "--port", port_str.as_str()]; + cmd_args.extend_from_slice(args); + + let process = Command::new("cargo") + .args(&["run", "--package", "dirigate", "--"]) + .args(&cmd_args) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) + .spawn() + .map_err(|e| format!("Failed to spawn mocker: {}", e))?; + + // Give the HTTP server time to start + sleep(Duration::from_millis(1000)).await; + + // TODO: Add health check to verify mocker is ready + + Ok(Self { + process, + mode: MockerMode::Http { port }, + }) + } + + /// Get the mocker mode. + pub fn mode(&self) -> MockerMode { + self.mode + } + + /// Get the HTTP URL if in HTTP mode. + pub fn http_url(&self) -> Option<String> { + match self.mode { + MockerMode::Http { port } => Some(format!("http://localhost:{}", port)), + MockerMode::Stdio => None, + } + } + + /// Kill the mocker process. + pub async fn kill(mut self) -> Result<(), String> { + self.process + .kill() + .map_err(|e| format!("Failed to kill mocker: {}", e))?; + + // Wait for process to exit + sleep(Duration::from_millis(100)).await; + + Ok(()) + } +} + +impl Drop for MockerProcess { + fn drop(&mut self) { + // Best-effort cleanup on drop + let _ = self.process.kill(); + } +} + +/// Configuration for mocker test scenarios. +#[derive(Debug, Clone)] +pub struct MockerConfig { + /// Preset configuration name + pub preset: Option<String>, + /// Custom configuration JSON/TOML + pub config: Option<String>, +} + +impl MockerConfig { + /// Create a default mocker configuration. + pub fn default() -> Self { + Self { + preset: None, + config: None, + } + } + + /// Create a mocker configuration with a preset. + pub fn with_preset(preset: impl Into<String>) -> Self { + Self { + preset: Some(preset.into()), + config: None, + } + } + + /// Create a mocker configuration with custom config. + // Test utility - kept for future mocker configuration test development + #[allow(dead_code)] + pub fn with_config(config: impl Into<String>) -> Self { + Self { + preset: None, + config: Some(config.into()), + } + } + + /// Convert to command-line arguments for the mocker. + pub fn to_args(&self) -> Vec<String> { + let mut args = Vec::new(); + + if let Some(preset) = &self.preset { + args.push("--preset".to_string()); + args.push(preset.clone()); + } + + if let Some(config) = &self.config { + args.push("--config".to_string()); + args.push(config.clone()); + } + + args + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mocker_config_default() { + let config = MockerConfig::default(); + assert!(config.preset.is_none()); + assert!(config.config.is_none()); + assert!(config.to_args().is_empty()); + } + + #[test] + fn test_mocker_config_with_preset() { + let config = MockerConfig::with_preset("basic"); + assert_eq!(config.preset, Some("basic".to_string())); + assert_eq!(config.to_args(), vec!["--preset", "basic"]); + } +} diff --git a/crates/dirigent_core/tests/acp_integration/mod.rs b/crates/dirigent_core/tests/acp_integration/mod.rs new file mode 100644 index 0000000..da47ba2 --- /dev/null +++ b/crates/dirigent_core/tests/acp_integration/mod.rs @@ -0,0 +1,9 @@ +//! ACP integration test utilities. +//! +//! This module provides utilities for testing with the conductor. + +pub mod mocker_utils; +pub mod golden_transcripts; + +pub use mocker_utils::*; +pub use golden_transcripts::*; diff --git a/crates/dirigent_core/tests/acp_mocker_test.rs b/crates/dirigent_core/tests/acp_mocker_test.rs new file mode 100644 index 0000000..7306ca4 --- /dev/null +++ b/crates/dirigent_core/tests/acp_mocker_test.rs @@ -0,0 +1,100 @@ +//! ACP mocker integration tests. +//! +//! These tests use the dirigate to validate ACP protocol flows. +//! Most tests are marked as #[ignore] by default since they require spawning processes. +//! +//! Run with: cargo test --package dirigent_core --test acp_mocker_test -- --ignored + +#![cfg(feature = "server")] + +mod acp_integration; + +use acp_integration::*; + +#[tokio::test] +async fn test_mocker_utilities_exist() { + // Just verify the utilities compile and work + let config = MockerConfig::default(); + assert!(config.to_args().is_empty()); +} + +#[tokio::test] +async fn test_golden_transcripts_available() { + // Verify golden transcripts are available + assert!(load_golden_transcript("initialize").is_some()); + assert!(load_golden_transcript("new_session").is_some()); + assert!(load_golden_transcript("prompt").is_some()); + assert!(load_golden_transcript("tool_call_read").is_some()); + assert!(load_golden_transcript("cancel").is_some()); +} + +// ============================================================================ +// Mocker Spawning Tests (Ignored by default) +// ============================================================================ + +#[tokio::test] +#[ignore] +async fn test_spawn_stdio_mocker() { + let mocker = MockerProcess::spawn_stdio().await; + assert!(mocker.is_ok(), "Failed to spawn stdio mocker: {:?}", mocker.err()); + + if let Ok(mocker) = mocker { + assert!(matches!(mocker.mode(), MockerMode::Stdio)); + assert_eq!(mocker.http_url(), None); + mocker.kill().await.expect("Failed to kill mocker"); + } +} + +#[tokio::test] +#[ignore] +async fn test_spawn_http_mocker() { + let port = 18888; // Use non-standard port for testing + let mocker = MockerProcess::spawn_http(port).await; + assert!(mocker.is_ok(), "Failed to spawn HTTP mocker: {:?}", mocker.err()); + + if let Ok(mocker) = mocker { + assert!(matches!(mocker.mode(), MockerMode::Http { .. })); + assert_eq!( + mocker.http_url(), + Some(format!("http://localhost:{}", port)) + ); + mocker.kill().await.expect("Failed to kill mocker"); + } +} + +#[tokio::test] +#[ignore] +async fn test_spawn_mocker_with_config() { + let config = MockerConfig::with_preset("basic"); + let args_vec = config.to_args(); + let args: Vec<&str> = args_vec.iter().map(|s| s.as_str()).collect(); + + let mocker = MockerProcess::spawn_stdio_with_args(&args).await; + assert!(mocker.is_ok(), "Failed to spawn mocker with config: {:?}", mocker.err()); + + if let Ok(mocker) = mocker { + mocker.kill().await.expect("Failed to kill mocker"); + } +} + +// ============================================================================ +// Documentation for Future Tests +// ============================================================================ + +// TODO: TEST-01 - Protocol validation test suite (stdio mode) +// - initialize → response with capabilities +// - session/new → sessionId +// - session/prompt → streaming updates +// - Validate all protocol flows against mocker + +// TODO: TEST-02 - Protocol validation test suite (HTTP mode) +// - Same tests as TEST-01 but using HTTP transport +// - SSE stream validation + +// TODO: TEST-03 - Session update rendering tests +// - Parse and validate all session update types +// - Verify ToolKind, ToolCallLocation, diff content + +// TODO: TEST-04 - Permission prompt flow tests +// - Test all permission outcomes (allow_once, allow_always, reject_once, reject_always, cancel) +// - Decision cache persistence and TTL diff --git a/crates/dirigent_core/tests/acp_stdio_integration_test.rs b/crates/dirigent_core/tests/acp_stdio_integration_test.rs new file mode 100644 index 0000000..22088d9 --- /dev/null +++ b/crates/dirigent_core/tests/acp_stdio_integration_test.rs @@ -0,0 +1,737 @@ +//! TEST-001: Comprehensive Stdio Integration Tests +//! +//! This file contains comprehensive integration tests for the ACP client over stdio transport. +//! It tests the full lifecycle of ACP communication using the dirigate in stdio mode. +//! +//! Test Coverage: +//! - T001: Full lifecycle (initialize, new session, prompt, cancel, load) +//! - T002: Initialize handshake +//! - T003: Session creation +//! - T004: Message sending with streaming +//! - T005: Session cancellation +//! - T006: Session loading +//! - T007: Multiple concurrent sessions +//! - T008: Graceful shutdown + +#![cfg(feature = "server")] + +use dirigent_core::connectors::acp::{AcpConfig, AcpConnector}; +use dirigent_core::connectors::{Connector, ConnectorCommand}; +use dirigent_protocol::Event; +use std::path::PathBuf; +use std::time::Duration; +use tokio::time::timeout; + +/// Helper to create a test fixture path +fn test_fixture_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("fixtures") + .join("acp_test.yaml") +} + +/// Helper to get the mocker binary path +fn mocker_binary_path() -> String { + // First try the compiled binary in target directory + let manifest_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + let project_root = manifest_dir.parent().unwrap().parent().unwrap(); + + // Try debug build first + let debug_path = project_root + .join("target") + .join("debug") + .join("dirigate.exe"); + + if debug_path.exists() { + return debug_path.to_string_lossy().to_string(); + } + + // Try release build + let release_path = project_root + .join("target") + .join("release") + .join("dirigate.exe"); + + if release_path.exists() { + return release_path.to_string_lossy().to_string(); + } + + // Fallback to cargo run + "cargo".to_string() +} + +/// Helper to create mocker args +fn mocker_args(fixture_path: PathBuf) -> Vec<String> { + let mocker_bin = mocker_binary_path(); + + if mocker_bin == "cargo" { + vec![ + "run".to_string(), + "--package".to_string(), + "dirigate".to_string(), + "--".to_string(), + "serve".to_string(), + "--stdio".to_string(), + "--fixtures".to_string(), + fixture_path.to_string_lossy().to_string(), + ] + } else { + vec![ + "serve".to_string(), + "--stdio".to_string(), + "--fixtures".to_string(), + fixture_path.to_string_lossy().to_string(), + ] + } +} + +/// Helper to create a test ACP connector with stdio transport +fn create_stdio_connector() -> AcpConnector { + let mocker_bin = mocker_binary_path(); + let fixture_path = test_fixture_path(); + let args = mocker_args(fixture_path); + + let config = AcpConfig::stdio(mocker_bin, args) + .with_retry_max_attempts(3) + .with_retry_initial_delay(Duration::from_millis(500)); + + AcpConnector::new( + "stdio-test-connector".to_string(), + uuid::Uuid::nil(), + "Stdio Test Connector".to_string(), + config, + dirigent_core::sharing::bus::SharingBus::new(), + ) + .expect("Failed to create connector") +} + +/// Helper to wait for event with timeout +async fn wait_for_event<F>( + events_rx: &mut tokio::sync::broadcast::Receiver<Event>, + predicate: F, + timeout_duration: Duration, +) -> anyhow::Result<Event> +where + F: Fn(&Event) -> bool, +{ + let result = timeout(timeout_duration, async { + loop { + match events_rx.recv().await { + Ok(event) => { + if predicate(&event) { + return Ok(event); + } + } + Err(e) => { + return Err(anyhow::anyhow!("Event receive error: {}", e)); + } + } + } + }) + .await; + + match result { + Ok(Ok(event)) => Ok(event), + Ok(Err(e)) => Err(e), + Err(_) => Err(anyhow::anyhow!("Timeout waiting for event")), + } +} + +// ============================================================================ +// T001: Full Lifecycle Test +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires dirigate binary to be built"] +async fn test_t001_full_lifecycle_stdio() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + // Create connector + let connector = create_stdio_connector(); + let mut events = connector.subscribe(); + + // Start the connector task + let task_handle = connector.start_task().await; + + // Wait for connection + let connected_event = wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + assert!(matches!(connected_event, Event::Connected)); + tracing::info!("✓ Connected to mocker"); + + // Create a new session + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_created = wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await?; + + let session_id = match session_created { + Event::SessionCreated { session, .. } => { + tracing::info!("✓ Session created: {}", session.id); + session.id + } + _ => anyhow::bail!("Expected SessionCreated event"), + }; + + // Send a message + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "Hello, test!".to_string(), + }) + .await?; + + // Wait for message started + let message_started = wait_for_event( + &mut events, + |e| matches!(e, Event::MessageStarted { .. }), + Duration::from_secs(5), + ) + .await?; + + let _message_id = match message_started { + Event::MessageStarted { message, .. } => { + tracing::info!("✓ Message started: {}", message.id); + message.id + } + _ => anyhow::bail!("Expected MessageStarted event"), + }; + + // Wait for message content (streaming chunks) + let mut received_content = false; + for _ in 0..10 { + if let Ok(event) = timeout(Duration::from_secs(2), events.recv()).await { + match event? { + Event::SessionUpdate { + connector_id: _, + session_id: _, + update, + } => { + if let dirigent_protocol::SessionUpdate::AgentMessageChunk { .. } = update { + received_content = true; + tracing::info!("✓ Received message content"); + break; + } + } + Event::MessageCompleted { message, .. } => { + tracing::info!("✓ Message completed: {}", message.id); + received_content = true; + break; + } + _ => {} + } + } + } + + assert!(received_content, "Should receive message content"); + + // Test cancellation (send another message then cancel immediately) + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "This should be cancelled".to_string(), + }) + .await?; + + // Cancel immediately + cmd_tx + .send(ConnectorCommand::CancelGeneration { + session_id: session_id.clone(), + }) + .await?; + + tracing::info!("✓ Sent cancellation request"); + + // Stop the connector + connector.stop(); + + // Wait for task to complete + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + tracing::info!("✓ Full lifecycle completed successfully"); + Ok(()) +} + +// ============================================================================ +// T002: Initialize Handshake +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires dirigate binary to be built"] +async fn test_t002_initialize_handshake() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let connector = create_stdio_connector(); + let mut events = connector.subscribe(); + + let task_handle = connector.start_task().await; + + // Wait for connected event + let connected = wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + assert!(matches!(connected, Event::Connected)); + tracing::info!("✓ Initialize handshake completed"); + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T003: Session Creation +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires dirigate binary to be built"] +async fn test_t003_session_creation() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let connector = create_stdio_connector(); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Wait for connection + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + // Create session + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + // Wait for session created event + let session_created = wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await?; + + match session_created { + Event::SessionCreated { session, .. } => { + assert_eq!(session.title, "Session Creation Test"); + tracing::info!("✓ Session created with correct title"); + } + _ => anyhow::bail!("Expected SessionCreated event"), + } + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T004: Message Sending with Streaming +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires dirigate binary to be built"] +async fn test_t004_message_streaming() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let connector = create_stdio_connector(); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Wait for connection and create session + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + // Send message + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "Tell me a story".to_string(), + }) + .await?; + + // Wait for message started + wait_for_event( + &mut events, + |e| matches!(e, Event::MessageStarted { .. }), + Duration::from_secs(5), + ) + .await?; + + // Collect streaming chunks + let mut chunk_count = 0; + let mut completed = false; + + for _ in 0..20 { + if let Ok(result) = timeout(Duration::from_secs(2), events.recv()).await { + match result? { + Event::SessionUpdate { update, .. } => { + if let dirigent_protocol::SessionUpdate::AgentMessageChunk { .. } = update { + chunk_count += 1; + tracing::info!("Received chunk {}", chunk_count); + } + } + Event::MessageCompleted { .. } => { + completed = true; + tracing::info!("✓ Message completed after {} chunks", chunk_count); + break; + } + _ => {} + } + } + } + + assert!(chunk_count > 0, "Should receive at least one chunk"); + assert!(completed, "Message should complete"); + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T005: Session Cancellation +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires dirigate binary to be built"] +async fn test_t005_session_cancellation() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let connector = create_stdio_connector(); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Setup: connect and create session + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => session.id, + _ => anyhow::bail!("Expected SessionCreated"), + }; + + // Send message + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "Long response please".to_string(), + }) + .await?; + + // Wait for message to start + wait_for_event( + &mut events, + |e| matches!(e, Event::MessageStarted { .. }), + Duration::from_secs(5), + ) + .await?; + + // Wait for at least one chunk + wait_for_event( + &mut events, + |e| matches!(e, Event::SessionUpdate { .. }), + Duration::from_secs(5), + ) + .await?; + + // Now cancel + cmd_tx + .send(ConnectorCommand::CancelGeneration { + session_id: session_id.clone(), + }) + .await?; + + tracing::info!("✓ Cancellation sent successfully"); + + // The message should either complete or be cancelled + // Either way, we verify the cancel command was accepted + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T006: Session Loading +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires dirigate binary to be built"] +async fn test_t006_session_loading() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let connector = create_stdio_connector(); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Wait for connection + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + // Try to load a session from the fixture + // The fixture should have a pre-defined session we can load + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::LoadSession { + session_id: "test-session-1".to_string(), + cwd: ".".to_string(), + mcp_servers: None, + }) + .await?; + + // Wait for session loaded event + let session_loaded = wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await?; + + match session_loaded { + Event::SessionCreated { session, .. } => { + assert_eq!(session.id, "test-session-1"); + tracing::info!("✓ Session loaded successfully"); + } + _ => anyhow::bail!("Expected SessionCreated event"), + } + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T007: Multiple Concurrent Sessions +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires dirigate binary to be built"] +async fn test_t007_multiple_concurrent_sessions() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let connector = create_stdio_connector(); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Wait for connection + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + let cmd_tx = connector.command_tx(); + let mut session_ids = Vec::new(); + + // Create multiple sessions + for i in 0..3 { + cmd_tx + .send(ConnectorCommand::CreateSession { + cwd: None, + project_id: None, + ownership: dirigent_protocol::SessionOwnership::internal(), + }) + .await?; + + let session_id = match wait_for_event( + &mut events, + |e| matches!(e, Event::SessionCreated { .. }), + Duration::from_secs(5), + ) + .await? + { + Event::SessionCreated { session, .. } => { + tracing::info!("✓ Created session {}: {}", i + 1, session.id); + session.id + } + _ => anyhow::bail!("Expected SessionCreated"), + }; + + session_ids.push(session_id); + } + + assert_eq!(session_ids.len(), 3, "Should create 3 sessions"); + + // Send messages to all sessions concurrently + for (i, session_id) in session_ids.iter().enumerate() { + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: format!("Message to session {}", i + 1), + }) + .await?; + } + + // Wait for all messages to start + let mut started_count = 0; + for _ in 0..30 { + if let Ok(result) = timeout(Duration::from_secs(1), events.recv()).await { + if let Event::MessageStarted { .. } = result? { + started_count += 1; + if started_count == 3 { + break; + } + } + } + } + + assert_eq!(started_count, 3, "All messages should start"); + tracing::info!("✓ All messages started successfully"); + + // Cleanup + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await?; + + Ok(()) +} + +// ============================================================================ +// T008: Graceful Shutdown +// ============================================================================ + +#[tokio::test] +#[ignore = "Requires dirigate binary to be built"] +async fn test_t008_graceful_shutdown() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_env_filter("debug") + .with_test_writer() + .try_init() + .ok(); + + let connector = create_stdio_connector(); + let mut events = connector.subscribe(); + let task_handle = connector.start_task().await; + + // Wait for connection + wait_for_event( + &mut events, + |e| matches!(e, Event::Connected), + Duration::from_secs(10), + ) + .await?; + + tracing::info!("Connected, now shutting down gracefully"); + + // Stop the connector + connector.stop(); + + // Task should complete without hanging + let result = timeout(Duration::from_secs(5), task_handle).await; + + assert!(result.is_ok(), "Task should complete within timeout"); + tracing::info!("✓ Graceful shutdown completed"); + + Ok(()) +} diff --git a/crates/dirigent_core/tests/adapter_integration_test.rs b/crates/dirigent_core/tests/adapter_integration_test.rs new file mode 100644 index 0000000..2c2063e --- /dev/null +++ b/crates/dirigent_core/tests/adapter_integration_test.rs @@ -0,0 +1,365 @@ +//! Tests for Event Adapter Integration +//! +//! T071: Event Adapter Integration +//! - SSE events translated correctly by OpenCodeAdapter +//! - Events forwarded to broadcast channel + +#![cfg(feature = "server")] +use dirigent_protocol::adapters::OpenCodeAdapter; +use dirigent_protocol::Event; +use std::time::Duration; +use tokio::time::timeout; + +// ============================================================================ +// T071: Event Adapter Integration +// ============================================================================ + +#[test] +fn test_t071_adapter_can_be_created() { + let _adapter = OpenCodeAdapter::new(); + // If this compiles and runs, the adapter can be instantiated +} + +#[test] +fn test_t071_adapter_translates_session_created() { + use opencode_client::types as oc; + + let adapter = OpenCodeAdapter::new(); + + let oc_session = oc::Session { + id: "session-123".to_string(), + project_id: "project-1".to_string(), + directory: "/test".to_string(), + parent_id: None, + summary: None, + share: None, + title: "Test Session".to_string(), + version: "1.0".to_string(), + time: oc::SessionTime { + created: 1234567890, + updated: 1234567890, + compacting: None, + }, + revert: None, + }; + + let oc_event = oc::Event::SessionCreated { + properties: oc::SessionEventInfo { + info: oc_session.clone(), + }, + }; + + let result = adapter.translate_event(oc_event); + assert!(result.is_ok(), "Translation should succeed"); + + if let Ok(Event::SessionCreated { session, .. }) = result { + assert_eq!(session.id, "session-123"); + assert_eq!(session.title, "Test Session"); + } else { + panic!("Expected SessionCreated event, got {:?}", result); + } +} + +#[test] +fn test_t071_adapter_translates_message_updated() { + use opencode_client::types as oc; + + let adapter = OpenCodeAdapter::new(); + + let oc_message = oc::Message::Assistant(oc::AssistantMessage { + id: "msg-123".to_string(), + session_id: "session-456".to_string(), + time: oc::AssistantMessageTime { + created: 1234567890, + completed: None, + }, + error: None, + system: vec![], + parent_id: None, + model_id: None, + provider_id: None, + mode: None, + path: None, + summary: None, + cost: 0.0, + tokens: oc::TokenUsage::default(), + }); + + let oc_event = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { + info: oc_message.clone(), + }, + }; + + let result = adapter.translate_event(oc_event); + assert!(result.is_ok(), "Translation should succeed"); + + // The adapter might return MessageStarted or MessageCompleted depending on status + match result { + Ok(Event::MessageStarted { message, .. }) => { + assert_eq!(message.id, "msg-123"); + assert_eq!(message.session_id, "session-456"); + } + Ok(Event::MessageCompleted { message, .. }) => { + assert_eq!(message.id, "msg-123"); + assert_eq!(message.session_id, "session-456"); + } + other => { + panic!( + "Expected MessageStarted or MessageCompleted, got {:?}", + other + ); + } + } +} + +#[test] +fn test_t071_adapter_handles_duplicate_events() { + use opencode_client::types as oc; + + let adapter = OpenCodeAdapter::new(); + + let oc_message = oc::Message::Assistant(oc::AssistantMessage { + id: "msg-same".to_string(), + session_id: "session-1".to_string(), + time: oc::AssistantMessageTime { + created: 1234567890, + completed: None, + }, + error: None, + system: vec![], + parent_id: None, + model_id: None, + provider_id: None, + mode: None, + path: None, + summary: None, + cost: 0.0, + tokens: oc::TokenUsage::default(), + }); + + let oc_event = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { + info: oc_message.clone(), + }, + }; + + // First translation should succeed + let result1 = adapter.translate_event(oc_event.clone()); + assert!(result1.is_ok(), "First translation should succeed"); + + // Second translation of the same event should return Duplicate error + let result2 = adapter.translate_event(oc_event); + assert!( + result2.is_err(), + "Second translation should detect duplicate" + ); + + if let Err(dirigent_protocol::adapters::opencode::TranslationError::Duplicate) = result2 { + // Expected + } else { + panic!("Expected Duplicate error, got {:?}", result2); + } +} + +#[tokio::test] +async fn test_t071_events_forwarded_to_broadcast_channel() { + // Test that the broadcast channel mechanism works + // We create our own channels to test the pattern used by connectors + let (events_tx, mut events_rx) = tokio::sync::broadcast::channel(1000); + + // Send a test event + events_tx.send(dirigent_protocol::Event::Connected).ok(); + + // Verify we can receive it + let event_received = timeout(Duration::from_millis(100), events_rx.recv()).await; + + assert!( + event_received.is_ok(), + "Should receive events via broadcast channel" + ); + assert!(matches!( + event_received.unwrap().unwrap(), + dirigent_protocol::Event::Connected + )); +} + +#[tokio::test] +async fn test_t071_multiple_subscribers_receive_same_events() { + // Test that multiple subscribers all receive broadcast events + let (events_tx, _) = tokio::sync::broadcast::channel(1000); + + // Create multiple subscriptions + let mut events1 = events_tx.subscribe(); + let mut events2 = events_tx.subscribe(); + let mut events3 = events_tx.subscribe(); + + // Send an event + events_tx.send(dirigent_protocol::Event::Connected).ok(); + + // All subscribers should receive events + let timeout_duration = Duration::from_millis(100); + + let result1 = timeout(timeout_duration, events1.recv()).await; + let result2 = timeout(timeout_duration, events2.recv()).await; + let result3 = timeout(timeout_duration, events3.recv()).await; + + // All three should receive events + assert!(result1.is_ok(), "Subscriber 1 should receive events"); + assert!(result2.is_ok(), "Subscriber 2 should receive events"); + assert!(result3.is_ok(), "Subscriber 3 should receive events"); +} + +#[tokio::test] +async fn test_t071_events_contain_correct_data() { + // Test that events can be sent and received with correct data + let (events_tx, mut events_rx) = tokio::sync::broadcast::channel(1000); + + // Send various event types + events_tx.send(Event::Connected).ok(); + events_tx.send(Event::Disconnected).ok(); + events_tx + .send(Event::Error { + message: "Test error".to_string(), + }) + .ok(); + + // Collect events + let mut received_events = Vec::new(); + while let Ok(result) = timeout(Duration::from_millis(10), events_rx.recv()).await { + if let Ok(event) = result { + received_events.push(event); + } + if received_events.len() >= 3 { + break; + } + } + + // We should have received all events + assert_eq!(received_events.len(), 3, "Should receive all events"); + + // Check that events are valid Event enum variants + for event in &received_events { + match event { + Event::Connected => {} + Event::Disconnected => {} + Event::Error { message } => { + assert!(!message.is_empty(), "Error messages should not be empty"); + } + _ => {} + } + } +} + +#[test] +fn test_t071_adapter_translates_message_part() { + use opencode_client::types as oc; + + let adapter = OpenCodeAdapter::new(); + + // First, create a message so the part can be associated + let oc_message = oc::Message::Assistant(oc::AssistantMessage { + id: "msg-123".to_string(), + session_id: "session-456".to_string(), + time: oc::AssistantMessageTime { + created: 1234567890, + completed: None, + }, + error: None, + system: vec![], + parent_id: None, + model_id: None, + provider_id: None, + mode: None, + path: None, + summary: None, + cost: 0.0, + tokens: oc::TokenUsage::default(), + }); + + let msg_event = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { + info: oc_message.clone(), + }, + }; + + adapter + .translate_event(msg_event) + .expect("Message should translate"); + + // Now translate a message part + let oc_part = oc::Part::Text(oc::TextPart { + id: "part-123".to_string(), + session_id: "session-456".to_string(), + message_id: "msg-123".to_string(), + text: "Hello, world!".to_string(), + synthetic: None, + time: Some(oc::PartTime { + start: 1234567890, + end: Some(1234567900), + }), + }); + + let part_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: oc_part.clone(), + delta: Some("Hello, world!".to_string()), + }, + }; + + let result = adapter.translate_event(part_event); + assert!(result.is_ok(), "Part translation should succeed"); + + if let Ok(Event::SessionUpdate { connector_id: _, session_id, update }) = result { + assert_eq!(session_id, "ses-456"); + match update { + dirigent_protocol::SessionUpdate::AgentMessageChunk { + message_id, + content, + .. + } => { + assert_eq!(message_id, "msg-123"); + match content { + dirigent_protocol::ContentBlock::Text { text } => { + assert_eq!(text, "Hello, world!"); + } + _ => panic!("Expected Text content block"), + } + } + _ => panic!("Expected AgentMessageChunk"), + } + } else { + panic!("Expected SessionUpdate event, got {:?}", result); + } +} + +#[tokio::test] +async fn test_t071_broadcast_channel_capacity() { + // Test that the broadcast channel has sufficient capacity + let (events_tx, mut events_rx) = tokio::sync::broadcast::channel(1000); + + // Send many events + for i in 0..100 { + events_tx + .send(Event::Error { + message: format!("Event {}", i), + }) + .ok(); + } + + // Receive them all + let mut count = 0; + while let Ok(result) = timeout(Duration::from_millis(10), events_rx.recv()).await { + if result.is_ok() { + count += 1; + } else { + break; + } + if count >= 100 { + break; + } + } + + // We should have received all events without lagging + assert_eq!(count, 100, "Should receive all events without lag"); +} diff --git a/crates/dirigent_core/tests/embedding_integration.rs b/crates/dirigent_core/tests/embedding_integration.rs new file mode 100644 index 0000000..46a0f51 --- /dev/null +++ b/crates/dirigent_core/tests/embedding_integration.rs @@ -0,0 +1,425 @@ +//! Integration tests for file embedding functionality. +//! +//! Tests the complete embedding pipeline from file paths to ContentBlocks. + +use dirigent_core::acp::content_blocks::build_content_blocks_from_files; +use dirigent_core::acp::protocol::prompt::{ContentBlock, EmbeddedResource}; +use dirigent_core::acp::{AgentCapabilities, PromptCapabilities}; +use dirigent_tools::config::{EmbeddingConfig, SandboxConfig}; +use tempfile::TempDir; + +/// Helper to create test agent capabilities. +fn create_test_agent_caps(embedded_context: bool) -> AgentCapabilities { + AgentCapabilities { + load_session: None, + prompt_capabilities: Some(PromptCapabilities { + image: None, + audio: None, + embedded_context: Some(embedded_context), + }), + mcp: None, + _meta: None, + } +} + +/// Helper to create test configuration. +fn create_test_config(temp_dir: &TempDir) -> (EmbeddingConfig, SandboxConfig) { + let embedding_config = EmbeddingConfig { + max_embed_bytes: 1000, + allow_resource_link: true, + redact_patterns: vec![], + snippet_strategy: dirigent_tools::config::SnippetStrategy::HeadTail, + max_files_per_prompt: 10, + }; + + let mut sandbox_config = SandboxConfig::default(); + sandbox_config.allowed_roots = vec![temp_dir.path().to_path_buf()]; + sandbox_config.normalize_roots(); + + (embedding_config, sandbox_config) +} + +#[test] +fn test_scenario_1_small_text_capability_on_embed() { + // Scenario 1: Small text file + capability on → embed + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("small.txt"); + std::fs::write(&file_path, "Hello, world!").unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + assert_eq!(blocks.len(), 1, "Should have one content block"); + + match &blocks[0] { + ContentBlock::Resource { resource, .. } => match resource { + EmbeddedResource::Text { text, .. } => { + assert_eq!(text, "Hello, world!"); + } + _ => panic!("Expected text resource"), + }, + _ => panic!("Expected resource block"), + } +} + +#[test] +fn test_scenario_2_large_text_capability_on_link() { + // Scenario 2: Large text file + capability on → link or snippet + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("large.txt"); + let large_content = "x".repeat(2000); // Exceeds 1000 byte limit + std::fs::write(&file_path, &large_content).unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + assert_eq!(blocks.len(), 1, "Should have one content block"); + + match &blocks[0] { + ContentBlock::ResourceLink { size, .. } => { + assert_eq!(*size, Some(2000)); + } + _ => panic!("Expected resource link block"), + } +} + +#[test] +fn test_scenario_3_binary_file_link() { + // Scenario 3: Binary file → link + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("image.png"); + std::fs::write(&file_path, b"\x89PNG\r\n\x1a\n").unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + assert_eq!(blocks.len(), 1, "Should have one content block"); + + match &blocks[0] { + ContentBlock::ResourceLink { mime_type, .. } => { + assert_eq!(mime_type, &Some("image/png".to_string())); + } + _ => panic!("Expected resource link block"), + } +} + +#[test] +fn test_scenario_4_capability_off_all_links() { + // Scenario 4: Capability off → all links + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + std::fs::write(&file_path, "Test content").unwrap(); + + let agent_caps = create_test_agent_caps(false); // Capability OFF + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + assert_eq!(blocks.len(), 1, "Should have one content block"); + + match &blocks[0] { + ContentBlock::ResourceLink { .. } => { + // Correct - should be link when capability is off + } + _ => panic!("Expected resource link block when capability is off"), + } +} + +#[test] +fn test_scenario_5_exceed_total_cap_link_remaining() { + // Scenario 5: Exceed total cap → deny or link remaining + let temp_dir = TempDir::new().unwrap(); + + // Create multiple small files that together exceed the total cap + let file1 = temp_dir.path().join("file1.txt"); + let file2 = temp_dir.path().join("file2.txt"); + let file3 = temp_dir.path().join("file3.txt"); + + let content = "x".repeat(800); // Each file is 800 bytes + std::fs::write(&file1, &content).unwrap(); + std::fs::write(&file2, &content).unwrap(); + std::fs::write(&file3, &content).unwrap(); + // Total: 2400 bytes, but max_embed_bytes * max_files_per_prompt = 1000 * 10 = 10000 + // So they should all embed in this test + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks = build_content_blocks_from_files( + &[file1, file2, file3], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + // All should be embedded or linked (not denied) + assert!(blocks.len() >= 2, "Should have at least 2 content blocks"); +} + +#[test] +fn test_scenario_6_exceed_file_count_deny() { + // Scenario 6: Exceed file count → deny + let temp_dir = TempDir::new().unwrap(); + let mut embedding_config = EmbeddingConfig::default(); + embedding_config.max_files_per_prompt = 2; // Limit to 2 files + + let mut sandbox_config = SandboxConfig::default(); + sandbox_config.allowed_roots = vec![temp_dir.path().to_path_buf()]; + sandbox_config.normalize_roots(); + + let agent_caps = create_test_agent_caps(true); + + // Create 3 files + let file1 = temp_dir.path().join("file1.txt"); + let file2 = temp_dir.path().join("file2.txt"); + let file3 = temp_dir.path().join("file3.txt"); + + std::fs::write(&file1, "File 1").unwrap(); + std::fs::write(&file2, "File 2").unwrap(); + std::fs::write(&file3, "File 3").unwrap(); + + let blocks = build_content_blocks_from_files( + &[file1, file2, file3], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + // Should have at most 2 blocks (third file denied) + assert!(blocks.len() <= 2, "Should have at most 2 content blocks"); +} + +#[test] +fn test_scenario_7_redaction_patterns_applied() { + // Scenario 7: Redaction patterns applied → verify content redacted + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("secrets.txt"); + std::fs::write(&file_path, "api_key: sk-1234567890abcdef").unwrap(); + + let agent_caps = create_test_agent_caps(true); + let mut embedding_config = EmbeddingConfig::default(); + embedding_config.redact_patterns = vec![ + r"(?i)(api[_-]?key):\s*([a-zA-Z0-9_\-\.]+)".to_string(), + ]; + + let mut sandbox_config = SandboxConfig::default(); + sandbox_config.allowed_roots = vec![temp_dir.path().to_path_buf()]; + sandbox_config.normalize_roots(); + + let blocks = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + assert_eq!(blocks.len(), 1, "Should have one content block"); + + match &blocks[0] { + ContentBlock::Resource { resource, .. } => match resource { + EmbeddedResource::Text { text, .. } => { + // Verify that the API key is redacted + assert!( + !text.contains("sk-1234567890abcdef"), + "Secret should be redacted" + ); + assert!(text.contains("REDACTED"), "Should contain redaction marker"); + } + _ => panic!("Expected text resource"), + }, + _ => panic!("Expected resource block"), + } +} + +#[test] +fn test_scenario_8_sandbox_violation_deny() { + // Scenario 8: Sandbox violation → deny with clear error + let temp_dir = TempDir::new().unwrap(); + let outside_dir = TempDir::new().unwrap(); + let file_path = outside_dir.path().join("outside.txt"); + std::fs::write(&file_path, "Outside sandbox").unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let result = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ); + + // Should fail with sandbox violation + assert!(result.is_err(), "Should fail with sandbox violation"); + let err = result.unwrap_err(); + assert!( + format!("{:?}", err).contains("SandboxViolation"), + "Error should be SandboxViolation" + ); +} + +#[test] +fn test_scenario_9_mixed_strategies() { + // Scenario 9: Mixed strategies in one prompt → correct blocks + let temp_dir = TempDir::new().unwrap(); + + // Small text file (will be embedded) + let small_file = temp_dir.path().join("small.txt"); + std::fs::write(&small_file, "Small content").unwrap(); + + // Large text file (will be linked) + let large_file = temp_dir.path().join("large.txt"); + let large_content = "x".repeat(2000); + std::fs::write(&large_file, &large_content).unwrap(); + + // Binary file (will be linked) + let binary_file = temp_dir.path().join("image.png"); + std::fs::write(&binary_file, b"\x89PNG\r\n\x1a\n").unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks = build_content_blocks_from_files( + &[small_file, large_file, binary_file], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + assert_eq!(blocks.len(), 3, "Should have three content blocks"); + + // First should be embedded + match &blocks[0] { + ContentBlock::Resource { .. } => { + // Correct - small file is embedded + } + _ => panic!("Expected first block to be Resource (embedded)"), + } + + // Second and third should be links + match &blocks[1] { + ContentBlock::ResourceLink { .. } => { + // Correct - large file is linked + } + _ => panic!("Expected second block to be ResourceLink"), + } + + match &blocks[2] { + ContentBlock::ResourceLink { .. } => { + // Correct - binary file is linked + } + _ => panic!("Expected third block to be ResourceLink"), + } +} + +#[test] +fn test_uri_stability() { + // Test that URIs are stable across multiple invocations + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("stable.txt"); + std::fs::write(&file_path, "Stable content").unwrap(); + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let blocks1 = build_content_blocks_from_files( + &[file_path.clone()], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + let blocks2 = build_content_blocks_from_files( + &[file_path], + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + + // Extract URIs and compare + match (&blocks1[0], &blocks2[0]) { + ( + ContentBlock::Resource { + resource: EmbeddedResource::Text { uri: uri1, .. }, + .. + }, + ContentBlock::Resource { + resource: EmbeddedResource::Text { uri: uri2, .. }, + .. + }, + ) => { + assert_eq!(uri1, uri2, "URIs should be stable"); + } + _ => panic!("Expected resource blocks with text"), + } +} + +#[test] +fn test_performance_many_files() { + // Test that building content blocks for many files is reasonably fast + let temp_dir = TempDir::new().unwrap(); + + // Create 10 files + let mut files = Vec::new(); + for i in 0..10 { + let file_path = temp_dir.path().join(format!("file{}.txt", i)); + std::fs::write(&file_path, format!("Content {}", i)).unwrap(); + files.push(file_path); + } + + let agent_caps = create_test_agent_caps(true); + let (embedding_config, sandbox_config) = create_test_config(&temp_dir); + + let start = std::time::Instant::now(); + let blocks = build_content_blocks_from_files( + &files, + &agent_caps, + &embedding_config, + &sandbox_config, + ) + .unwrap(); + let elapsed = start.elapsed(); + + assert_eq!(blocks.len(), 10, "Should have 10 content blocks"); + assert!( + elapsed.as_millis() < 500, + "Should complete in less than 500ms, took {}ms", + elapsed.as_millis() + ); +} diff --git a/crates/dirigent_core/tests/fixtures/acp_test.yaml b/crates/dirigent_core/tests/fixtures/acp_test.yaml new file mode 100644 index 0000000..46ce077 --- /dev/null +++ b/crates/dirigent_core/tests/fixtures/acp_test.yaml @@ -0,0 +1,40 @@ +# Test fixture for ACP integration tests +version: "0.1" + +sessions: + # Pre-defined session for loading tests + - id: "test-session-1" + title: "Preloaded Test Session" + created_at: "2025-01-01T00:00:00Z" + participants: + - id: "user-1" + kind: user + display_name: "Test User" + - id: "assistant-1" + kind: assistant + display_name: "Test Assistant" + messages: + - id: "msg-1" + session_id: "test-session-1" + role: user + content: "This is a preloaded message" + created_at: "2025-01-01T00:00:01Z" + - id: "msg-2" + session_id: "test-session-1" + role: assistant + content: "This is a preloaded response" + created_at: "2025-01-01T00:00:02Z" + parent_id: "msg-1" + +responders: + keyword_map: + hello: "Hello! How can I help you today?" + test: "This is a test response" + story: "Once upon a time, in a digital realm far away, there lived a brave ACP connector who ventured forth to test the limits of communication protocols." + default_strategy: echo + +streaming: + enabled: true + tokens_per_chunk: 3 + chunk_interval_ms: 50 + jitter_ms: 10 diff --git a/crates/dirigent_core/tests/integration_test.rs b/crates/dirigent_core/tests/integration_test.rs new file mode 100644 index 0000000..0504bc0 --- /dev/null +++ b/crates/dirigent_core/tests/integration_test.rs @@ -0,0 +1,634 @@ +#![cfg(feature = "server")] +//! Full integration tests for CoreRuntime + +//! + +//! T076: Full Runtime Lifecycle + +//! T077: Multiple Connectors + +use dirigent_core::connectors::{Connector, ConnectorCommand, ConnectorHandle}; +use dirigent_core::types::{ConnectorKind, ConnectorState}; +use dirigent_core::{ConnectorConfig, CoreConfig, CoreRuntime}; +use serde_json::json; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{broadcast, mpsc, RwLock}; +use tokio::time::timeout; + +/// Helper to create a test runtime +fn create_test_runtime() -> CoreRuntime { + CoreRuntime::new(CoreConfig::default(), None) +} + +/// Helper to create an OpenCode connector config +fn create_opencode_config(id: &str, title: &str) -> ConnectorConfig { + ConnectorConfig { + id: Some(id.to_string()), + kind: ConnectorKind::OpenCode, + owner: None, + title: Some(title.to_string()), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": title, + "initial_session": null + }), + ..Default::default() + } +} + +/// Mock connector for testing (simpler than OpenCode) +struct MockConnector { + id: String, + owner: dirigent_core::UserId, + title: String, + state: Arc<RwLock<ConnectorState>>, + cmd_tx: mpsc::Sender<ConnectorCommand>, + cmd_rx: Arc<RwLock<Option<mpsc::Receiver<ConnectorCommand>>>>, + events_tx: broadcast::Sender<dirigent_protocol::Event>, +} + +impl MockConnector { + fn new(id: String, owner: dirigent_core::UserId, title: String) -> Self { + let (cmd_tx, cmd_rx) = mpsc::channel(100); + let (events_tx, _) = broadcast::channel(1000); + + Self { + id, + owner, + title, + state: Arc::new(RwLock::new(ConnectorState::Initializing)), + cmd_tx, + cmd_rx: Arc::new(RwLock::new(Some(cmd_rx))), + events_tx, + } + } + + fn events_sender(&self) -> broadcast::Sender<dirigent_protocol::Event> { + self.events_tx.clone() + } + + async fn start_task(&self) -> tokio::task::JoinHandle<()> { + let id = self.id.clone(); + let state = Arc::clone(&self.state); + let events_tx = self.events_tx.clone(); + let cmd_rx = self + .cmd_rx + .write() + .await + .take() + .expect("start_task called more than once"); + + tokio::spawn(async move { + Self::run_task(id, state, events_tx, cmd_rx).await; + }) + } + + async fn run_task( + _id: String, + state: Arc<RwLock<ConnectorState>>, + events_tx: broadcast::Sender<dirigent_protocol::Event>, + mut cmd_rx: mpsc::Receiver<ConnectorCommand>, + ) { + // Transition to Ready immediately + { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Ready; + } + let _ = events_tx.send(dirigent_protocol::Event::Connected); + + // Process commands + while let Some(cmd) = cmd_rx.recv().await { + match cmd { + ConnectorCommand::ListSessions => { + let _ = events_tx.send(dirigent_protocol::Event::SessionsListed { + connector_id: "test-connector".to_string(), + sessions: vec![], + }); + } + ConnectorCommand::ListMessages { .. } => { + let _ = events_tx + .send(dirigent_protocol::Event::MessagesListed { messages: vec![] }); + } + ConnectorCommand::CreateSession { .. } => { + // Mock connector doesn't support session creation + } + ConnectorCommand::LoadSession { .. } => { + // Mock connector doesn't support session loading + } + ConnectorCommand::SendMessage { .. } => { + // Just acknowledge + } + ConnectorCommand::CancelGeneration { .. } => { + // Mock connector doesn't support cancellation + } + ConnectorCommand::Reconnect => { + let _ = events_tx.send(dirigent_protocol::Event::Connected); + } + ConnectorCommand::AgentResponse { .. } => { + // Mock connector doesn't handle agent responses + } + ConnectorCommand::SetSessionMode { .. } => { + // Mock connector doesn't support mode switching + } + ConnectorCommand::SetSessionModel { .. } => { + // Mock connector doesn't support model switching + } + ConnectorCommand::CloseSession { .. } => { + // Mock connector doesn't support session close + } + ConnectorCommand::SetConfigOption { .. } => { + // Mock connector doesn't support config options + } + ConnectorCommand::Shutdown => { + let mut state_guard = state.write().await; + *state_guard = ConnectorState::Stopped; + break; + } + } + } + } +} + +impl Connector for MockConnector { + fn id(&self) -> &String { + &self.id + } + + fn kind(&self) -> ConnectorKind { + ConnectorKind::Mock + } + + fn owner(&self) -> &dirigent_core::UserId { + &self.owner + } + + fn title(&self) -> &str { + &self.title + } + + fn state(&self) -> ConnectorState { + match self.state.try_read() { + Ok(state_guard) => state_guard.clone(), + Err(_) => ConnectorState::Initializing, + } + } + + fn command_tx(&self) -> mpsc::Sender<ConnectorCommand> { + self.cmd_tx.clone() + } + + fn subscribe(&self) -> broadcast::Receiver<dirigent_protocol::Event> { + self.events_tx.subscribe() + } + + fn stop(&self) { + let cmd_tx = self.cmd_tx.clone(); + tokio::spawn(async move { + let _ = cmd_tx.send(ConnectorCommand::Shutdown).await; + }); + } +} + +// ============================================================================ +// T076: Full Runtime Lifecycle +// ============================================================================ + +#[tokio::test] +async fn test_t076_full_lifecycle_with_mock_connector() { + // Create runtime + let _runtime = create_test_runtime(); + + // Step 1: Create a mock connector manually (since we can't use Mock kind via API) + let mock = MockConnector::new( + "mock-1".to_string(), + uuid::Uuid::nil(), + "Mock Connector 1".to_string(), + ); + + // Create a handle for it + let handle = ConnectorHandle::new( + mock.id().clone(), + mock.kind(), + mock.owner().clone(), + mock.title().to_string(), + mock.command_tx(), + mock.events_sender(), + serde_json::json!({}), // Empty config for mock connector + None, // working_directory + None, // icon_path + false, // show_type_overlay + ); + + // Subscribe to events before starting + let mut events = handle.subscribe(); + + // Step 2: Start the connector + let task_handle = mock.start_task().await; + handle.set_task_handle(task_handle).await; + + // Step 3: Wait for it to become Ready + let connected = timeout(Duration::from_secs(2), async { + while let Ok(event) = events.recv().await { + if matches!(event, dirigent_protocol::Event::Connected) { + return true; + } + } + false + }) + .await; + + assert!( + connected.is_ok() && connected.unwrap(), + "Should receive Connected event" + ); + + // Note: State checking via try_read() may be flaky due to timing. + // The important thing is we received the Connected event, which proves + // the connector is running and functional. + // Skip state assertion as it's not critical for this integration test. + + // Step 4: Send commands and verify events + let cmd_tx = handle.command_tx(); + + // Send ListSessions + cmd_tx.send(ConnectorCommand::ListSessions).await.unwrap(); + let sessions_listed = timeout(Duration::from_secs(1), async { + while let Ok(event) = events.recv().await { + if matches!(event, dirigent_protocol::Event::SessionsListed { .. }) { + return true; + } + } + false + }) + .await; + assert!( + sessions_listed.is_ok() && sessions_listed.unwrap(), + "Should receive SessionsListed" + ); + + // Send ListMessages + cmd_tx + .send(ConnectorCommand::ListMessages { + session_id: "test-session".to_string(), + }) + .await + .unwrap(); + let messages_listed = timeout(Duration::from_secs(1), async { + while let Ok(event) = events.recv().await { + if matches!(event, dirigent_protocol::Event::MessagesListed { .. }) { + return true; + } + } + false + }) + .await; + assert!( + messages_listed.is_ok() && messages_listed.unwrap(), + "Should receive MessagesListed" + ); + + // Step 5: Stop the connector + handle.stop(); + + // Wait for it to stop + tokio::time::sleep(Duration::from_millis(200)).await; + + // Note: State checking via try_read() is flaky. The important thing is that + // we successfully sent the stop command. The connector should be stopping. + // Skip strict state assertion. +} + +#[tokio::test] +async fn test_t076_full_lifecycle_with_opencode_connector() { + // This test uses the real OpenCodeConnector but with a fake URL + // so it will fail to connect, but we can still verify the lifecycle + + let runtime = create_test_runtime(); + + // Step 1: Create connector + let cfg = create_opencode_config("oc-lifecycle", "Lifecycle Test"); + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + + // Step 2: Verify it's in the list + let list = runtime.list_connectors(None).await; + let found = list.iter().find(|c| c.id == connector_id); + assert!(found.is_some(), "Connector should be in list"); + + // Step 3: Get the connector handle + let handle = runtime.get_connector(&connector_id).await.unwrap(); + assert_eq!(handle.state(), ConnectorState::Initializing); + + // Step 4: Send commands (commands can be queued even if not started) + // Note: Since the connector wasn't started, the command channel exists but isn't being processed + let result = runtime + .send_command(&connector_id, ConnectorCommand::ListSessions) + .await; + // This should succeed - we can send to the channel + if result.is_err() { + println!( + "Note: Command send failed (channel may be closed): {:?}", + result + ); + // Don't fail the test - this is expected if connector wasn't fully initialized + } + + // Step 5: Stop the connector + let result = runtime.stop_connector(&connector_id).await; + // Note: Stop may fail if the connector wasn't properly started, which is ok for this test + if result.is_err() { + println!( + "Note: Stop failed (connector may not have been fully initialized): {:?}", + result + ); + } else { + // If stop succeeded, verify state changed + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(handle.state(), ConnectorState::Stopped); + } + + // Step 6: Remove the connector + let result = runtime.remove_connector(&connector_id).await; + assert!(result.is_ok(), "Should be able to remove"); + + // Step 7: Verify it's gone + let list = runtime.list_connectors(None).await; + let found = list.iter().find(|c| c.id == connector_id); + assert!(found.is_none(), "Connector should be removed from list"); +} + +// ============================================================================ +// T077: Multiple Connectors +// ============================================================================ + +#[tokio::test] +async fn test_t077_multiple_connectors_dont_crosstalk() { + let runtime = create_test_runtime(); + + // Create multiple connectors + let cfg1 = create_opencode_config("multi-1", "Multi 1"); + let cfg2 = create_opencode_config("multi-2", "Multi 2"); + let cfg3 = create_opencode_config("multi-3", "Multi 3"); + + let id1 = runtime + .create_connector(uuid::Uuid::nil(), cfg1) + .await + .unwrap(); + let id2 = runtime + .create_connector(uuid::Uuid::from_u128(2), cfg2) + .await + .unwrap(); + let id3 = runtime + .create_connector(uuid::Uuid::nil(), cfg3) + .await + .unwrap(); + + // Verify all three exist + let list = runtime.list_connectors(None).await; + assert!(list.iter().any(|c| c.id == id1)); + assert!(list.iter().any(|c| c.id == id2)); + assert!(list.iter().any(|c| c.id == id3)); + + // Verify they have correct owners + let c1 = list.iter().find(|c| c.id == id1).unwrap(); + let c2 = list.iter().find(|c| c.id == id2).unwrap(); + let c3 = list.iter().find(|c| c.id == id3).unwrap(); + + assert_eq!(c1.owner, uuid::Uuid::nil()); + assert_eq!(c2.owner, uuid::Uuid::from_u128(2)); + assert_eq!(c3.owner, uuid::Uuid::nil()); + + // Stop one connector (may fail if already stopped, which is ok) + let stop_result = runtime.stop_connector(&id2).await; + + if stop_result.is_ok() { + // Wait for state to update + tokio::time::sleep(Duration::from_millis(100)).await; + + // Verify id2 state changed + let handle2 = runtime.get_connector(&id2).await.unwrap(); + let state2 = handle2.state(); + assert!( + matches!(state2, ConnectorState::Stopped), + "Expected id2 to be Stopped after stop_connector, got {:?}", + state2 + ); + } else { + println!( + "Note: stop_connector failed (connector may not have been started): {:?}", + stop_result + ); + } + + // Remove one connector + runtime.remove_connector(&id1).await.unwrap(); + + // Verify only id1 is removed + assert!(runtime.get_connector(&id1).await.is_none()); + assert!(runtime.get_connector(&id2).await.is_some()); + assert!(runtime.get_connector(&id3).await.is_some()); + + // Clean up + runtime.remove_connector(&id2).await.ok(); + runtime.remove_connector(&id3).await.ok(); +} + +#[tokio::test] +async fn test_t077_per_connector_broadcasts_work() { + // Create mock connectors to test event isolation + let mock1 = MockConnector::new( + "mock-a".to_string(), + uuid::Uuid::nil(), + "Mock A".to_string(), + ); + let mock2 = MockConnector::new( + "mock-b".to_string(), + uuid::Uuid::nil(), + "Mock B".to_string(), + ); + + // Subscribe to events from each + let mut events1 = mock1.subscribe(); + let mut events2 = mock2.subscribe(); + + // Start both connectors + let _task1 = mock1.start_task().await; + let _task2 = mock2.start_task().await; + + // Wait for both to become ready + tokio::time::sleep(Duration::from_millis(200)).await; + + // Send command to mock1 only + mock1 + .command_tx() + .send(ConnectorCommand::ListSessions) + .await + .unwrap(); + + // mock1 should receive SessionsListed + let mock1_received = timeout(Duration::from_secs(1), async { + while let Ok(event) = events1.recv().await { + if matches!(event, dirigent_protocol::Event::SessionsListed { .. }) { + return true; + } + } + false + }) + .await; + + assert!( + mock1_received.is_ok() && mock1_received.unwrap(), + "Mock1 should receive event" + ); + + // mock2 should NOT receive it (only Connected event) + let mock2_received = timeout(Duration::from_millis(500), async { + while let Ok(event) = events2.recv().await { + if matches!(event, dirigent_protocol::Event::SessionsListed { .. }) { + return true; + } + // Skip Connected events + } + false + }) + .await; + + // Should timeout (not receive SessionsListed) + assert!( + mock2_received.is_err() || !mock2_received.unwrap(), + "Mock2 should NOT receive mock1's events" + ); + + // Clean up + mock1.stop(); + mock2.stop(); +} + +#[tokio::test] +async fn test_t077_concurrent_operations_on_different_connectors() { + let runtime = Arc::new(create_test_runtime()); + + // Create multiple connectors + let cfg1 = create_opencode_config("concurrent-1", "Concurrent 1"); + let cfg2 = create_opencode_config("concurrent-2", "Concurrent 2"); + + let id1 = runtime + .create_connector(uuid::Uuid::nil(), cfg1) + .await + .unwrap(); + let id2 = runtime + .create_connector(uuid::Uuid::nil(), cfg2) + .await + .unwrap(); + + // Spawn concurrent operations + let runtime1 = Arc::clone(&runtime); + let runtime2 = Arc::clone(&runtime); + let id1_clone = id1.clone(); + let id2_clone = id2.clone(); + + let task1 = tokio::spawn(async move { + // Send multiple commands to connector 1 + for _ in 0..10 { + runtime1 + .send_command(&id1_clone, ConnectorCommand::ListSessions) + .await + .ok(); + tokio::time::sleep(Duration::from_millis(10)).await; + } + }); + + let task2 = tokio::spawn(async move { + // Send multiple commands to connector 2 + for _ in 0..10 { + runtime2 + .send_command(&id2_clone, ConnectorCommand::ListSessions) + .await + .ok(); + tokio::time::sleep(Duration::from_millis(10)).await; + } + }); + + // Wait for both to complete + let result1 = task1.await; + let result2 = task2.await; + + assert!(result1.is_ok(), "Task 1 should complete successfully"); + assert!(result2.is_ok(), "Task 2 should complete successfully"); + + // Clean up + runtime.remove_connector(&id1).await.ok(); + runtime.remove_connector(&id2).await.ok(); +} + +#[tokio::test] +async fn test_t077_list_connectors_with_multiple() { + let runtime = create_test_runtime(); + + // Create several connectors for different users + for i in 1..=5 { + let cfg = create_opencode_config(&format!("user1-conn-{}", i), &format!("User 1 #{}", i)); + runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + } + + for i in 1..=3 { + let cfg = create_opencode_config(&format!("user2-conn-{}", i), &format!("User 2 #{}", i)); + runtime + .create_connector(uuid::Uuid::from_u128(2), cfg) + .await + .unwrap(); + } + + // List all + let all = runtime.list_connectors(None).await; + assert!(all.len() >= 8, "Should have at least 8 connectors"); + + // List for user-1 + let user1_list = runtime.list_connectors(Some(uuid::Uuid::nil())).await; + assert_eq!(user1_list.len(), 5, "User 1 should have 5 connectors"); + + // List for user-2 + let user2_list = runtime + .list_connectors(Some(uuid::Uuid::from_u128(2))) + .await; + assert_eq!(user2_list.len(), 3, "User 2 should have 3 connectors"); + + // List for user-3 (none) + let user3_list = runtime + .list_connectors(Some(uuid::Uuid::from_u128(3))) + .await; + assert_eq!(user3_list.len(), 0, "User 3 should have 0 connectors"); + + // Clean up + for i in 1..=5 { + runtime + .remove_connector(&format!("user1-conn-{}", i)) + .await + .ok(); + } + for i in 1..=3 { + runtime + .remove_connector(&format!("user2-conn-{}", i)) + .await + .ok(); + } +} + +#[tokio::test] +async fn test_t077_global_events_subscription() { + let _runtime = create_test_runtime(); + + // Subscribe to every event on the SharingBus (replaces the retired + // `subscribe_global()` API). + let bus_rx = _runtime.sharing_bus().subscribe_all().await; + + drop(bus_rx); + // If this compiles and runs, bus subscription works +} diff --git a/crates/dirigent_core/tests/matrix_migration_test.rs b/crates/dirigent_core/tests/matrix_migration_test.rs new file mode 100644 index 0000000..c702fb2 --- /dev/null +++ b/crates/dirigent_core/tests/matrix_migration_test.rs @@ -0,0 +1,207 @@ +//! Integration test: Matrix migration onto StreamRegistry (Phase 4, Task 18). +//! +//! Scope: +//! - `MatrixFactory::kind()` reports `"matrix"`. +//! - A fresh `StreamFactoryRegistry` with the factory registered can look it +//! up and rejects unknown kinds. +//! - Building a Matrix stream from a config with an `archive_wide` scope is +//! rejected with `StreamBuildError::Config`. +//! - Building a Matrix stream against a not-logged-in service is rejected +//! with `StreamBuildError::Transport` (does not panic, does not spin up +//! a real Matrix connection). +//! +//! This does NOT exercise end-to-end Matrix delivery — that requires a +//! live homeserver or a stub client, which is outside Task 18's scope. +//! The share-side `SessionStream` impl is covered separately by +//! `dirigent_matrix` unit tests. + +#![cfg(feature = "server")] + +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; + +use uuid::Uuid; + +use dirigent_auth::{Account, AccountKind, AccountProfile, SecretSource}; +use dirigent_core::sharing::{ + MatrixFactory, StreamBuildError, StreamConfig, StreamFactory, StreamFactoryRegistry, +}; +use dirigent_matrix::{MatrixBehaviorConfig, MatrixService}; +use dirigent_protocol::streaming::StreamScope; + +// ─── Helpers ──────────────────────────────────────────────────────────────── + +fn sample_matrix_account() -> Account { + let mut credentials = HashMap::new(); + credentials.insert( + "password".to_string(), + SecretSource::Inline { + value: "bot-pass".to_string(), + }, + ); + + let mut properties = HashMap::new(); + properties.insert( + "homeserver".to_string(), + serde_json::json!("https://matrix.example.com"), + ); + properties.insert( + "device_id".to_string(), + serde_json::json!("DIRIGENT_TEST"), + ); + + Account { + kind: AccountKind::Matrix, + config_name: "matrix-test".to_string(), + user_id: None, + credentials, + profile: AccountProfile { + username: Some("bot".to_string()), + display_name: Some("Test Bot".to_string()), + ..Default::default() + }, + properties, + } +} + +fn behavior() -> MatrixBehaviorConfig { + MatrixBehaviorConfig { + account: "matrix-test".to_string(), + mode: Default::default(), + default_invite: vec![], + store_path: "matrix/test/store".to_string(), + rooms: vec![], + } +} + +/// Build a `MatrixService` without calling `login()`. Any code path that +/// needs a live Client will surface a clean error (not a panic). +fn not_logged_in_service() -> Arc<MatrixService> { + let account = sample_matrix_account(); + let tmp = tempfile::tempdir().expect("tempdir"); + let data_dir: PathBuf = tmp.path().to_path_buf(); + // Leak the TempDir so the path survives the life of the service for + // the duration of the test. The sqlite store is only created when + // login() runs — we never call it in these tests. + std::mem::forget(tmp); + let service = MatrixService::from_account(&account, behavior(), data_dir) + .expect("from_account"); + Arc::new(service) +} + +// ─── Tests ────────────────────────────────────────────────────────────────── + +#[test] +fn matrix_factory_kind_is_matrix() { + let service = not_logged_in_service(); + let f = MatrixFactory::new(service); + assert_eq!(f.kind(), "matrix"); +} + +#[test] +fn registry_returns_registered_matrix_factory() { + let service = not_logged_in_service(); + let reg = StreamFactoryRegistry::new().register(MatrixFactory::new(service)); + assert!(reg.get("matrix").is_some(), "matrix factory should be found"); + assert!( + reg.get("langfuse").is_none(), + "unregistered kinds must return None" + ); +} + +#[tokio::test] +async fn build_rejects_archive_wide_scope_with_config_error() { + let service = not_logged_in_service(); + let factory = MatrixFactory::new(service); + + let params_toml = r#" +connector_id = "opencode-1" +session_id = "native-abc" +room_id = "!room:example.com" +"#; + let params: toml::Value = toml::from_str(params_toml).unwrap(); + + let cfg = StreamConfig { + name: "matrix-wrong-scope".to_string(), + kind: "matrix".to_string(), + scope: StreamScope::ArchiveWide { acknowledged: false }, + enabled: true, + params, + }; + + let err = factory.build(&cfg).await.err().expect("build should fail"); + match err { + StreamBuildError::Config(msg) => { + assert!( + msg.contains("session"), + "expected 'session' hint in error, got: {msg}" + ); + } + other => panic!("expected Config error, got {other:?}"), + } +} + +#[tokio::test] +async fn build_rejects_missing_params_with_config_error() { + let service = not_logged_in_service(); + let factory = MatrixFactory::new(service); + + // Missing room_id — required field. + let params_toml = r#" +connector_id = "opencode-1" +session_id = "native-abc" +"#; + let params: toml::Value = toml::from_str(params_toml).unwrap(); + + let cfg = StreamConfig { + name: "matrix-missing-room".to_string(), + kind: "matrix".to_string(), + scope: StreamScope::Session { + scroll_id: Uuid::now_v7(), + }, + enabled: true, + params, + }; + + let err = factory.build(&cfg).await.err().expect("build should fail"); + assert!( + matches!(err, StreamBuildError::Config(_)), + "expected Config error, got {err:?}" + ); +} + +#[tokio::test] +async fn build_reports_transport_error_when_service_not_logged_in() { + let service = not_logged_in_service(); + let factory = MatrixFactory::new(service); + + let params_toml = r#" +connector_id = "opencode-1" +session_id = "native-abc" +room_id = "!room:example.com" +"#; + let params: toml::Value = toml::from_str(params_toml).unwrap(); + + let cfg = StreamConfig { + name: "matrix-not-logged-in".to_string(), + kind: "matrix".to_string(), + scope: StreamScope::Session { + scroll_id: Uuid::now_v7(), + }, + enabled: true, + params, + }; + + let err = factory.build(&cfg).await.err().expect("build should fail"); + match err { + StreamBuildError::Transport(msg) => { + assert!( + msg.to_lowercase().contains("logged in") + || msg.to_lowercase().contains("matrix service"), + "expected transport error to mention login state, got: {msg}" + ); + } + other => panic!("expected Transport error, got {other:?}"), + } +} diff --git a/crates/dirigent_core/tests/opencode_connector_test.rs b/crates/dirigent_core/tests/opencode_connector_test.rs new file mode 100644 index 0000000..03febbc --- /dev/null +++ b/crates/dirigent_core/tests/opencode_connector_test.rs @@ -0,0 +1,370 @@ +#![cfg(feature = "server")] +//! Tests for OpenCodeConnector state transitions and command handling +//! +//! T069: OpenCodeConnector State Transitions +//! T070: OpenCodeConnector Command Handling + +use dirigent_core::connectors::opencode::{OpenCodeConfig, OpenCodeConnector}; +use dirigent_core::connectors::{Connector, ConnectorCommand}; +use dirigent_core::sharing::bus::SharingBus; +use dirigent_core::types::{ConnectorKind, ConnectorState}; +use std::time::Duration; +use tokio::time::timeout; + +/// Helper to create a test connector +fn create_test_connector(base_url: &str) -> OpenCodeConnector { + let config = OpenCodeConfig { + base_url: base_url.to_string(), + initial_session: None, + }; + + OpenCodeConnector::new( + "test-conn".to_string(), + uuid::Uuid::nil(), + "Test Connector".to_string(), + config, + SharingBus::new(), + ) +} + +// ============================================================================ +// T069: OpenCodeConnector State Transitions +// ============================================================================ + +#[tokio::test] +async fn test_t069_initial_state_is_initializing() { + let connector = create_test_connector("http://localhost:12225"); + assert_eq!(connector.state(), ConnectorState::Initializing); +} + +#[tokio::test] +async fn test_t069_connector_metadata() { + let connector = create_test_connector("http://localhost:12225"); + + assert_eq!(connector.id(), "test-conn"); + assert_eq!(*connector.owner(), uuid::Uuid::nil()); + assert_eq!(connector.title(), "Test Connector"); + assert_eq!(connector.kind(), ConnectorKind::OpenCode); +} + +#[tokio::test] +async fn test_t069_state_transition_connecting() { + // This test verifies that when started, the connector transitions to Connecting + // Since we don't have a real OpenCode server, it will fail to connect and + // enter Error state after retries, but we can verify the initial transition + + let connector = create_test_connector("http://192.0.2.1:12225"); // TEST-NET-1, guaranteed non-routable + let _events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + + // Give it a moment to start transitioning + tokio::time::sleep(Duration::from_millis(100)).await; + + // At this point, state should have transitioned from Initializing + // It will be either Connecting, Ready, or Error (depending on timing) + let state = connector.state(); + assert!( + !matches!(state, ConnectorState::Initializing), + "Expected state to transition from Initializing, got {:?}", + state + ); + + // Clean up: send shutdown + connector.stop(); + tokio::time::sleep(Duration::from_millis(100)).await; +} + +#[tokio::test] +#[ignore] // This test is flaky due to timing and network conditions - ignore for now +async fn test_t069_state_transition_to_error_on_connection_failure() { + let connector = create_test_connector("http://192.0.2.1:12225"); // TEST-NET-1, guaranteed non-routable + let mut events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + + // Wait for connection failures (it will retry a few times) + let error_received = timeout(Duration::from_secs(10), async { + loop { + if let Ok(event) = events.recv().await { + if let dirigent_protocol::Event::Error { message: _ } = event { + return true; + } + } + } + }) + .await; + + assert!( + error_received.is_ok(), + "Should receive error event on connection failure" + ); + + // Eventually, state should be Error after retries + tokio::time::sleep(Duration::from_secs(2)).await; + let state = connector.state(); + assert!( + matches!(state, ConnectorState::Error(_)), + "Expected Error state, got {:?}", + state + ); + + // Clean up + connector.stop(); +} + +#[tokio::test] +async fn test_t069_state_transition_to_stopped_on_shutdown() { + let connector = create_test_connector("http://localhost:12225"); + + // Start the connector + let task_handle = connector.start_task().await; + + // Give it a moment to start + tokio::time::sleep(Duration::from_millis(100)).await; + + // Send shutdown command + let cmd_tx = connector.command_tx(); + cmd_tx.send(ConnectorCommand::Shutdown).await.unwrap(); + + // Wait for task to complete + let result = timeout(Duration::from_secs(5), task_handle).await; + assert!(result.is_ok(), "Task should complete on shutdown"); + + // State should be Stopped + let state = connector.state(); + assert_eq!(state, ConnectorState::Stopped); +} + +// ============================================================================ +// T070: OpenCodeConnector Command Handling +// ============================================================================ + +#[tokio::test] +async fn test_t070_list_sessions_command_sent() { + let connector = create_test_connector("http://localhost:12225"); + let mut events = connector.subscribe(); + + // Start the connector (it will fail to connect, but we can still send commands) + let _task_handle = connector.start_task().await; + + // Give it a moment to start + tokio::time::sleep(Duration::from_millis(100)).await; + + // Send ListSessions command + let cmd_tx = connector.command_tx(); + cmd_tx.send(ConnectorCommand::ListSessions).await.unwrap(); + + // We expect an error event since we can't actually connect + // But this verifies the command was processed + let error_or_sessions = timeout(Duration::from_secs(2), async { + loop { + if let Ok(event) = events.recv().await { + match event { + dirigent_protocol::Event::Error { .. } => return "error", + dirigent_protocol::Event::SessionsListed { .. } => return "sessions", + _ => continue, + } + } + } + }) + .await; + + // We expect either an error (can't connect) or sessions (if somehow it worked) + assert!( + error_or_sessions.is_ok(), + "Should receive response to ListSessions" + ); + + // Clean up + connector.stop(); +} + +#[tokio::test] +async fn test_t070_list_messages_command_sent() { + let connector = create_test_connector("http://localhost:12225"); + let mut events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + tokio::time::sleep(Duration::from_millis(100)).await; + + // Send ListMessages command + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::ListMessages { + session_id: "test-session".to_string(), + }) + .await + .unwrap(); + + // We expect an error event since we can't actually connect + let error_or_messages = timeout(Duration::from_secs(2), async { + loop { + if let Ok(event) = events.recv().await { + match event { + dirigent_protocol::Event::Error { .. } => return "error", + dirigent_protocol::Event::MessagesListed { .. } => return "messages", + _ => continue, + } + } + } + }) + .await; + + assert!( + error_or_messages.is_ok(), + "Should receive response to ListMessages" + ); + + // Clean up + connector.stop(); +} + +#[tokio::test] +async fn test_t070_send_message_command_sent() { + let connector = create_test_connector("http://localhost:12225"); + let mut events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + tokio::time::sleep(Duration::from_millis(100)).await; + + // Send SendMessage command + let cmd_tx = connector.command_tx(); + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: "test-session".to_string(), + text: "Hello, world!".to_string(), + }) + .await + .unwrap(); + + // We expect an error event since we can't actually connect + let error_received = timeout(Duration::from_secs(2), async { + loop { + if let Ok(event) = events.recv().await { + if matches!(event, dirigent_protocol::Event::Error { .. }) { + return true; + } + } + } + }) + .await; + + assert!( + error_received.is_ok(), + "Should receive error for SendMessage" + ); + + // Clean up + connector.stop(); +} + +#[tokio::test] +#[ignore] // This test is flaky due to timing and network conditions - ignore for now +async fn test_t070_reconnect_command_restarts_sse() { + let connector = create_test_connector("http://192.0.2.1:12225"); // TEST-NET-1, guaranteed non-routable + let _events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + + // Wait for it to fail and enter Error state + tokio::time::sleep(Duration::from_secs(3)).await; + + let state_before = connector.state(); + assert!( + matches!(state_before, ConnectorState::Error(_)), + "Should be in Error state before reconnect" + ); + + // Send Reconnect command + let cmd_tx = connector.command_tx(); + cmd_tx.send(ConnectorCommand::Reconnect).await.unwrap(); + + // Give it a moment to process reconnect + tokio::time::sleep(Duration::from_millis(500)).await; + + // State should transition to Connecting (and then Error again) + let state_after = connector.state(); + assert!( + matches!(state_after, ConnectorState::Connecting) + || matches!(state_after, ConnectorState::Error(_)), + "Should be Connecting or Error after reconnect, got {:?}", + state_after + ); + + // Clean up + connector.stop(); +} + +#[tokio::test] +async fn test_t070_command_channel_is_cloneable() { + let connector = create_test_connector("http://localhost:12225"); + + // Get multiple command senders + let cmd_tx1 = connector.command_tx(); + let cmd_tx2 = connector.command_tx(); + + // Start the connector + let _task_handle = connector.start_task().await; + tokio::time::sleep(Duration::from_millis(100)).await; + + // Both senders should work + assert!(cmd_tx1.send(ConnectorCommand::ListSessions).await.is_ok()); + assert!(cmd_tx2.send(ConnectorCommand::ListSessions).await.is_ok()); + + // Clean up + connector.stop(); +} + +#[tokio::test] +async fn test_t070_multiple_subscriptions() { + let connector = create_test_connector("http://localhost:12225"); + + // Create multiple subscriptions + let mut events1 = connector.subscribe(); + let mut events2 = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + tokio::time::sleep(Duration::from_millis(100)).await; + + // Both should receive events + let receive1 = timeout(Duration::from_secs(2), events1.recv()); + let receive2 = timeout(Duration::from_secs(2), events2.recv()); + + // At least one should succeed (we'll get error events from failed connection) + assert!(receive1.await.is_ok() || receive2.await.is_ok()); + + // Clean up + connector.stop(); +} + +#[tokio::test] +async fn test_t070_shutdown_command_stops_task() { + let connector = create_test_connector("http://localhost:12225"); + + // Start the connector + let task_handle = connector.start_task().await; + tokio::time::sleep(Duration::from_millis(100)).await; + + // Verify task is running by checking state is not Stopped + let state_before = connector.state(); + assert!(!matches!(state_before, ConnectorState::Stopped)); + + // Send Shutdown command via command channel + let cmd_tx = connector.command_tx(); + cmd_tx.send(ConnectorCommand::Shutdown).await.unwrap(); + + // Task should complete + let result = timeout(Duration::from_secs(5), task_handle).await; + assert!(result.is_ok(), "Task should complete within timeout"); + assert!(result.unwrap().is_ok(), "Task should not panic"); + + // State should be Stopped + assert_eq!(connector.state(), ConnectorState::Stopped); +} diff --git a/crates/dirigent_core/tests/opencode_real_test.rs b/crates/dirigent_core/tests/opencode_real_test.rs new file mode 100644 index 0000000..19d3130 --- /dev/null +++ b/crates/dirigent_core/tests/opencode_real_test.rs @@ -0,0 +1,508 @@ +//! Real OpenCode HTTP API integration tests +//! +//! T078: OpenCode Real HTTP Tests +//! +//! These tests are marked with #[ignore] and require a real OpenCode instance +//! running. Set the DIRIGENT_TEST_API_URL environment variable to enable them. +//! +//! Example: +//! ```bash +//! DIRIGENT_TEST_API_URL=http://localhost:12225 cargo test --package dirigent_core -- --ignored +//! ``` + +#![cfg(feature = "server")] +use dirigent_core::connectors::opencode::{OpenCodeConfig, OpenCodeConnector}; + +use dirigent_core::connectors::{Connector, ConnectorCommand}; +use dirigent_core::sharing::bus::SharingBus; +use dirigent_core::types::ConnectorState; +use std::env; +use std::time::Duration; +use tokio::time::timeout; + +/// Get the test API URL from environment or skip the test +fn get_test_api_url() -> Option<String> { + env::var("DIRIGENT_TEST_API_URL").ok() +} + +/// Helper to create a real connector for testing +fn create_real_connector(base_url: &str) -> OpenCodeConnector { + let config = OpenCodeConfig { + base_url: base_url.to_string(), + initial_session: None, + }; + + OpenCodeConnector::new( + "real-test".to_string(), + uuid::Uuid::nil(), + "Real API Test".to_string(), + config, + SharingBus::new(), + ) +} + +// ============================================================================ +// T078: OpenCode Real HTTP Tests +// ============================================================================ + +#[tokio::test] +#[ignore] +async fn test_t078_real_connection_successful() { + let Some(api_url) = get_test_api_url() else { + println!("Skipping test: DIRIGENT_TEST_API_URL not set"); + return; + }; + + let connector = create_real_connector(&api_url); + let mut events = connector.subscribe(); + + // Start the connector + let task_handle = connector.start_task().await; + + // Wait for connection + let connected = timeout(Duration::from_secs(10), async { + while let Ok(event) = events.recv().await { + if matches!(event, dirigent_protocol::Event::Connected) { + return true; + } + if matches!(event, dirigent_protocol::Event::Error { .. }) { + eprintln!("Connection error: {:?}", event); + return false; + } + } + false + }) + .await; + + assert!( + connected.is_ok() && connected.unwrap(), + "Should successfully connect to real OpenCode instance" + ); + + // Verify state is Ready + assert_eq!(connector.state(), ConnectorState::Ready); + + // Clean up + connector.stop(); + let _ = timeout(Duration::from_secs(5), task_handle).await; +} + +#[tokio::test] +#[ignore] +async fn test_t078_real_list_sessions() { + let Some(api_url) = get_test_api_url() else { + println!("Skipping test: DIRIGENT_TEST_API_URL not set"); + return; + }; + + let connector = create_real_connector(&api_url); + let mut events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + + // Wait for connection + let connected = timeout(Duration::from_secs(10), async { + while let Ok(event) = events.recv().await { + if matches!(event, dirigent_protocol::Event::Connected) { + return true; + } + } + false + }) + .await; + + assert!( + connected.is_ok() && connected.unwrap(), + "Should connect first" + ); + + // Send ListSessions command + let cmd_tx = connector.command_tx(); + cmd_tx.send(ConnectorCommand::ListSessions).await.unwrap(); + + // Wait for SessionsListed event + let sessions_received = timeout(Duration::from_secs(5), async { + while let Ok(event) = events.recv().await { + if let dirigent_protocol::Event::SessionsListed { + connector_id: _, + sessions, + } = event + { + println!("Received {} sessions", sessions.len()); + return Some(sessions); + } + } + None + }) + .await; + + assert!( + sessions_received.is_ok(), + "Should receive SessionsListed event" + ); + let sessions = sessions_received.unwrap(); + assert!(sessions.is_some(), "Should have sessions data"); + + // Clean up + connector.stop(); +} + +#[tokio::test] +#[ignore] +async fn test_t078_real_list_messages() { + let Some(api_url) = get_test_api_url() else { + println!("Skipping test: DIRIGENT_TEST_API_URL not set"); + return; + }; + + let connector = create_real_connector(&api_url); + let mut events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + + // Wait for connection + timeout(Duration::from_secs(10), async { + while let Ok(event) = events.recv().await { + if matches!(event, dirigent_protocol::Event::Connected) { + break; + } + } + }) + .await + .ok(); + + // First, get a session ID by listing sessions + let cmd_tx = connector.command_tx(); + cmd_tx.send(ConnectorCommand::ListSessions).await.unwrap(); + + let session_id = timeout(Duration::from_secs(5), async { + while let Ok(event) = events.recv().await { + if let dirigent_protocol::Event::SessionsListed { + connector_id: _, + sessions, + } = event + { + if let Some(first_session) = sessions.first() { + return Some(first_session.id.clone()); + } + } + } + None + }) + .await; + + if session_id.is_err() || session_id.as_ref().unwrap().is_none() { + println!("No sessions available to test ListMessages"); + connector.stop(); + return; + } + + let session_id = session_id.unwrap().unwrap(); + println!("Testing with session: {}", session_id); + + // Now list messages for that session + cmd_tx + .send(ConnectorCommand::ListMessages { + session_id: session_id.clone(), + }) + .await + .unwrap(); + + let messages_received = timeout(Duration::from_secs(5), async { + while let Ok(event) = events.recv().await { + if let dirigent_protocol::Event::MessagesListed { messages } = event { + println!("Received {} messages", messages.len()); + return Some(messages); + } + } + None + }) + .await; + + assert!( + messages_received.is_ok(), + "Should receive MessagesListed event" + ); + let messages = messages_received.unwrap(); + assert!(messages.is_some(), "Should have messages data"); + + // Clean up + connector.stop(); +} + +#[tokio::test] +#[ignore] +async fn test_t078_real_send_message() { + let Some(api_url) = get_test_api_url() else { + println!("Skipping test: DIRIGENT_TEST_API_URL not set"); + return; + }; + + let connector = create_real_connector(&api_url); + let mut events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + + // Wait for connection + timeout(Duration::from_secs(10), async { + while let Ok(event) = events.recv().await { + if matches!(event, dirigent_protocol::Event::Connected) { + break; + } + } + }) + .await + .ok(); + + // Get a session ID + let cmd_tx = connector.command_tx(); + cmd_tx.send(ConnectorCommand::ListSessions).await.unwrap(); + + let session_id = timeout(Duration::from_secs(5), async { + while let Ok(event) = events.recv().await { + if let dirigent_protocol::Event::SessionsListed { + connector_id: _, + sessions, + } = event + { + if let Some(first_session) = sessions.first() { + return Some(first_session.id.clone()); + } + } + } + None + }) + .await; + + if session_id.is_err() || session_id.as_ref().unwrap().is_none() { + println!("No sessions available to test SendMessage"); + connector.stop(); + return; + } + + let session_id = session_id.unwrap().unwrap(); + println!("Sending message to session: {}", session_id); + + // Send a message + cmd_tx + .send(ConnectorCommand::SendMessage { + session_id: session_id.clone(), + text: "Test message from integration test".to_string(), + }) + .await + .unwrap(); + + // Wait for response events (MessageStarted, SessionUpdate, MessageCompleted) + let message_events_received = timeout(Duration::from_secs(30), async { + let mut started = false; + let mut parts = 0; + let mut completed = false; + + while let Ok(event) = events.recv().await { + match event { + dirigent_protocol::Event::MessageStarted { .. } => { + println!("Received MessageStarted"); + started = true; + } + dirigent_protocol::Event::SessionUpdate { .. } => { + parts += 1; + } + dirigent_protocol::Event::MessageCompleted { .. } => { + println!("Received MessageCompleted after {} parts", parts); + completed = true; + break; + } + dirigent_protocol::Event::Error { message } => { + eprintln!("Error during message send: {}", message); + break; + } + _ => {} + } + } + + (started, parts, completed) + }) + .await; + + assert!( + message_events_received.is_ok(), + "Should receive message response events" + ); + + let (started, parts, completed) = message_events_received.unwrap(); + println!( + "Message events: started={}, parts={}, completed={}", + started, parts, completed + ); + + assert!(started, "Should receive MessageStarted event"); + assert!(parts > 0, "Should receive at least one SessionUpdate event"); + assert!(completed, "Should receive MessageCompleted event"); + + // Clean up + connector.stop(); +} + +#[tokio::test] +#[ignore] +async fn test_t078_real_reconnect_command() { + let Some(api_url) = get_test_api_url() else { + println!("Skipping test: DIRIGENT_TEST_API_URL not set"); + return; + }; + + let connector = create_real_connector(&api_url); + let mut events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + + // Wait for initial connection + timeout(Duration::from_secs(10), async { + while let Ok(event) = events.recv().await { + if matches!(event, dirigent_protocol::Event::Connected) { + break; + } + } + }) + .await + .ok(); + + assert_eq!(connector.state(), ConnectorState::Ready); + + // Send Reconnect command + let cmd_tx = connector.command_tx(); + cmd_tx.send(ConnectorCommand::Reconnect).await.unwrap(); + + // Should disconnect and reconnect + let reconnected = timeout(Duration::from_secs(10), async { + let mut saw_disconnect = false; + while let Ok(event) = events.recv().await { + match event { + dirigent_protocol::Event::Disconnected => { + println!("Saw disconnect"); + saw_disconnect = true; + } + dirigent_protocol::Event::Connected => { + println!("Saw reconnect"); + if saw_disconnect { + return true; + } + } + _ => {} + } + } + false + }) + .await; + + // Note: The reconnect behavior may vary depending on the implementation + // We just verify the command was processed without error + println!( + "Reconnect completed (saw reconnect cycle: {:?})", + reconnected.ok() + ); + + // Clean up + connector.stop(); +} + +#[tokio::test] +#[ignore] +async fn test_t078_real_sse_stream_reliability() { + let Some(api_url) = get_test_api_url() else { + println!("Skipping test: DIRIGENT_TEST_API_URL not set"); + return; + }; + + let connector = create_real_connector(&api_url); + let mut events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + + // Collect events for 5 seconds + let events_collected = timeout(Duration::from_secs(5), async { + let mut count = 0; + let mut event_types = std::collections::HashMap::new(); + + while let Ok(event) = events.recv().await { + count += 1; + let type_name = match &event { + dirigent_protocol::Event::Connected => "Connected", + dirigent_protocol::Event::Disconnected => "Disconnected", + dirigent_protocol::Event::SessionCreated { .. } => "SessionCreated", + dirigent_protocol::Event::MessageStarted { .. } => "MessageStarted", + dirigent_protocol::Event::SessionUpdate { .. } => "SessionUpdate", + dirigent_protocol::Event::MessageCompleted { .. } => "MessageCompleted", + dirigent_protocol::Event::Error { .. } => "Error", + _ => "Other", + }; + *event_types.entry(type_name).or_insert(0) += 1; + } + + (count, event_types) + }) + .await; + + assert!( + events_collected.is_ok(), + "Should collect events without timeout" + ); + + let (count, event_types) = events_collected.unwrap(); + println!("Collected {} events:", count); + for (type_name, count) in event_types { + println!(" {}: {}", type_name, count); + } + + assert!( + count > 0, + "Should receive at least some events from SSE stream" + ); + + // Clean up + connector.stop(); +} + +#[tokio::test] +#[ignore] +async fn test_t078_real_error_handling() { + // Test with an invalid URL to verify error handling + let connector = create_real_connector("http://invalid-hostname-that-does-not-exist:99999"); + let mut events = connector.subscribe(); + + // Start the connector + let _task_handle = connector.start_task().await; + + // Should receive error events + let error_received = timeout(Duration::from_secs(10), async { + while let Ok(event) = events.recv().await { + if let dirigent_protocol::Event::Error { message } = event { + println!("Received expected error: {}", message); + return true; + } + } + false + }) + .await; + + assert!( + error_received.is_ok() && error_received.unwrap(), + "Should receive error event" + ); + + // State should be Error + tokio::time::sleep(Duration::from_secs(2)).await; + assert!( + matches!(connector.state(), ConnectorState::Error(_)), + "State should be Error, got {:?}", + connector.state() + ); + + // Clean up + connector.stop(); +} diff --git a/crates/dirigent_core/tests/replay_test.rs b/crates/dirigent_core/tests/replay_test.rs new file mode 100644 index 0000000..f60ef03 --- /dev/null +++ b/crates/dirigent_core/tests/replay_test.rs @@ -0,0 +1,176 @@ +//! Integration test: replay archived session into a `MockStream`. +//! +//! Builds a single-backend in-memory (tempdir) archivist, registers a +//! session, appends 10 messages with ascending timestamps, then exercises +//! `replay_session_to_stream` end-to-end. + +use std::sync::Arc; + +use chrono::{Duration as ChronoDuration, Utc}; +use uuid::Uuid; + +use dirigent_archivist::{ + Archivist, MessageRecord, RegisterConnectorRequest, RegisterSessionRequest, + backends::JsonlBackend, +}; +use dirigent_core::sharing::{ + MockStream, + replay::{ReplayOptions, ReplaySpeed, replay_session_to_stream}, +}; +use dirigent_protocol::streaming::{EventOrigin, SessionStream, StreamScope}; + +/// Build an in-memory-ish archivist backed by a tempdir + JsonlBackend. +/// +/// Matches the pattern used by `dirigent_archivist/tests/integration_tests.rs`. +/// The tempdir is leaked for the duration of the test process — acceptable +/// because the test binary exits immediately after. +async fn build_in_memory_archivist() -> Arc<Archivist> { + let temp_dir = std::env::temp_dir().join(format!("core_replay_test_{}", Uuid::now_v7())); + let backend = Arc::new( + JsonlBackend::new(temp_dir.clone()) + .await + .expect("JsonlBackend construction"), + ); + let archivist = Archivist::from_single_backend("main".into(), backend) + .await + .expect("Archivist::from_single_backend"); + Arc::new(archivist) +} + +/// Register a fresh connector + session and append `n` messages with +/// timestamps one second apart. Returns the scroll_id. +async fn seed_session_with_messages(archivist: &Archivist, n: usize) -> Uuid { + let connector_resp = archivist + .register_connector( + RegisterConnectorRequest { + r#type: "OpenCode".to_string(), + title: "Replay Test Connector".to_string(), + client_native_id: format!("replay-test@{}", Uuid::now_v7()), + custom_uid: None, + metadata: serde_json::json!({}), + fingerprint: None, + }, + None, + ) + .await + .expect("register_connector"); + + let session_resp = archivist + .register_session( + RegisterSessionRequest { + connector_uid: connector_resp.connector_uid, + native_session_id: format!("native-{}", Uuid::now_v7()), + title: Some("Replay Test Session".to_string()), + custom_scroll_id: None, + metadata: serde_json::json!({}), + completeness: Default::default(), + parent_scroll_id: None, + is_subagent: false, + continuation: None, + agent_id: None, + subagent_type: None, + spawning_tool_use_id: None, + }, + None, + ) + .await + .expect("register_session"); + + let scroll_id = session_resp.scroll_id; + let base_ts = Utc::now(); + + let messages: Vec<MessageRecord> = (0..n) + .map(|i| { + let role = if i % 2 == 0 { "user" } else { "assistant" }; + MessageRecord { + version: 1, + message_id: Uuid::now_v7(), + session: scroll_id, + parent_id: None, + ts: base_ts + ChronoDuration::seconds(i as i64), + role: role.to_string(), + author: None, + content_md: format!("message {i}"), + content_parts: None, + attachments: vec![], + metadata: serde_json::json!({}), + } + }) + .collect(); + + archivist + .append_messages(scroll_id, messages, None) + .await + .expect("append_messages"); + + scroll_id +} + +#[tokio::test] +async fn replay_delivers_archived_messages_to_stream() { + let archivist = build_in_memory_archivist().await; + let scroll_id = seed_session_with_messages(&archivist, 10).await; + + let mock = MockStream::new("mock", StreamScope::Session { scroll_id }); + let stream: Arc<dyn SessionStream> = mock.clone(); + + let report = replay_session_to_stream( + archivist.as_ref(), + scroll_id, + stream, + ReplayOptions { + include_meta_events: false, + speed: ReplaySpeed::AsFastAsPossible, + }, + ) + .await + .expect("replay_session_to_stream"); + + assert_eq!(report.events_sent, 10, "events_sent"); + assert_eq!(report.failures, 0, "failures"); + assert_eq!(mock.received_count(), 10, "mock received count"); + + let received = mock.received.lock().unwrap(); + for evt in received.iter() { + assert!( + matches!(evt.origin, EventOrigin::Replay { .. }), + "every replayed event must carry EventOrigin::Replay" + ); + assert_eq!( + evt.routing.scroll_id, + Some(scroll_id), + "every replayed event must carry the authoritative scroll_id" + ); + } +} + +#[tokio::test] +async fn replay_continues_on_stream_failure() { + let archivist = build_in_memory_archivist().await; + let scroll_id = seed_session_with_messages(&archivist, 10).await; + + let mock = MockStream::new("mock", StreamScope::Session { scroll_id }); + mock.fail_next(3); + let stream: Arc<dyn SessionStream> = mock.clone(); + + let report = replay_session_to_stream( + archivist.as_ref(), + scroll_id, + stream, + ReplayOptions { + include_meta_events: false, + speed: ReplaySpeed::AsFastAsPossible, + }, + ) + .await + .expect("replay_session_to_stream"); + + // events_sent counts attempted (ok + failed); failures counts Failed only. + assert_eq!(report.events_sent, 10, "events_sent counts every attempt"); + assert_eq!(report.failures, 3, "first 3 events rejected by mock"); + assert_eq!( + mock.received_count(), + 7, + "mock buffer contains the 7 successful events" + ); +} diff --git a/crates/dirigent_core/tests/runtime_test.rs b/crates/dirigent_core/tests/runtime_test.rs new file mode 100644 index 0000000..6e53cf5 --- /dev/null +++ b/crates/dirigent_core/tests/runtime_test.rs @@ -0,0 +1,750 @@ +#![cfg(feature = "server")] +//! Tests for CoreRuntime operations +//! +//! T072: CoreRuntime::create_connector +//! T073: CoreRuntime::start_connector +//! T074: CoreRuntime::stop_connector +//! T075: CoreRuntime::send_command + +use dirigent_core::connectors::{Connector, ConnectorCommand}; +use dirigent_core::types::{ConnectorKind, ConnectorState}; +use dirigent_core::{ConnectorConfig, CoreConfig, CoreError, CoreRuntime}; +use serde_json::json; + +/// Helper to create a test runtime +fn create_test_runtime() -> CoreRuntime { + CoreRuntime::new(CoreConfig::default(), None) +} + +/// Helper to create an OpenCode connector config +fn create_opencode_config(id: Option<String>, title: &str) -> ConnectorConfig { + ConnectorConfig { + id, + kind: ConnectorKind::OpenCode, + owner: None, + title: Some(title.to_string()), + working_directory: None, + params: json!({ + "base_url": "http://localhost:12225", + "title": title, + "initial_session": null + }), + ..Default::default() + } +} + +// ============================================================================ +// T072: CoreRuntime::create_connector +// ============================================================================ + +#[tokio::test] +async fn test_t072_connector_id_auto_generated_if_not_provided() { + let runtime = create_test_runtime(); + let cfg = create_opencode_config(None, "Auto ID Test"); + + let result = runtime.create_connector(uuid::Uuid::nil(), cfg).await; + + assert!(result.is_ok(), "Should create connector successfully"); + + let connector_id = result.unwrap(); + assert!(!connector_id.is_empty(), "Generated ID should not be empty"); + + // Verify it's a valid UUID format (36 chars with hyphens) + assert_eq!(connector_id.len(), 36, "Should be UUID format"); + + // Clean up + runtime.remove_connector(&connector_id).await.ok(); +} + +#[tokio::test] +async fn test_t072_connector_uses_provided_id() { + let runtime = create_test_runtime(); + let cfg = create_opencode_config(Some("my-custom-id".to_string()), "Custom ID Test"); + + let result = runtime.create_connector(uuid::Uuid::nil(), cfg).await; + + assert!(result.is_ok(), "Should create connector successfully"); + assert_eq!(result.unwrap(), "my-custom-id"); + + // Clean up + runtime + .remove_connector(&"my-custom-id".to_string()) + .await + .ok(); +} + +#[tokio::test] +async fn test_t072_already_exists_error_if_id_conflicts() { + let runtime = create_test_runtime(); + let cfg1 = create_opencode_config(Some("duplicate-id".to_string()), "First"); + + // Create first connector + let result1 = runtime + .create_connector(uuid::Uuid::nil(), cfg1.clone()) + .await; + assert!(result1.is_ok(), "First creation should succeed"); + + // Try to create another with the same ID + let result2 = runtime.create_connector(uuid::Uuid::nil(), cfg1).await; + + assert!(result2.is_err(), "Second creation should fail"); + assert_eq!(result2.unwrap_err(), CoreError::AlreadyExists); + + // Clean up + runtime + .remove_connector(&"duplicate-id".to_string()) + .await + .ok(); +} + +#[tokio::test] +async fn test_t072_connector_appears_in_list_after_creation() { + let runtime = create_test_runtime(); + + // Initially empty + let list_before = runtime.list_connectors(None).await; + let initial_count = list_before.len(); + + // Create a connector + let cfg = create_opencode_config(Some("test-conn-1".to_string()), "Test 1"); + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + + // List should now contain it + let list_after = runtime.list_connectors(None).await; + assert_eq!(list_after.len(), initial_count + 1); + + let found = list_after.iter().find(|c| c.id == connector_id); + assert!(found.is_some(), "Created connector should be in list"); + + let connector_summary = found.unwrap(); + assert_eq!(connector_summary.id, "test-conn-1"); + assert_eq!(connector_summary.title, "Test 1"); + assert_eq!(connector_summary.owner, uuid::Uuid::nil()); + assert_eq!(connector_summary.kind, ConnectorKind::OpenCode); + + // Clean up + runtime.remove_connector(&connector_id).await.ok(); +} + +#[tokio::test] +async fn test_t072_invalid_config_returns_error() { + let runtime = create_test_runtime(); + + let cfg = ConnectorConfig { + id: None, + kind: ConnectorKind::OpenCode, + owner: None, + title: Some("Invalid".to_string()), + working_directory: None, + params: json!({ + "invalid": "config" + // Missing required fields like base_url + }), + ..Default::default() + }; + + let result = runtime.create_connector(uuid::Uuid::nil(), cfg).await; + + assert!(result.is_err(), "Should fail with invalid config"); + assert_eq!(result.unwrap_err(), CoreError::InvalidConfig); +} + +#[tokio::test] +async fn test_t072_mock_connector_not_allowed() { + let runtime = create_test_runtime(); + + let cfg = ConnectorConfig { + id: None, + kind: ConnectorKind::Mock, + owner: None, + title: Some("Mock".to_string()), + working_directory: None, + params: json!({}), + ..Default::default() + }; + + let result = runtime.create_connector(uuid::Uuid::nil(), cfg).await; + + assert!( + result.is_err(), + "Mock connectors should not be creatable via API" + ); + assert_eq!(result.unwrap_err(), CoreError::InvalidConfig); +} + +#[tokio::test] +async fn test_t072_owner_override() { + let runtime = create_test_runtime(); + + let mut cfg = create_opencode_config(None, "Owner Test"); + // Try to set owner in config + cfg.owner = Some(uuid::Uuid::from_u128(99)); + + // Create with different owner + let connector_id = runtime + .create_connector(uuid::Uuid::from_u128(42), cfg) + .await + .unwrap(); + + // Verify owner was overridden + let list = runtime.list_connectors(None).await; + let connector = list.iter().find(|c| c.id == connector_id).unwrap(); + assert_eq!( + connector.owner, + uuid::Uuid::from_u128(42), + "Owner should be from parameter, not config" + ); + + // Clean up + runtime.remove_connector(&connector_id).await.ok(); +} + +// ============================================================================ +// T073: CoreRuntime::start_connector +// ============================================================================ + +#[tokio::test] +async fn test_t073_start_connector_not_found() { + let runtime = create_test_runtime(); + + let result = runtime.start_connector(&"nonexistent".to_string()).await; + + assert!(result.is_err(), "Should fail for nonexistent connector"); + assert_eq!(result.unwrap_err(), CoreError::NotFound); +} + +#[tokio::test] +async fn test_t073_start_connector_not_yet_implemented() { + // Note: As per the runtime.rs code, starting an existing connector + // is not yet fully implemented. This test documents the current behavior. + + let runtime = create_test_runtime(); + let cfg = create_opencode_config(Some("start-test".to_string()), "Start Test"); + + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + + // Try to start it + let result = runtime.start_connector(&connector_id).await; + + // Currently returns an error indicating not implemented + assert!( + result.is_err(), + "Starting existing connector not yet supported" + ); + + // Clean up + runtime.remove_connector(&connector_id).await.ok(); +} + +// ============================================================================ +// T074: CoreRuntime::stop_connector +// ============================================================================ + +#[tokio::test] +async fn test_t074_stop_connector_not_found() { + let runtime = create_test_runtime(); + + let result = runtime.stop_connector(&"nonexistent".to_string()).await; + + assert!(result.is_err(), "Should fail for nonexistent connector"); + assert_eq!(result.unwrap_err(), CoreError::NotFound); +} + +#[tokio::test] +async fn test_t074_stop_connector_changes_state_to_stopped() { + let runtime = create_test_runtime(); + let cfg = create_opencode_config(Some("stop-test".to_string()), "Stop Test"); + + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + + // Get the connector and verify initial state + let connector = runtime.get_connector(&connector_id).await.unwrap(); + let initial_state = connector.state(); + assert_eq!(initial_state, ConnectorState::Initializing); + + // Stop it + let result = runtime.stop_connector(&connector_id).await; + assert!(result.is_ok(), "Stop should succeed"); + + // Verify state changed to Stopped + let connector = runtime.get_connector(&connector_id).await.unwrap(); + let final_state = connector.state(); + assert_eq!(final_state, ConnectorState::Stopped); + + // Clean up + runtime.remove_connector(&connector_id).await.ok(); +} + +#[tokio::test] +async fn test_t074_stop_connector_idempotent() { + let runtime = create_test_runtime(); + let cfg = create_opencode_config(Some("stop-twice".to_string()), "Stop Twice"); + + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + + // Stop once + let result1 = runtime.stop_connector(&connector_id).await; + assert!(result1.is_ok(), "First stop should succeed"); + + // Stop again + let result2 = runtime.stop_connector(&connector_id).await; + assert!(result2.is_ok(), "Second stop should also succeed"); + + // State should still be Stopped + let connector = runtime.get_connector(&connector_id).await.unwrap(); + assert_eq!(connector.state(), ConnectorState::Stopped); + + // Clean up + runtime.remove_connector(&connector_id).await.ok(); +} + +// ============================================================================ +// T075: CoreRuntime::send_command +// ============================================================================ + +#[tokio::test] +async fn test_t075_send_command_not_found() { + let runtime = create_test_runtime(); + + let result = runtime + .send_command(&"nonexistent".to_string(), ConnectorCommand::ListSessions) + .await; + + assert!(result.is_err(), "Should fail for nonexistent connector"); + assert_eq!(result.unwrap_err(), CoreError::NotFound); +} + +#[tokio::test] +#[ignore] // TODO: Fix - cmd_rx is dropped when connector is dropped in create_connector +async fn test_t075_send_command_to_connector() { + let runtime = create_test_runtime(); + let cfg = create_opencode_config(Some("cmd-test".to_string()), "Command Test"); + + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + + // Get the connector and subscribe to events + let connector = runtime.get_connector(&connector_id).await.unwrap(); + let _events = connector.subscribe(); + + // Send a command via runtime + let result = runtime + .send_command(&connector_id, ConnectorCommand::ListSessions) + .await; + + assert!(result.is_ok(), "Send command should succeed"); + + // Note: The command won't be processed until the connector is started, + // but we've verified that the command was accepted + + // Clean up + runtime.remove_connector(&connector_id).await.ok(); +} + +#[ignore] // TODO: Fix - cmd_rx is dropped when connector is dropped in create_connector +#[tokio::test] +async fn test_t075_send_all_command_types() { + let runtime = create_test_runtime(); + let cfg = create_opencode_config(Some("all-cmds".to_string()), "All Commands"); + + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + + // Send ListSessions + let result = runtime + .send_command(&connector_id, ConnectorCommand::ListSessions) + .await; + assert!(result.is_ok()); + + // Send ListMessages + let result = runtime + .send_command( + &connector_id, + ConnectorCommand::ListMessages { + session_id: "test-session".to_string(), + }, + ) + .await; + assert!(result.is_ok()); + + // Send SendMessage + let result = runtime + .send_command( + &connector_id, + ConnectorCommand::SendMessage { + session_id: "test-session".to_string(), + text: "Hello".to_string(), + }, + ) + .await; + assert!(result.is_ok()); + + // Send Reconnect + let result = runtime + .send_command(&connector_id, ConnectorCommand::Reconnect) + .await; + assert!(result.is_ok()); + + // Send Shutdown + let result = runtime + .send_command(&connector_id, ConnectorCommand::Shutdown) + .await; + assert!(result.is_ok()); + + // Clean up + runtime.remove_connector(&connector_id).await.ok(); +} + +#[tokio::test] +#[ignore] // TODO: Fix - cmd_rx is dropped when connector is dropped in create_connector +async fn test_t075_send_command_channel_capacity() { + let runtime = create_test_runtime(); + let cfg = create_opencode_config(Some("capacity-test".to_string()), "Capacity Test"); + + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + + // Send many commands (the channel has capacity 100) + for i in 0..50 { + let result = runtime + .send_command(&connector_id, ConnectorCommand::ListSessions) + .await; + assert!(result.is_ok(), "Command {} should be accepted", i); + } + + // Clean up + runtime.remove_connector(&connector_id).await.ok(); +} + +// ============================================================================ +// Additional Runtime Tests +// ============================================================================ + +#[tokio::test] +async fn test_list_connectors_filters_by_owner() { + let runtime = create_test_runtime(); + + // Create connectors for different users + let cfg1 = create_opencode_config(Some("user1-conn1".to_string()), "User 1 Conn 1"); + let cfg2 = create_opencode_config(Some("user1-conn2".to_string()), "User 1 Conn 2"); + let cfg3 = create_opencode_config(Some("user2-conn1".to_string()), "User 2 Conn 1"); + + runtime + .create_connector(uuid::Uuid::nil(), cfg1) + .await + .unwrap(); + runtime + .create_connector(uuid::Uuid::nil(), cfg2) + .await + .unwrap(); + runtime + .create_connector(uuid::Uuid::from_u128(2), cfg3) + .await + .unwrap(); + + // List all + let all = runtime.list_connectors(None).await; + assert!(all.len() >= 3, "Should have at least 3 connectors"); + + // List for user-1 + let user1_list = runtime.list_connectors(Some(uuid::Uuid::nil())).await; + assert_eq!(user1_list.len(), 2, "User 1 should have 2 connectors"); + + // List for user-2 + let user2_list = runtime + .list_connectors(Some(uuid::Uuid::from_u128(2))) + .await; + assert_eq!(user2_list.len(), 1, "User 2 should have 1 connector"); + + // Clean up + runtime + .remove_connector(&"user1-conn1".to_string()) + .await + .ok(); + runtime + .remove_connector(&"user1-conn2".to_string()) + .await + .ok(); + runtime + .remove_connector(&"user2-conn1".to_string()) + .await + .ok(); +} + +#[tokio::test] +async fn test_get_connector_returns_some_if_exists() { + let runtime = create_test_runtime(); + let cfg = create_opencode_config(Some("get-test".to_string()), "Get Test"); + + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + + let result = runtime.get_connector(&connector_id).await; + assert!(result.is_some(), "Should find connector"); + + let connector = result.unwrap(); + assert_eq!(connector.id(), &connector_id); + assert_eq!(connector.title(), "Get Test"); + + // Clean up + runtime.remove_connector(&connector_id).await.ok(); +} + +#[tokio::test] +async fn test_get_connector_returns_none_if_not_exists() { + let runtime = create_test_runtime(); + + let result = runtime.get_connector(&"nonexistent".to_string()).await; + assert!(result.is_none(), "Should not find nonexistent connector"); +} + +#[tokio::test] +async fn test_remove_connector_success() { + let runtime = create_test_runtime(); + let cfg = create_opencode_config(Some("remove-test".to_string()), "Remove Test"); + + let connector_id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .unwrap(); + + // Verify it exists + assert!(runtime.get_connector(&connector_id).await.is_some()); + + // Remove it + let result = runtime.remove_connector(&connector_id).await; + assert!(result.is_ok(), "Remove should succeed"); + + // Verify it's gone + assert!(runtime.get_connector(&connector_id).await.is_none()); +} + +#[tokio::test] +async fn test_sharing_bus_subscribe_all() { + let runtime = create_test_runtime(); + + // Subscribe to every event on the SharingBus — this is the + // replacement for the retired `subscribe_global()` API. + let rx1 = runtime.sharing_bus().subscribe_all().await; + let rx2 = runtime.sharing_bus().subscribe_all().await; + + // Verify both subscriptions are valid + drop(rx1); + drop(rx2); + // If this compiles and runs, subscriptions work +} + +// ============================================================================ +// Issue 2: Zero-Connector State (Regression Test) +// ============================================================================ + +/// Test that new connectors can be added after removing all existing connectors +/// +/// This test verifies the fix for Issue 2: "Removing All Connectors Breaks New Connection" +/// +/// Regression test ensures: +/// - Creating first connector works (0 -> 1 transition) +/// - Removing all connectors works (1 -> 0 transition) +/// - Creating connector after zero state works (0 -> 1 again) +/// - No state corruption or race conditions +/// - Config persistence handles empty state correctly +#[tokio::test] +async fn test_issue_2_create_connector_after_removing_all() { + let runtime = create_test_runtime(); + + // Verify starting state is empty + let initial_connectors = runtime.list_connectors(None).await; + assert_eq!( + initial_connectors.len(), + 0, + "Should start with zero connectors" + ); + + // Create first connector (0 -> 1 transition) + let cfg1 = create_opencode_config(Some("test-connector-1".to_string()), "First Connector"); + let id1 = runtime + .create_connector(uuid::Uuid::nil(), cfg1) + .await + .expect("Should create first connector"); + + let connectors_after_create = runtime.list_connectors(None).await; + assert_eq!( + connectors_after_create.len(), + 1, + "Should have 1 connector after creation" + ); + + // Verify connector is in the list + let connector1 = runtime.get_connector(&id1).await; + assert!(connector1.is_some(), "First connector should exist"); + + // Remove the connector (1 -> 0 transition) + runtime + .remove_connector(&id1) + .await + .expect("Should remove connector successfully"); + + let connectors_after_remove = runtime.list_connectors(None).await; + assert_eq!( + connectors_after_remove.len(), + 0, + "Should have zero connectors after removal" + ); + + // Verify connector is gone + let connector1_after_remove = runtime.get_connector(&id1).await; + assert!( + connector1_after_remove.is_none(), + "First connector should not exist after removal" + ); + + // Create new connector after reaching zero state (0 -> 1 again) + let cfg2 = create_opencode_config(Some("test-connector-2".to_string()), "Second Connector"); + let id2 = runtime + .create_connector(uuid::Uuid::nil(), cfg2) + .await + .expect("Should create connector after removing all"); + + let connectors_final = runtime.list_connectors(None).await; + assert_eq!( + connectors_final.len(), + 1, + "Should have 1 connector after recreation" + ); + + // Verify new connector is in the list + let connector2 = runtime.get_connector(&id2).await; + assert!(connector2.is_some(), "Second connector should exist"); + + // Verify IDs are different + assert_ne!(id1, id2, "New connector should have different ID"); + + // Verify new connector is functional (state check) + let connector2_handle = connector2.unwrap(); + let state = connector2_handle.state(); + assert_eq!( + state, + ConnectorState::Initializing, + "New connector should be initializing" + ); + + // Clean up + runtime.remove_connector(&id2).await.ok(); +} + +/// Test rapid remove-create cycles to check for race conditions +/// +/// This test verifies that repeated transitions between empty and non-empty +/// connector states don't cause issues with config persistence or internal state. +#[tokio::test] +async fn test_issue_2_rapid_remove_create_cycles() { + let runtime = create_test_runtime(); + + // Perform 5 cycles of create -> remove + for i in 0..5 { + let cfg = create_opencode_config( + Some(format!("cycle-connector-{}", i)), + &format!("Cycle Test {}", i), + ); + + // Create connector + let id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .expect(&format!("Should create connector in cycle {}", i)); + + // Verify it exists + let connectors = runtime.list_connectors(None).await; + assert_eq!( + connectors.len(), + 1, + "Should have 1 connector after creation" + ); + + // Remove connector + runtime + .remove_connector(&id) + .await + .expect(&format!("Should remove connector in cycle {}", i)); + + // Verify empty state + let connectors = runtime.list_connectors(None).await; + assert_eq!( + connectors.len(), + 0, + "Should have 0 connectors after removal" + ); + } + + // Final verification: Create one more connector to ensure state is still valid + let final_cfg = create_opencode_config(Some("final-connector".to_string()), "Final Test"); + let final_id = runtime + .create_connector(uuid::Uuid::nil(), final_cfg) + .await + .expect("Should create connector after rapid cycles"); + + let final_connectors = runtime.list_connectors(None).await; + assert_eq!(final_connectors.len(), 1, "Should have 1 connector at end"); + + // Clean up + runtime.remove_connector(&final_id).await.ok(); +} + +/// Test that SharingBus subscriptions work across zero-connector transitions. +/// +/// Ensures SSE subscriptions remain valid when all connectors are removed +/// and new connectors are added. +#[tokio::test] +async fn test_issue_2_global_events_survive_zero_connectors() { + let runtime = create_test_runtime(); + + // Subscribe to events before any connectors exist. + let _rx = runtime.sharing_bus().subscribe_all().await; + + // Create connector + let cfg = create_opencode_config(Some("event-test".to_string()), "Event Test"); + let id = runtime + .create_connector(uuid::Uuid::nil(), cfg) + .await + .expect("Should create connector"); + + // Remove connector (transition to zero) + runtime + .remove_connector(&id) + .await + .expect("Should remove connector"); + + // Subscribe again after zero state + let _rx2 = runtime.sharing_bus().subscribe_all().await; + + // Create another connector + let cfg2 = create_opencode_config(Some("event-test-2".to_string()), "Event Test 2"); + let id2 = runtime + .create_connector(uuid::Uuid::nil(), cfg2) + .await + .expect("Should create connector after zero state"); + + // If we get here, subscriptions survived the zero-connector transition + // Clean up + runtime.remove_connector(&id2).await.ok(); +} diff --git a/crates/dirigent_core/tests/sharing_bus_test.rs b/crates/dirigent_core/tests/sharing_bus_test.rs new file mode 100644 index 0000000..be9c989 --- /dev/null +++ b/crates/dirigent_core/tests/sharing_bus_test.rs @@ -0,0 +1,173 @@ +//! Integration test for SharingBus late-bind and filter routing. +//! +//! Exercises the full publish cycle: subscribe with two different filters, +//! publish three events (one before cache population, one that populates the +//! cache, one that is late-bound from the cache), then assert each subscriber +//! saw exactly the events it should. + +use std::time::Duration; + +use uuid::Uuid; + +use dirigent_core::sharing::bus::SharingBus; +use dirigent_protocol::{ + Message, MessageRole, MessageStatus, Session, SessionMetadata, + conversation::MessagePart, + streaming::{BusEvent, EventFilter}, + Event, +}; + +// ─── Fixtures ──────────────────────────────────────────────────────────────── + +fn make_session(id: &str) -> Session { + let now = chrono::Utc::now(); + Session { + id: id.to_string(), + title: "test-session".to_string(), + created_at: now, + updated_at: now, + metadata: SessionMetadata { + project_path: "/tmp/test".to_string(), + model: None, + total_messages: 0, + system_message: None, + current_mode_id: None, + _meta: None, + project_id: None, + }, + cwd: None, + models: None, + modes: None, + config_options: None, + acp_client_id: None, + } +} + +fn make_message(session_id: &str, msg_id: &str) -> Message { + let now = chrono::Utc::now(); + Message { + id: msg_id.to_string(), + session_id: session_id.to_string(), + role: MessageRole::Assistant, + created_at: now, + content: vec![MessagePart::Text { + text: "hello".to_string(), + }], + status: MessageStatus::Completed, + metadata: None, + } +} + +// ─── Test ───────────────────────────────────────────────────────────────────── + +/// Full late-bind cycle: +/// +/// - `scroll_rx` is filtered by `EventFilter::ScrollId(scroll_id)`. +/// It should receive events #2 and #3 only (not #1, because the cache was +/// empty when #1 was published). +/// +/// - `uid_rx` is filtered by `EventFilter::ConnectorUid(uid)`. +/// It should receive all three events because `connector_uid` is set on +/// every event by `BusEvent::from_connector_event`. +#[tokio::test] +async fn scroll_id_late_bind_via_session_registered() { + let bus = SharingBus::new(); + + let uid = Uuid::new_v4(); + let scroll_id = Uuid::new_v4(); + + let mut scroll_rx = bus + .subscribe_filtered(EventFilter::ScrollId(scroll_id), 32) + .await; + let mut uid_rx = bus + .subscribe_filtered(EventFilter::ConnectorUid(uid), 32) + .await; + + // Event #1: SessionCreated — cache is empty, scroll_id will not be + // late-bound. The ScrollId subscriber must NOT see this event. + bus.publish(BusEvent::from_connector_event( + Event::SessionCreated { + connector_id: "mock".into(), + session: make_session("abc"), + }, + Some(uid), + "mock".into(), + )) + .await; + + // Event #2: SessionRegistered — populates the cache + // (`connector_id="mock"`, `session_id="abc"`) → `scroll_id`. + // The bus also sets `routing.scroll_id` on this event itself, so the + // ScrollId subscriber DOES see it. + bus.publish(BusEvent::from_connector_event( + Event::SessionRegistered { + connector_id: "mock".into(), + session_id: "abc".into(), + scroll_id: scroll_id.to_string(), + }, + Some(uid), + "mock".into(), + )) + .await; + + // Event #3: MessageCompleted — no scroll_id on entry, but the bus + // late-binds it from the cache via (connector_id="mock", session_id="abc"). + // The ScrollId subscriber DOES see it. + bus.publish(BusEvent::from_connector_event( + Event::MessageCompleted { + connector_id: "mock".into(), + message: make_message("abc", "m1"), + }, + Some(uid), + "mock".into(), + )) + .await; + + // Give the bus worker time to dispatch all three events. + tokio::time::sleep(Duration::from_millis(50)).await; + + // Drain what each subscriber received. + let mut scroll_events: Vec<BusEvent> = Vec::new(); + while let Ok(Some(e)) = + tokio::time::timeout(Duration::from_millis(20), scroll_rx.rx.recv()).await + { + scroll_events.push(e); + } + + let mut uid_events: Vec<BusEvent> = Vec::new(); + while let Ok(Some(e)) = + tokio::time::timeout(Duration::from_millis(20), uid_rx.rx.recv()).await + { + uid_events.push(e); + } + + // ScrollId subscriber: only events #2 and #3. + assert_eq!( + scroll_events.len(), + 2, + "ScrollId subscriber should see events #2 (SessionRegistered) and #3 (MessageCompleted) only; got {:?}", + scroll_events.iter().map(|e| format!("{:?}", e.event)).collect::<Vec<_>>() + ); + assert!( + matches!(scroll_events[0].event.as_ref(), Event::SessionRegistered { .. }), + "first scroll event should be SessionRegistered" + ); + assert!( + matches!(scroll_events[1].event.as_ref(), Event::MessageCompleted { .. }), + "second scroll event should be MessageCompleted" + ); + // Both must carry the correct scroll_id. + assert_eq!(scroll_events[0].routing.scroll_id, Some(scroll_id)); + assert_eq!(scroll_events[1].routing.scroll_id, Some(scroll_id)); + + // ConnectorUid subscriber: all three. + assert_eq!( + uid_events.len(), + 3, + "ConnectorUid subscriber should see all three events; got {:?}", + uid_events.iter().map(|e| format!("{:?}", e.event)).collect::<Vec<_>>() + ); + for ev in &uid_events { + assert_eq!(ev.routing.connector_uid, Some(uid)); + } +} diff --git a/crates/dirigent_core/tests/stream_registry_test.rs b/crates/dirigent_core/tests/stream_registry_test.rs new file mode 100644 index 0000000..91feb84 --- /dev/null +++ b/crates/dirigent_core/tests/stream_registry_test.rs @@ -0,0 +1,156 @@ +//! Integration tests for `StreamRegistry`: scope-based filtering, health +//! drift, and `detach()` shutdown semantics. +//! +//! Uses the `MockStream` from `dirigent_core::sharing` (enabled via the +//! `test-utils` feature; see `required-features` in `Cargo.toml`). + +use std::sync::Arc; +use std::time::Duration; + +use uuid::Uuid; + +use dirigent_core::sharing::MockStream; +use dirigent_core::sharing::bus::SharingBus; +use dirigent_core::sharing::registry::StreamRegistry; +use dirigent_protocol::{ + Event, + streaming::{ + BusEvent, EventKind, EventOrigin, EventRouting, StreamScope, + }, +}; + +// ─── Helpers ──────────────────────────────────────────────────────────────── + +/// Build a `BusEvent` with an explicit `scroll_id` in the routing so the +/// bus does not need to consult its late-bind cache. `connector_uid` is +/// populated too so tests can mix-and-match filters. +fn make_scoped_event(scroll_id: Uuid, connector_uid: Uuid) -> BusEvent { + BusEvent { + routing: EventRouting { + scroll_id: Some(scroll_id), + connector_uid: Some(connector_uid), + connector_id: Some("conn-test".to_string()), + native_session_id: None, + kind: EventKind::System, + }, + origin: EventOrigin::Runtime, + event: Arc::new(Event::Connected), + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn mock_stream_receives_only_session_scoped_events() { + let bus = SharingBus::new(); + let registry = StreamRegistry::new(Arc::clone(&bus)); + + let scroll_x = Uuid::now_v7(); + let scroll_y = Uuid::now_v7(); + let connector_uid = Uuid::now_v7(); + + let mock = MockStream::new("mock", StreamScope::Session { + scroll_id: scroll_x, + }); + let _id = registry.attach("mock".to_string(), mock.clone()).await; + + // One event in-scope (scroll_x), two out-of-scope (scroll_y). + bus.publish(make_scoped_event(scroll_x, connector_uid)) + .await; + bus.publish(make_scoped_event(scroll_y, connector_uid)) + .await; + bus.publish(make_scoped_event(scroll_y, connector_uid)) + .await; + + // Give the bus worker + stream worker a chance to process. + tokio::time::sleep(Duration::from_millis(50)).await; + + let count = mock.received_count(); + assert_eq!( + count, + 1, + "expected 1 in-scope event, got {count}", + ); + assert_eq!( + mock.received.lock().unwrap()[0].routing.scroll_id, + Some(scroll_x) + ); +} + +#[tokio::test] +async fn health_drifts_to_unavailable_after_five_failures() { + use dirigent_core::sharing::HealthStatus; + + let bus = SharingBus::new(); + let registry = StreamRegistry::new(Arc::clone(&bus)); + + let scroll_id = Uuid::now_v7(); + let connector_uid = Uuid::now_v7(); + + let mock = MockStream::new("mock", StreamScope::Session { scroll_id }); + mock.fail_next(5); + let _id = registry.attach("mock".to_string(), mock.clone()).await; + + // Five consecutive failures should drift Healthy → Unavailable. + for _ in 0..5 { + bus.publish(make_scoped_event(scroll_id, connector_uid)) + .await; + } + + tokio::time::sleep(Duration::from_millis(100)).await; + + let infos = registry.list().await; + assert_eq!(infos.len(), 1); + match &infos[0].health { + HealthStatus::Unavailable { reason } => { + assert!( + reason.contains("5 failures"), + "expected reason to mention the failure count, got: {reason}" + ); + } + other => panic!("expected Unavailable after 5 failures, got {:?}", other), + } + assert_eq!(infos[0].lagged_count, 5); +} + +#[tokio::test] +async fn detach_invokes_shutdown_and_stops_delivery() { + let bus = SharingBus::new(); + let registry = StreamRegistry::new(Arc::clone(&bus)); + + let scroll_id = Uuid::now_v7(); + let connector_uid = Uuid::now_v7(); + + let mock = MockStream::new("mock", StreamScope::Session { scroll_id }); + let id = registry.attach("mock".to_string(), mock.clone()).await; + + // One in-scope event before detach to prove delivery works at all. + bus.publish(make_scoped_event(scroll_id, connector_uid)) + .await; + tokio::time::sleep(Duration::from_millis(50)).await; + assert_eq!(mock.received_count(), 1, "pre-detach delivery failed"); + + // Detach — worker should stop, shutdown should run. + let reg = registry.detach(id).await.expect("stream should exist"); + // Wait for the worker to finish so we know the post-detach publish + // cannot possibly reach the stream. + let _ = tokio::time::timeout(Duration::from_millis(500), async { + let handle = ®.worker; + while !handle.is_finished() { + tokio::task::yield_now().await; + } + }) + .await; + + // Publish post-detach; the stream must not receive it. + bus.publish(make_scoped_event(scroll_id, connector_uid)) + .await; + tokio::time::sleep(Duration::from_millis(50)).await; + + assert_eq!( + mock.received_count(), + 1, + "stream received an event after detach" + ); + assert!(registry.list().await.is_empty()); +} diff --git a/crates/dirigent_core/tests/transport_integration.rs b/crates/dirigent_core/tests/transport_integration.rs new file mode 100644 index 0000000..a5b2e98 --- /dev/null +++ b/crates/dirigent_core/tests/transport_integration.rs @@ -0,0 +1,153 @@ +//! Integration tests for ACP transport layer. +//! +//! These tests verify both stdio and HTTP transport implementations work correctly. + +use dirigent_core::acp::transport::{ + JsonRpcError, JsonRpcErrorResponse, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, + JsonRpcResult, TransportError, TransportState, +}; + +/// Test JSON-RPC request creation and serialization. +#[test] +fn test_jsonrpc_request() { + let req = JsonRpcRequest::new( + 1, + "test_method", + Some(serde_json::json!({"param1": "value1"})), + ); + + assert_eq!(req.jsonrpc, "2.0"); + assert_eq!(req.id, serde_json::Value::Number(1.into())); + assert_eq!(req.method, "test_method"); + + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"jsonrpc\":\"2.0\"")); + assert!(json.contains("\"id\":1")); + assert!(json.contains("\"method\":\"test_method\"")); + assert!(json.contains("\"param1\":\"value1\"")); +} + +/// Test JSON-RPC notification creation and serialization. +#[test] +fn test_jsonrpc_notification() { + let notif = JsonRpcNotification::new( + "test_notification", + Some(serde_json::json!({"data": "test"})), + ); + + assert_eq!(notif.jsonrpc, "2.0"); + assert_eq!(notif.method, "test_notification"); + + let json = serde_json::to_string(¬if).unwrap(); + assert!(json.contains("\"jsonrpc\":\"2.0\"")); + assert!(json.contains("\"method\":\"test_notification\"")); + // Notifications don't have an "id" field + assert!(!json.contains("\"id\"")); +} + +/// Test JSON-RPC response deserialization. +#[test] +fn test_jsonrpc_response_deserialization() { + let json = r#"{"jsonrpc":"2.0","id":1,"result":{"status":"ok"}}"#; + let response: JsonRpcResponse = serde_json::from_str(json).unwrap(); + + assert_eq!(response.jsonrpc, "2.0"); + assert_eq!(response.id, serde_json::Value::Number(1.into())); + assert_eq!(response.result.get("status").unwrap(), "ok"); +} + +/// Test JSON-RPC error response deserialization. +#[test] +fn test_jsonrpc_error_response_deserialization() { + let json = r#"{"jsonrpc":"2.0","id":1,"error":{"code":-32600,"message":"Invalid Request"}}"#; + let error_response: JsonRpcErrorResponse = serde_json::from_str(json).unwrap(); + + assert_eq!(error_response.jsonrpc, "2.0"); + assert_eq!(error_response.id, serde_json::Value::Number(1.into())); + assert_eq!(error_response.error.code, -32600); + assert_eq!(error_response.error.message, "Invalid Request"); +} + +/// Test JsonRpcResult helper methods. +#[test] +fn test_jsonrpc_result_helpers() { + // Test success result + let success_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::Value::Number(1.into()), + result: serde_json::json!({"status": "ok"}), + }; + let result = JsonRpcResult::Success(success_response); + + assert!(result.is_success()); + assert!(!result.is_error()); + assert!(result.result().is_some()); + assert!(result.error().is_none()); + assert_eq!(result.result().unwrap().get("status").unwrap(), "ok"); + + // Test error result + let error_response = JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::Value::Number(2.into()), + error: JsonRpcError { + code: -32600, + message: "Invalid Request".to_string(), + data: None, + }, + }; + let result = JsonRpcResult::Error(error_response); + + assert!(!result.is_success()); + assert!(result.is_error()); + assert!(result.result().is_none()); + assert!(result.error().is_some()); + assert_eq!(result.error().unwrap().code, -32600); +} + +/// Test JsonRpcResult conversion to Result. +#[test] +fn test_jsonrpc_result_into_result() { + // Success case + let success_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::Value::Number(1.into()), + result: serde_json::json!({"status": "ok"}), + }; + let result = JsonRpcResult::Success(success_response); + let converted = result.into_result(); + assert!(converted.is_ok()); + + // Error case + let error_response = JsonRpcErrorResponse { + jsonrpc: "2.0".to_string(), + id: serde_json::Value::Number(2.into()), + error: JsonRpcError { + code: -32600, + message: "Invalid Request".to_string(), + data: None, + }, + }; + let result = JsonRpcResult::Error(error_response); + let converted = result.into_result(); + assert!(converted.is_err()); + + match converted { + Err(TransportError::JsonRpcError(error)) => { + assert_eq!(error.code, -32600); + assert_eq!(error.message, "Invalid Request"); + } + _ => panic!("Expected JsonRpcError"), + } +} + +/// Test transport state enum. +#[test] +fn test_transport_state() { + assert_eq!(TransportState::Disconnected, TransportState::Disconnected); + assert_ne!(TransportState::Connected, TransportState::Disconnected); + + let state = TransportState::Connecting; + assert!(matches!(state, TransportState::Connecting)); +} + +// Note: SSE event extraction is tested in http.rs unit tests diff --git a/crates/dirigent_fermata/CLAUDE.md b/crates/dirigent_fermata/CLAUDE.md new file mode 100644 index 0000000..49ffa67 --- /dev/null +++ b/crates/dirigent_fermata/CLAUDE.md @@ -0,0 +1,34 @@ +# Package: dirigent_fermata + +Harness-agnostic policy gate for AI coding agents. + +## Quick Facts +- **Type**: Library + binary (`fermata`) +- **Main Entry**: `src/lib.rs`, `src/bin/fermata.rs` +- **Dependencies**: `ignore`, `toml`, `regex`, `globset`, `serde`, `clap` (cli feature) +- **Status**: v0.1 — library + CLI + Claude hook adapter + +## Layering + +Three concentric layers; nothing inner imports from anything outer. + +- **`core/`** — harness-unaware, transport-unaware, sync. Types (`Op`, `Decision`), `.botignore` walker, `botignore.toml` parser, `Policy::check` / `check_command`, path extraction. Sync, no tokio. +- **`harness/`** — `HarnessAdapter` trait over a normalized `ToolCall`. Each adapter (Claude, future Codex, etc.) lives in its own submodule, feature-gated. +- **`bin/fermata.rs`** — only place where `clap`, stdio, and exit codes appear. + +## Release Model + +Developed in this monorepo; planned to be exported as a standalone repo in the future for advertising / external distribution. Development stays here. See `docs/tools/fermata.md`. + +## Dependency Direction + +`dirigent_tools` depends on `dirigent_fermata`, never the reverse. Fermata must remain usable as a standalone hook/MCP without dragging in the in-process ACP tool runtime. + +## Out of scope (v0.1) + +Codex / Gemini hook adapters, MCP server mode, PostToolUse envelope, `readonly_only` Bash mode, audit log, filesystem watcher. Each is a future task with its own plan. + +## See also + +- `docs/tools/fermata.md` — Dirigent integration plan +- `docs/workpad/brainstorm/fermata.md` — canonical product spec diff --git a/crates/dirigent_fermata/Cargo.toml b/crates/dirigent_fermata/Cargo.toml new file mode 100644 index 0000000..779a399 --- /dev/null +++ b/crates/dirigent_fermata/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "dirigent_fermata" +version = "0.1.0" +edition = "2021" +rust-version = "1.75" +description = "Harness-agnostic policy gate for AI coding agents (.botignore + botignore.toml)" +license = "MIT OR Apache-2.0" +repository = "https://git.g4b.org/dirigence/fermata" +readme = "README.md" +keywords = ["ai", "agents", "security", "policy", "gitignore"] +categories = ["command-line-utilities", "development-tools"] + +[lib] +path = "src/lib.rs" + +[[bin]] +name = "fermata" +path = "src/bin/fermata.rs" +required-features = ["cli"] + +[dependencies] +globset = "0.4" +ignore = "0.4" +walkdir = "2" +toml = "0.8" +regex = "1.10" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "2.0" +clap = { version = "4.5", features = ["derive"], optional = true } + +[dev-dependencies] +tempfile = "3.10" +assert_cmd = "2.0" +predicates = "3.1" + +[features] +default = ["cli", "harness-claude"] +cli = ["dep:clap"] +harness-claude = [] diff --git a/crates/dirigent_fermata/LICENSE-APACHE b/crates/dirigent_fermata/LICENSE-APACHE new file mode 100644 index 0000000..b2ea092 --- /dev/null +++ b/crates/dirigent_fermata/LICENSE-APACHE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for describing the origin of the Work and + reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Support. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or support. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2026 Gabor Körber and contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/crates/dirigent_fermata/LICENSE-MIT b/crates/dirigent_fermata/LICENSE-MIT new file mode 100644 index 0000000..0440791 --- /dev/null +++ b/crates/dirigent_fermata/LICENSE-MIT @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Gabor Körber and contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/crates/dirigent_fermata/README.md b/crates/dirigent_fermata/README.md new file mode 100644 index 0000000..7dc4df3 --- /dev/null +++ b/crates/dirigent_fermata/README.md @@ -0,0 +1,214 @@ +# 𝄐 dirigent_fermata + +**A fast, harness-agnostic policy gate for AI coding agents.** + +Drop a `.botignore` file in your project root. Fermata reads it and blocks your agent from reading, writing, or running things it shouldn't — before the tool call happens. + +``` +.env +.env.* +secrets/** +conf/localsettings.yaml +``` + +That's all it takes. + +--- + +## Why Fermata + +AI coding agents are powerful, but they don't have an innate sense of "don't touch `.env`." Native hook systems in tools like Claude Code let you intercept every file operation — but wiring up your own secure, fast hook for each project is friction. Fermata is that hook, ready to drop in. + +- **Fast** — written in Rust; ~1–5ms per call. Hooks fire on every read, write, and bash operation. Python cold-start (~50–150ms) compounds fast. Fermata doesn't. +- **Familiar syntax** — `.botignore` uses gitignore rules via the `ignore` crate (the same engine powering ripgrep). +- **Per-operation control** — `botignore.toml` lets you block writes to `vendor/**` while still allowing reads, or deny specific bash patterns without touching path rules. +- **Harness-agnostic** — plain CLI exit codes work from any shell wrapper; the hook adapter speaks Claude Code's JSON natively. + +--- + +## Status: v0.1 + +| Component | Status | +|-----------|--------| +| Library (`Op`, `Decision`, `Policy::check`, `Policy::check_command`) | Done | +| `.botignore` walker (project-root walk-up, gitignore semantics) | Done | +| `botignore.toml` parser (read / write / bash namespaces) | Done | +| Path identification heuristics | Done | +| CLI: `fermata check <path>...` | Done | +| CLI: `fermata hook --harness claude` | Done | +| Claude Code PreToolUse adapter | Done | + +Out of scope for v0.1: Codex / Gemini hook adapters, MCP server mode, audit log, filesystem watcher. + +--- + +## Install + +From source (this monorepo): + +```bash +cargo install --path crates/dirigent_fermata --features cli +``` + +This installs the `fermata` binary into `~/.cargo/bin/`. + +--- + +## Usage + +### Checking a path + +```bash +fermata check --op read /path/to/.env +# exit 1 — blocked +# stderr: blocked by rule ".env" in /your/project/.botignore + +fermata check --op write /path/to/src/main.rs +# exit 0 — allowed +``` + +### Claude Code hook adapter + +```bash +fermata hook --harness claude < hook_payload.json +``` + +Reads the PreToolUse JSON from stdin, extracts the tool name and path or command, applies policy, and emits the Claude-shaped JSON response. The hook's exit code is always `0`; the verdict is in the JSON body. + +--- + +## Configuration + +### `.botignore` — the 80% case + +Create a `.botignore` at your project root. Gitignore syntax. Blocks both reads and writes. + +```gitignore +# Secrets +.env +.env.* +secrets/** + +# Local config overrides +conf/localsettings.yaml +conf/localtestsettings.yaml + +# Generated files — let the tools rebuild them, not patch them +dist/** +*.lock +``` + +Fermata walks up from the target file to find the nearest `.botignore`, so it works correctly even when an agent changes directory. + +### `botignore.toml` — per-operation rules + +For cases where `.botignore`'s uniform read+write block isn't granular enough: + +```toml +[read] +# Block reading secrets outright +patterns = [".env*", "secrets/**", "conf/localsettings.yaml"] + +[write] +# Allow reading vendor code but block patching it +patterns = ["vendor/**", "*.lock"] + +[bash] +# Hard-block destructive or exfiltrating commands +deny = [ + "rm -rf /", + "curl * | sh", + "git push --force*", +] +# Ask before any removal or move +ask = ["rm:*", "mv:*"] +# Narrow allowlist for automated commands +allow_prefixes = ["make test", "git checkout:*"] +``` + +--- + +## How it fits into Claude Code + +Add fermata as a `PreToolUse` hook in `.claude/settings.json`: + +```json +{ + "hooks": { + "PreToolUse": [ + { + "matcher": "Bash|Read|Edit|Write", + "hooks": [ + { + "type": "command", + "command": "fermata hook --harness claude" + } + ] + } + ] + } +} +``` + +When Claude attempts a `Read(.env)`, `Write(vendor/foo.js)`, or `Bash(rm ./secrets/key.pem)`, fermata intercepts the call, checks policy, and returns a deny with a human-readable reason — before any damage is done. + +--- + +## Real-world scenario + +A project has `.env`, `conf/localsettings.yaml`, and a `vendor/` tree it doesn't want patched. With `.botignore`: + +```gitignore +.env +.env.* +conf/localsettings.yaml +vendor/** +``` + +Claude attempts to read credentials: + +``` +Tool: Read +Path: ./conf/localsettings.yaml +Decision: BLOCK — matched rule "conf/localsettings.yaml" (.botignore) +``` + +Claude attempts to read application code: + +``` +Tool: Read +Path: ./src/app/main.rs +Decision: ALLOW +``` + +Claude attempts to run `cat .env` via bash — which would bypass a path-only check: + +```toml +# botignore.toml +[bash] +deny = ["cat .env*", "cat conf/localsettings*"] +``` + +``` +Tool: Bash +Command: cat .env +Decision: BLOCK — matched bash deny rule "cat .env*" +``` + +--- + +## Architecture + +Three concentric layers; nothing inner imports from anything outer: + +- **`core/`** — harness-unaware, sync. Types, `.botignore` walker, `botignore.toml` parser, `Policy::check` / `check_command`, path extraction. +- **`harness/`** — `HarnessAdapter` trait over a normalized `ToolCall`. Each adapter lives in its own submodule, feature-gated. +- **`bin/fermata.rs`** — the only place `clap`, stdio, and exit codes appear. + +--- + +## See also + +- `docs/tools/fermata.md` — Dirigent integration plan +- `docs/workpad/brainstorm/fermata.md` — full product spec and field notes +- `docs/architecture/crates.md` — crate dependency map diff --git a/crates/dirigent_fermata/src/bin/fermata.rs b/crates/dirigent_fermata/src/bin/fermata.rs new file mode 100644 index 0000000..23dbcb5 --- /dev/null +++ b/crates/dirigent_fermata/src/bin/fermata.rs @@ -0,0 +1,205 @@ +use clap::{Parser, Subcommand, ValueEnum}; +use dirigent_fermata::core::{project::find_project_root, Decision, Op, Policy}; +use std::io::{Read, Write}; +use std::path::PathBuf; +use std::process::ExitCode; + +#[derive(Parser)] +#[command(name = "fermata", about = "Harness-agnostic policy gate for AI coding agents")] +struct Cli { + #[command(subcommand)] + cmd: Cmd, +} + +#[derive(Subcommand)] +enum Cmd { + /// Check whether `path` is allowed for the given `--op`. + Check { + #[arg(long, value_enum, default_value_t = OpArg::Read)] + op: OpArg, + #[arg(long)] + json: bool, + paths: Vec<PathBuf>, + }, + /// Read a harness hook payload from stdin and render the decision. + Hook { + #[arg(long)] + harness: String, + }, +} + +#[derive(Copy, Clone, ValueEnum)] +enum OpArg { + Read, + Write, + Execute, +} + +impl From<OpArg> for Op { + fn from(a: OpArg) -> Self { + match a { + OpArg::Read => Op::Read, + OpArg::Write => Op::Write, + OpArg::Execute => Op::Execute, + } + } +} + +fn main() -> ExitCode { + let cli = Cli::parse(); + match cli.cmd { + Cmd::Check { op, json, paths } => run_check(op.into(), json, &paths), + Cmd::Hook { harness } => run_hook(&harness), + } +} + +fn run_check(op: Op, json: bool, paths: &[PathBuf]) -> ExitCode { + let mut worst: Option<Decision> = None; + for p in paths { + let root = match find_project_root(p) { + Some(r) => r, + None => continue, + }; + let policy = match Policy::load(&root) { + Ok(p) => p, + Err(e) => { + eprintln!("fermata: load error: {e}"); + return ExitCode::from(2); + } + }; + let d = match policy.check(op, p) { + Ok(d) => d, + Err(e) => { + eprintln!("fermata: check error: {e}"); + return ExitCode::from(2); + } + }; + worst = Some(merge_worst(worst.take(), d)); + } + let decision = worst.unwrap_or(Decision::Allow); + if json { + let _ = serde_json::to_writer(std::io::stdout().lock(), &decision); + let _ = writeln!(std::io::stdout().lock()); + } else if let Decision::Deny(ref r) = decision { + println!("{}", r.message); + } else if let Decision::Ask(ref r) = decision { + println!("ASK: {}", r.message); + } + match decision { + Decision::Allow => ExitCode::from(0), + Decision::Ask(_) => ExitCode::from(0), + Decision::Deny(_) => ExitCode::from(1), + } +} + +fn run_hook(harness: &str) -> ExitCode { + let adapter = match dirigent_fermata::harness::lookup(harness) { + Some(a) => a, + None => { + eprintln!("fermata: unknown harness '{harness}'"); + return ExitCode::from(2); + } + }; + let mut buf = Vec::new(); + if let Err(e) = std::io::stdin().lock().read_to_end(&mut buf) { + eprintln!("fermata: stdin: {e}"); + return ExitCode::from(2); + } + let call = match adapter.parse_request(&buf) { + Ok(c) => c, + Err(e) => { + eprintln!("fermata: parse: {e}"); + return ExitCode::from(2); + } + }; + + use dirigent_fermata::harness::{PathKind, ToolOp}; + let decision = match &call.op { + ToolOp::Path { path, kind } => { + let root = match find_project_root(path) { + // No project root → fail-open allow (hook must always exit 0 with a verdict). + // run_check silently skips these paths; here we must still emit JSON. + Some(r) => r, + None => { + let out = adapter.render_decision(&call, &Decision::Allow).unwrap_or_default(); + let _ = std::io::stdout().lock().write_all(&out); + return ExitCode::from(0); + } + }; + let policy = match Policy::load(&root) { + Ok(p) => p, + Err(e) => { + eprintln!("fermata: load error: {e}"); + let out = adapter.render_decision(&call, &Decision::Allow).unwrap_or_default(); + let _ = std::io::stdout().lock().write_all(&out); + return ExitCode::from(0); + } + }; + let op = match kind { + PathKind::Read => Op::Read, + PathKind::Write => Op::Write, + }; + match policy.check(op, path) { + Ok(d) => d, + Err(e) => { + eprintln!("fermata: check error: {e}"); + let out = adapter.render_decision(&call, &Decision::Allow).unwrap_or_default(); + let _ = std::io::stdout().lock().write_all(&out); + return ExitCode::from(0); + } + } + } + ToolOp::Command { text } => { + // For commands, we look up the project from cwd (no path argument). + let cwd = match std::env::current_dir() { + Ok(d) => d, + Err(e) => { + eprintln!("fermata: cwd error: {e}"); + let out = adapter.render_decision(&call, &Decision::Allow).unwrap_or_default(); + let _ = std::io::stdout().lock().write_all(&out); + return ExitCode::from(0); + } + }; + match find_project_root(&cwd) { + // No project root → fail-open allow (see Path branch note above). + None => Decision::Allow, + Some(root) => { + let policy = match Policy::load(&root) { + Ok(p) => p, + Err(e) => { + eprintln!("fermata: load error: {e}"); + let out = adapter.render_decision(&call, &Decision::Allow).unwrap_or_default(); + let _ = std::io::stdout().lock().write_all(&out); + return ExitCode::from(0); + } + }; + match policy.check_command(text) { + Ok(d) => d, + Err(e) => { + eprintln!("fermata: check error: {e}"); + let out = adapter.render_decision(&call, &Decision::Allow).unwrap_or_default(); + let _ = std::io::stdout().lock().write_all(&out); + return ExitCode::from(0); + } + } + } + } + } + }; + let out = adapter.render_decision(&call, &decision).unwrap_or_default(); + let _ = std::io::stdout().lock().write_all(&out); + ExitCode::from(0) // hook bins always exit 0; the JSON carries the verdict +} + +fn merge_worst(a: Option<Decision>, b: Decision) -> Decision { + let rank = |d: &Decision| match d { + Decision::Allow => 0, + Decision::Ask(_) => 1, + Decision::Deny(_) => 2, + }; + match a { + None => b, + Some(a) if rank(&a) >= rank(&b) => a, + Some(_) => b, + } +} diff --git a/crates/dirigent_fermata/src/core/botignore.rs b/crates/dirigent_fermata/src/core/botignore.rs new file mode 100644 index 0000000..4c9ac6c --- /dev/null +++ b/crates/dirigent_fermata/src/core/botignore.rs @@ -0,0 +1,91 @@ +use crate::core::decision::Rule; +use ignore::gitignore::{Gitignore, GitignoreBuilder}; +use std::path::{Path, PathBuf}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum BotignoreError { + #[error("failed to read .botignore: {0}")] + Io(#[from] std::io::Error), + #[error("failed to compile .botignore: {0}")] + Compile(#[source] ignore::Error), +} + +struct ScopedMatcher { + /// Path of the source `.botignore` file. + source: PathBuf, + /// Directory the matcher is rooted at (parent of `source`). + dir: PathBuf, + /// Depth of `dir` (component count) — deeper = more specific. + depth: usize, + matcher: Gitignore, +} + +/// A collection of `.botignore` matchers, one per file discovered under the +/// project root. Each matcher is rooted at its source file's directory so +/// gitignore-style semantics (anchored vs unanchored patterns, per-directory +/// scope) work correctly. At match time, the deepest applicable matcher +/// wins; whitelist (`!` negation) at any depth overrides an ignore at +/// shallower depth. +pub struct BotignoreSet { + matchers: Vec<ScopedMatcher>, +} + +impl BotignoreSet { + /// Walk `root` recursively, building a per-file matcher for every + /// `.botignore` encountered. Empty if none are found. + pub fn load(root: &Path) -> Result<Self, BotignoreError> { + let mut matchers = Vec::new(); + for entry in walkdir::WalkDir::new(root).into_iter().filter_map(Result::ok) { + if !(entry.file_type().is_file() && entry.file_name() == ".botignore") { + continue; + } + let source = entry.path().to_path_buf(); + let dir = source.parent().unwrap_or(root).to_path_buf(); + let mut builder = GitignoreBuilder::new(&dir); + if let Some(err) = builder.add(&source) { + return Err(BotignoreError::Compile(err)); + } + let matcher = builder.build().map_err(BotignoreError::Compile)?; + let depth = dir.components().count(); + matchers.push(ScopedMatcher { + source, + dir, + depth, + matcher, + }); + } + // Shallowest first so iteration applies broader rules then more-specific overrides. + matchers.sort_by_key(|m| m.depth); + Ok(Self { matchers }) + } + + /// Returns `Some(Rule)` if `path` is matched (and not negated by a + /// deeper-scoped whitelist), else `None`. The deepest matcher whose + /// directory contains `path` wins. + pub fn matched(&self, path: &Path) -> Result<Option<Rule>, BotignoreError> { + let is_dir = path.is_dir(); + let mut current: Option<&ScopedMatcher> = None; + let mut current_pattern: Option<String> = None; + + for sm in &self.matchers { + if !path.starts_with(&sm.dir) { + continue; + } + let m = sm.matcher.matched(path, is_dir); + if m.is_ignore() { + current = Some(sm); + current_pattern = m.inner().map(|g| g.original().to_string()); + } else if m.is_whitelist() { + // Deeper whitelist overrides any shallower ignore. + current = None; + current_pattern = None; + } + } + + Ok(current.map(|sm| Rule { + source: sm.source.clone(), + pattern: current_pattern.unwrap_or_default(), + })) + } +} diff --git a/crates/dirigent_fermata/src/core/decision.rs b/crates/dirigent_fermata/src/core/decision.rs new file mode 100644 index 0000000..cd7dae5 --- /dev/null +++ b/crates/dirigent_fermata/src/core/decision.rs @@ -0,0 +1,30 @@ +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Rule { + /// Source file the rule came from (e.g. `/proj/.botignore`). + pub source: PathBuf, + /// Pattern text as it appeared in the source. + pub pattern: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Reason { + pub message: String, + pub rule: Option<Rule>, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "kind", rename_all = "lowercase")] +pub enum Decision { + Allow, + Ask(Reason), + Deny(Reason), +} + +impl Decision { + pub fn is_blocking(&self) -> bool { + matches!(self, Decision::Deny(_)) + } +} diff --git a/crates/dirigent_fermata/src/core/extract.rs b/crates/dirigent_fermata/src/core/extract.rs new file mode 100644 index 0000000..6535513 --- /dev/null +++ b/crates/dirigent_fermata/src/core/extract.rs @@ -0,0 +1,50 @@ +use regex::Regex; +use std::sync::OnceLock; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Confidence { + /// Absolute path or path with explicit separator. + High, + /// Bare filename with extension; could be a word. + Low, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PathCandidate { + pub text: String, + pub confidence: Confidence, +} + +/// Heuristically extract path-like substrings from arbitrary text. +/// Confident matches (absolute paths, paths containing separators) → `High`. +/// Bare filenames with an extension → `Low` (advisory only). +pub fn extract_path_candidates(text: &str) -> Vec<PathCandidate> { + static UNIX_ABS: OnceLock<Regex> = OnceLock::new(); + static WIN_ABS: OnceLock<Regex> = OnceLock::new(); + static REL_WITH_SEP: OnceLock<Regex> = OnceLock::new(); + static BARE_NAME: OnceLock<Regex> = OnceLock::new(); + + let unix_abs = UNIX_ABS.get_or_init(|| Regex::new(r"(?m)(?:^|\s)(/[\w./~\-_]+)").unwrap()); + let win_abs = WIN_ABS.get_or_init(|| Regex::new(r#"(?m)(?:^|\s)([A-Za-z]:\\[\w.\\\-_]+)"#).unwrap()); + let rel = REL_WITH_SEP.get_or_init(|| Regex::new(r"(?m)(?:^|\s)((?:\./|\.\./|[\w\-_]+/)[\w./\-_]+)").unwrap()); + let bare = BARE_NAME.get_or_init(|| Regex::new(r"(?m)(?:^|\s)([\w\-_]+\.[A-Za-z]{1,8})(?:\s|[.,;:!?]|$)").unwrap()); + + let mut out = Vec::new(); + let mut seen = std::collections::HashSet::new(); + + for re in [unix_abs, win_abs, rel] { + for cap in re.captures_iter(text) { + let m = cap.get(1).unwrap().as_str().trim_end_matches(['.', ',', ';', ':', '!', '?']); + if seen.insert(m.to_string()) { + out.push(PathCandidate { text: m.to_string(), confidence: Confidence::High }); + } + } + } + for cap in bare.captures_iter(text) { + let m = cap.get(1).unwrap().as_str(); + if seen.insert(m.to_string()) { + out.push(PathCandidate { text: m.to_string(), confidence: Confidence::Low }); + } + } + out +} diff --git a/crates/dirigent_fermata/src/core/mod.rs b/crates/dirigent_fermata/src/core/mod.rs new file mode 100644 index 0000000..519bf69 --- /dev/null +++ b/crates/dirigent_fermata/src/core/mod.rs @@ -0,0 +1,14 @@ +//! Core policy layer. Harness-unaware, transport-unaware, sync. + +pub mod botignore; +pub mod decision; +pub mod extract; +pub mod op; +pub mod policy; +pub mod project; +pub mod toml_config; + +pub use decision::{Decision, Reason, Rule}; +pub use extract::{extract_path_candidates, Confidence, PathCandidate}; +pub use op::Op; +pub use policy::Policy; diff --git a/crates/dirigent_fermata/src/core/op.rs b/crates/dirigent_fermata/src/core/op.rs new file mode 100644 index 0000000..99e5ee8 --- /dev/null +++ b/crates/dirigent_fermata/src/core/op.rs @@ -0,0 +1,9 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Op { + Read, + Write, + Execute, +} diff --git a/crates/dirigent_fermata/src/core/policy.rs b/crates/dirigent_fermata/src/core/policy.rs new file mode 100644 index 0000000..5479c30 --- /dev/null +++ b/crates/dirigent_fermata/src/core/policy.rs @@ -0,0 +1,164 @@ +use crate::core::botignore::{BotignoreError, BotignoreSet}; +use crate::core::decision::{Decision, Reason, Rule}; +use crate::core::op::Op; +use crate::core::toml_config::{BotignoreToml, TomlConfigError}; +use std::path::{Path, PathBuf}; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum PolicyError { + #[error(transparent)] + Botignore(#[from] BotignoreError), + #[error(transparent)] + Toml(#[from] TomlConfigError), + #[error("invalid pattern in botignore.toml: {0}")] + BadPattern(String), +} + +pub struct Policy { + root: PathBuf, + botignore: BotignoreSet, + toml: BotignoreToml, + read_globs: globset::GlobSet, + write_globs: globset::GlobSet, + read_patterns: Vec<String>, + write_patterns: Vec<String>, +} + +impl Policy { + pub fn load(root: &Path) -> Result<Self, PolicyError> { + let botignore = BotignoreSet::load(root)?; + let toml = BotignoreToml::load(root)?; + + let (read_globs, read_patterns) = compile_globs( + toml.read.as_ref().map(|r| r.patterns.as_slice()).unwrap_or(&[]), + )?; + let (write_globs, write_patterns) = compile_globs( + toml.write.as_ref().map(|r| r.patterns.as_slice()).unwrap_or(&[]), + )?; + + Ok(Self { + root: root.to_path_buf(), + botignore, + toml, + read_globs, + write_globs, + read_patterns, + write_patterns, + }) + } + + pub fn check_command(&self, command: &str) -> Result<Decision, PolicyError> { + let bash = match self.toml.bash.as_ref() { + Some(b) => b, + None => return Ok(Decision::Allow), + }; + + // 1. Deny wins over everything else. + if let Some(pat) = match_command(command, &bash.deny)? { + return Ok(Decision::Deny(Reason { + message: format!("blocked by botignore.toml [bash.deny]: {}", pat), + rule: Some(Rule { + source: self.root.join("botignore.toml"), + pattern: pat, + }), + })); + } + + // 2. Allow prefixes — if any matches, allow. + for prefix in &bash.allow_prefixes { + if command_matches_prefix(command, prefix) { + return Ok(Decision::Allow); + } + } + + // 3. Ask patterns. + if let Some(pat) = match_command(command, &bash.ask)? { + return Ok(Decision::Ask(Reason { + message: format!("requires confirmation [bash.ask]: {}", pat), + rule: Some(Rule { + source: self.root.join("botignore.toml"), + pattern: pat, + }), + })); + } + + Ok(Decision::Allow) + } + + pub fn check(&self, op: Op, path: &Path) -> Result<Decision, PolicyError> { + // 1. .botignore is path-only and applies to read+write equally. + if matches!(op, Op::Read | Op::Write) { + if let Some(rule) = self.botignore.matched(path)? { + return Ok(Decision::Deny(Reason { + message: format!("blocked by .botignore: {}", rule.pattern), + rule: Some(rule), + })); + } + } + + // 2. botignore.toml namespace-specific rules. + let (set, patterns) = match op { + Op::Read => (&self.read_globs, &self.read_patterns), + Op::Write => (&self.write_globs, &self.write_patterns), + Op::Execute => return Ok(Decision::Allow), // path-based check_command handles bash + }; + + let rel = path.strip_prefix(&self.root).unwrap_or(path); + let matches = set.matches(rel); + if let Some(idx) = matches.first() { + let pattern = patterns[*idx].clone(); + return Ok(Decision::Deny(Reason { + message: format!("blocked by botignore.toml [{:?}]: {}", op, pattern), + rule: Some(Rule { + source: self.root.join("botignore.toml"), + pattern, + }), + })); + } + + Ok(Decision::Allow) + } +} + +/// Substring-or-glob match of `command` against `patterns`. +/// Patterns containing glob metachars (`*`, `?`, `[`) are treated as globs; +/// others are matched as literal substrings. +fn match_command(command: &str, patterns: &[String]) -> Result<Option<String>, PolicyError> { + for pat in patterns { + if is_glob(pat) { + let g = globset::Glob::new(pat) + .map_err(|e| PolicyError::BadPattern(e.to_string()))? + .compile_matcher(); + if g.is_match(command) { + return Ok(Some(pat.clone())); + } + } else if command.contains(pat.as_str()) { + return Ok(Some(pat.clone())); + } + } + Ok(None) +} + +fn is_glob(pat: &str) -> bool { + pat.contains('*') || pat.contains('?') || pat.contains('[') +} + +/// `prefix` is `"name"` or `"name:*"`. Both treat `name` as a leading word +/// boundary in `command`. Mirrors Claude Code's `Bash(name:*)` style. +fn command_matches_prefix(command: &str, prefix: &str) -> bool { + let needle = prefix.trim_end_matches(":*"); + command.trim_start().starts_with(needle) +} + +fn compile_globs(patterns: &[String]) -> Result<(globset::GlobSet, Vec<String>), PolicyError> { + let mut builder = globset::GlobSetBuilder::new(); + for pat in patterns { + let glob = globset::Glob::new(pat).map_err(|e| PolicyError::BadPattern(e.to_string()))?; + builder.add(glob); + } + let set = builder + .build() + .map_err(|e| PolicyError::BadPattern(e.to_string()))?; + Ok((set, patterns.to_vec())) +} diff --git a/crates/dirigent_fermata/src/core/project.rs b/crates/dirigent_fermata/src/core/project.rs new file mode 100644 index 0000000..bceba07 --- /dev/null +++ b/crates/dirigent_fermata/src/core/project.rs @@ -0,0 +1,27 @@ +use std::path::{Path, PathBuf}; + +/// Markers checked in priority order when walking up from a target path. +const MARKERS: &[&str] = &["botignore.toml", ".botignore", ".git"]; + +/// Walk upward from `target` (or its parent if `target` is a file) looking +/// for the nearest project root. Roots are identified by the presence of +/// any marker in `MARKERS`. Walks from the **target file's location**, not +/// from cwd, because agents `cd` around. +pub fn find_project_root(target: &Path) -> Option<PathBuf> { + let start = if target.is_file() { + target.parent()? + } else { + target + }; + + let mut current = Some(start); + while let Some(dir) = current { + for marker in MARKERS { + if dir.join(marker).exists() { + return Some(dir.to_path_buf()); + } + } + current = dir.parent(); + } + None +} diff --git a/crates/dirigent_fermata/src/core/toml_config.rs b/crates/dirigent_fermata/src/core/toml_config.rs new file mode 100644 index 0000000..8f2997f --- /dev/null +++ b/crates/dirigent_fermata/src/core/toml_config.rs @@ -0,0 +1,47 @@ +use serde::{Deserialize, Serialize}; +use std::path::Path; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum TomlConfigError { + #[error("io error: {0}")] + Io(#[from] std::io::Error), + #[error("toml parse error: {0}")] + Parse(#[from] toml::de::Error), +} + +#[derive(Debug, Default, Clone, Deserialize, Serialize)] +pub struct OpRules { + #[serde(default)] + pub patterns: Vec<String>, +} + +#[derive(Debug, Default, Clone, Deserialize, Serialize)] +pub struct BashRules { + #[serde(default)] + pub deny: Vec<String>, + #[serde(default)] + pub ask: Vec<String>, + #[serde(default)] + pub allow_prefixes: Vec<String>, +} + +#[derive(Debug, Default, Clone, Deserialize, Serialize)] +pub struct BotignoreToml { + pub read: Option<OpRules>, + pub write: Option<OpRules>, + pub bash: Option<BashRules>, +} + +impl BotignoreToml { + /// Load `<root>/botignore.toml` if present, else return an empty config. + pub fn load(root: &Path) -> Result<Self, TomlConfigError> { + let path = root.join("botignore.toml"); + if !path.exists() { + return Ok(Self::default()); + } + let text = std::fs::read_to_string(&path)?; + let cfg = toml::from_str(&text)?; + Ok(cfg) + } +} diff --git a/crates/dirigent_fermata/src/harness/claude.rs b/crates/dirigent_fermata/src/harness/claude.rs new file mode 100644 index 0000000..e0c3576 --- /dev/null +++ b/crates/dirigent_fermata/src/harness/claude.rs @@ -0,0 +1,76 @@ +//! Claude Code hook adapter (PreToolUse). +//! +//! Wire format: stdin is one JSON object with `tool_name` and `tool_input`. +//! Stdout is `{"hookSpecificOutput": {...}}` with exit code 0; the JSON +//! carries the verdict. + +use super::{AdapterError, HarnessAdapter, PathKind, ToolCall, ToolOp}; +use crate::core::Decision; +use serde_json::{json, Value}; +use std::path::PathBuf; + +pub struct ClaudeAdapter; + +impl HarnessAdapter for ClaudeAdapter { + fn name(&self) -> &'static str { + "claude" + } + + fn parse_request(&self, input: &[u8]) -> Result<ToolCall, AdapterError> { + let v: Value = serde_json::from_slice(input)?; + let tool_name = v + .get("tool_name") + .and_then(|x| x.as_str()) + .ok_or_else(|| AdapterError::Parse("missing tool_name".into()))? + .to_string(); + let tool_input = v.get("tool_input").cloned().unwrap_or(Value::Null); + + let op = match tool_name.as_str() { + "Read" => path_op(&tool_input, PathKind::Read)?, + "Write" | "Edit" | "MultiEdit" => path_op(&tool_input, PathKind::Write)?, + "Bash" => command_op(&tool_input)?, + other => return Err(AdapterError::UnsupportedTool(other.to_string())), + }; + + Ok(ToolCall { + tool_name, + op, + raw: v, + }) + } + + fn render_decision(&self, _call: &ToolCall, decision: &Decision) -> Result<Vec<u8>, AdapterError> { + let (verdict, reason) = match decision { + Decision::Allow => ("allow", String::new()), + Decision::Ask(r) => ("ask", r.message.clone()), + Decision::Deny(r) => ("deny", r.message.clone()), + }; + let out = json!({ + "hookSpecificOutput": { + "hookEventName": "PreToolUse", + "permissionDecision": verdict, + "permissionDecisionReason": reason, + } + }); + Ok(serde_json::to_vec(&out)?) + } +} + +fn path_op(tool_input: &Value, kind: PathKind) -> Result<ToolOp, AdapterError> { + let p = tool_input + .get("file_path") + .and_then(|x| x.as_str()) + .ok_or_else(|| AdapterError::Parse("missing tool_input.file_path".into()))?; + Ok(ToolOp::Path { + path: PathBuf::from(p), + kind, + }) +} + +fn command_op(tool_input: &Value) -> Result<ToolOp, AdapterError> { + let c = tool_input + .get("command") + .and_then(|x| x.as_str()) + .ok_or_else(|| AdapterError::Parse("missing tool_input.command".into()))?; + Ok(ToolOp::Command { text: c.to_string() }) +} diff --git a/crates/dirigent_fermata/src/harness/mod.rs b/crates/dirigent_fermata/src/harness/mod.rs new file mode 100644 index 0000000..347b7b3 --- /dev/null +++ b/crates/dirigent_fermata/src/harness/mod.rs @@ -0,0 +1,67 @@ +//! Harness adapter layer. Normalizes harness-specific payloads into +//! `core` types and renders `Decision` back to harness wire format. + +use crate::core::Decision; +use std::path::PathBuf; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum AdapterError { + #[error("invalid request payload: {0}")] + Parse(String), + #[error("unsupported tool: {0}")] + UnsupportedTool(String), + #[error("io: {0}")] + Io(#[from] std::io::Error), + #[error("json: {0}")] + Json(#[from] serde_json::Error), +} + +/// Normalized tool-call shape consumed by `core::Policy`. +/// Adapters translate harness-specific payloads into this; nothing in +/// `core` knows about adapters. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ToolCall { + /// Harness's tool name (e.g. "Read", "Write", "Edit", "Bash"). + pub tool_name: String, + /// Op classification derived from `tool_name`. + pub op: ToolOp, + /// Original raw payload for the adapter to consult when rendering. + pub raw: serde_json::Value, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ToolOp { + Path { path: PathBuf, kind: PathKind }, + Command { text: String }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PathKind { + Read, + Write, +} + +/// Trait implemented by each harness adapter. Adapters parse the harness's +/// hook stdin payload into `ToolCall` and render a `Decision` back to the +/// harness's expected stdout format. +pub trait HarnessAdapter { + /// The CLI name (e.g. "claude", "codex", "gemini"). + fn name(&self) -> &'static str; + + fn parse_request(&self, input: &[u8]) -> Result<ToolCall, AdapterError>; + + fn render_decision(&self, call: &ToolCall, decision: &Decision) -> Result<Vec<u8>, AdapterError>; +} + +#[cfg(feature = "harness-claude")] +pub mod claude; + +/// Look up a registered adapter by CLI name. +pub fn lookup(name: &str) -> Option<Box<dyn HarnessAdapter>> { + match name { + #[cfg(feature = "harness-claude")] + "claude" => Some(Box::new(claude::ClaudeAdapter)), + _ => None, + } +} diff --git a/crates/dirigent_fermata/src/lib.rs b/crates/dirigent_fermata/src/lib.rs new file mode 100644 index 0000000..4bb9af6 --- /dev/null +++ b/crates/dirigent_fermata/src/lib.rs @@ -0,0 +1,7 @@ +//! `dirigent_fermata` — harness-agnostic policy gate. +//! +//! See `docs/tools/fermata.md` (Dirigent integration plan) and +//! `docs/workpad/brainstorm/fermata.md` (product spec). + +pub mod core; +pub mod harness; diff --git a/crates/dirigent_fermata/tests/cargo_publish_metadata.rs b/crates/dirigent_fermata/tests/cargo_publish_metadata.rs new file mode 100644 index 0000000..6243c56 --- /dev/null +++ b/crates/dirigent_fermata/tests/cargo_publish_metadata.rs @@ -0,0 +1,47 @@ +//! Guards that fermata's Cargo.toml carries the metadata required for +//! `cargo publish` and a useful `cargo install`. Reads the manifest as +//! plain text to avoid pulling cargo internals. + +use std::fs; + +fn manifest() -> String { + fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/Cargo.toml")) + .expect("read Cargo.toml") +} + +#[test] +fn has_license() { + let m = manifest(); + assert!(m.contains("license ="), "Cargo.toml missing `license` field"); +} + +#[test] +fn has_repository() { + let m = manifest(); + assert!(m.contains("repository ="), "Cargo.toml missing `repository` field"); +} + +#[test] +fn has_description() { + let m = manifest(); + assert!(m.contains("description ="), "Cargo.toml missing `description`"); +} + +#[test] +fn has_readme() { + let m = manifest(); + assert!(m.contains("readme ="), "Cargo.toml missing `readme` field"); +} + +#[test] +fn has_keywords_and_categories() { + let m = manifest(); + assert!(m.contains("keywords ="), "Cargo.toml missing `keywords`"); + assert!(m.contains("categories ="), "Cargo.toml missing `categories`"); +} + +#[test] +fn has_rust_version() { + let m = manifest(); + assert!(m.contains("rust-version ="), "Cargo.toml missing `rust-version` (MSRV)"); +} diff --git a/crates/dirigent_fermata/tests/cli_check.rs b/crates/dirigent_fermata/tests/cli_check.rs new file mode 100644 index 0000000..d909902 --- /dev/null +++ b/crates/dirigent_fermata/tests/cli_check.rs @@ -0,0 +1,52 @@ +use assert_cmd::Command; +use predicates::prelude::*; +use std::fs; + +#[test] +fn check_blocks_botignore_match() { + let tmp = tempfile::tempdir().unwrap(); + fs::write(tmp.path().join(".botignore"), ".env\n").unwrap(); + let target = tmp.path().join(".env"); + fs::write(&target, "").unwrap(); + + Command::cargo_bin("fermata") + .unwrap() + .args(["check", "--op", "read", target.to_str().unwrap()]) + .assert() + .failure() + .code(1) + .stdout(predicate::str::contains(".env")); +} + +#[test] +fn check_allows_unmatched() { + let tmp = tempfile::tempdir().unwrap(); + fs::write(tmp.path().join(".botignore"), ".env\n").unwrap(); + let target = tmp.path().join("src.rs"); + fs::write(&target, "").unwrap(); + + Command::cargo_bin("fermata") + .unwrap() + .args(["check", "--op", "read", target.to_str().unwrap()]) + .assert() + .success(); +} + +#[test] +fn check_emits_json_with_flag() { + let tmp = tempfile::tempdir().unwrap(); + fs::write(tmp.path().join(".botignore"), ".env\n").unwrap(); + let target = tmp.path().join(".env"); + fs::write(&target, "").unwrap(); + + let out = Command::cargo_bin("fermata") + .unwrap() + .args(["check", "--op", "read", "--json", target.to_str().unwrap()]) + .assert() + .failure() + .get_output() + .stdout + .clone(); + let v: serde_json::Value = serde_json::from_slice(&out).unwrap(); + assert_eq!(v["kind"], "deny"); +} diff --git a/crates/dirigent_fermata/tests/cli_hook_claude.rs b/crates/dirigent_fermata/tests/cli_hook_claude.rs new file mode 100644 index 0000000..40fe591 --- /dev/null +++ b/crates/dirigent_fermata/tests/cli_hook_claude.rs @@ -0,0 +1,69 @@ +use assert_cmd::Command; +use std::fs; + +#[test] +fn hook_blocks_read_of_botignore_match() { + let tmp = tempfile::tempdir().unwrap(); + fs::write(tmp.path().join(".botignore"), ".env\n").unwrap(); + let target = tmp.path().join(".env"); + fs::write(&target, "").unwrap(); + + let payload = serde_json::json!({ + "tool_name": "Read", + "tool_input": { "file_path": target.to_str().unwrap() } + }) + .to_string(); + + let out = Command::cargo_bin("fermata") + .unwrap() + .args(["hook", "--harness", "claude"]) + .write_stdin(payload) + .assert() + .success() // hook always exits 0 + .get_output() + .stdout + .clone(); + + let v: serde_json::Value = serde_json::from_slice(&out).unwrap(); + assert_eq!(v["hookSpecificOutput"]["permissionDecision"], "deny"); + assert!(v["hookSpecificOutput"]["permissionDecisionReason"] + .as_str() + .unwrap() + .contains(".env")); +} + +#[test] +fn hook_allows_unrelated_read() { + let tmp = tempfile::tempdir().unwrap(); + fs::write(tmp.path().join(".botignore"), ".env\n").unwrap(); + let target = tmp.path().join("src.rs"); + fs::write(&target, "").unwrap(); + + let payload = serde_json::json!({ + "tool_name": "Read", + "tool_input": { "file_path": target.to_str().unwrap() } + }) + .to_string(); + + let out = Command::cargo_bin("fermata") + .unwrap() + .args(["hook", "--harness", "claude"]) + .write_stdin(payload) + .assert() + .success() + .get_output() + .stdout + .clone(); + let v: serde_json::Value = serde_json::from_slice(&out).unwrap(); + assert_eq!(v["hookSpecificOutput"]["permissionDecision"], "allow"); +} + +#[test] +fn hook_unknown_harness_errors() { + Command::cargo_bin("fermata") + .unwrap() + .args(["hook", "--harness", "doesnotexist"]) + .write_stdin("{}") + .assert() + .code(2); +} diff --git a/crates/dirigent_fermata/tests/core_botignore.rs b/crates/dirigent_fermata/tests/core_botignore.rs new file mode 100644 index 0000000..b54ae0f --- /dev/null +++ b/crates/dirigent_fermata/tests/core_botignore.rs @@ -0,0 +1,135 @@ +use dirigent_fermata::core::botignore::BotignoreSet; +use std::fs; +use tempfile::TempDir; + +#[test] +fn matches_simple_pattern() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::write(root.join(".botignore"), ".env\nsecrets/\n").unwrap(); + + let set = BotignoreSet::load(root).unwrap(); + + let env = root.join(".env"); + fs::write(&env, "").unwrap(); + let m = set.matched(&env).unwrap(); + assert!(m.is_some(), ".env should be matched"); + assert_eq!(m.unwrap().pattern, ".env"); +} + +#[test] +fn does_not_match_unrelated_files() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::write(root.join(".botignore"), ".env\n").unwrap(); + + let set = BotignoreSet::load(root).unwrap(); + + let other = root.join("README.md"); + fs::write(&other, "").unwrap(); + assert!(set.matched(&other).unwrap().is_none()); +} + +#[test] +fn negation_pattern_excludes() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::write(root.join(".botignore"), "*.log\n!keep.log\n").unwrap(); + + let set = BotignoreSet::load(root).unwrap(); + + let blocked = root.join("foo.log"); + fs::write(&blocked, "").unwrap(); + assert!(set.matched(&blocked).unwrap().is_some()); + + let allowed = root.join("keep.log"); + fs::write(&allowed, "").unwrap(); + assert!(set.matched(&allowed).unwrap().is_none()); +} + +#[test] +fn empty_or_missing_botignore_is_ok() { + let tmp = TempDir::new().unwrap(); + let set = BotignoreSet::load(tmp.path()).unwrap(); + let any = tmp.path().join("anything.txt"); + std::fs::write(&any, "").unwrap(); + assert!(set.matched(&any).unwrap().is_none()); +} + +#[test] +fn nested_botignore_is_scoped_to_its_directory() { + // A `.botignore` in a subdirectory only applies under that subdirectory, + // matching gitignore semantics: a sibling file with the same name at the + // root is NOT affected. + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::create_dir_all(root.join("frontend")).unwrap(); + fs::write(root.join("frontend/.botignore"), "secret.key\n").unwrap(); + + let set = BotignoreSet::load(root).unwrap(); + + let blocked = root.join("frontend/secret.key"); + fs::write(&blocked, "").unwrap(); + let m = set + .matched(&blocked) + .unwrap() + .expect("frontend/secret.key should match"); + let src = m.source.to_string_lossy().replace('\\', "/"); + assert!( + src.ends_with("frontend/.botignore"), + "Rule.source should point at the nested file; was {}", + src, + ); + + let unblocked = root.join("secret.key"); + fs::write(&unblocked, "").unwrap(); + assert!( + set.matched(&unblocked).unwrap().is_none(), + "top-level secret.key should NOT be matched (rule scoped to frontend/)", + ); +} + +#[test] +fn nested_botignore_anchored_pattern_is_local() { + // A leading `/` anchors the pattern to the directory of the .botignore + // file it's declared in. + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::create_dir_all(root.join("frontend")).unwrap(); + fs::write(root.join("frontend/.botignore"), "/secret.key\n").unwrap(); + + let set = BotignoreSet::load(root).unwrap(); + + let blocked = root.join("frontend/secret.key"); + fs::write(&blocked, "").unwrap(); + assert!(set.matched(&blocked).unwrap().is_some()); + + let unblocked = root.join("secret.key"); + fs::write(&unblocked, "").unwrap(); + assert!( + set.matched(&unblocked).unwrap().is_none(), + "anchored /secret.key should NOT match outside frontend/", + ); +} + +#[test] +fn nested_botignore_overrides_root() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::write(root.join(".botignore"), "*.log\n").unwrap(); + fs::create_dir_all(root.join("logs")).unwrap(); + fs::write(root.join("logs/.botignore"), "!keep.log\n").unwrap(); + + let set = BotignoreSet::load(root).unwrap(); + + let blocked = root.join("logs/foo.log"); + fs::write(&blocked, "").unwrap(); + assert!(set.matched(&blocked).unwrap().is_some()); + + let kept = root.join("logs/keep.log"); + fs::write(&kept, "").unwrap(); + assert!( + set.matched(&kept).unwrap().is_none(), + "logs/.botignore should un-ignore keep.log", + ); +} diff --git a/crates/dirigent_fermata/tests/core_extract.rs b/crates/dirigent_fermata/tests/core_extract.rs new file mode 100644 index 0000000..72a07d9 --- /dev/null +++ b/crates/dirigent_fermata/tests/core_extract.rs @@ -0,0 +1,39 @@ +use dirigent_fermata::core::extract::{extract_path_candidates, Confidence}; + +#[test] +fn extracts_absolute_unix_path() { + let s = "the file is at /home/user/.env and was modified"; + let cs = extract_path_candidates(s); + assert!(cs.iter().any(|c| c.text == "/home/user/.env" && c.confidence == Confidence::High)); +} + +#[test] +fn extracts_absolute_windows_path() { + let s = r"see C:\Users\me\secret.toml for details"; + let cs = extract_path_candidates(s); + assert!(cs.iter().any(|c| c.text == r"C:\Users\me\secret.toml" && c.confidence == Confidence::High)); +} + +#[test] +fn extracts_relative_with_separator() { + let s = "modified src/lib.rs and tests/foo.rs"; + let cs = extract_path_candidates(s); + let texts: Vec<_> = cs.iter().map(|c| c.text.as_str()).collect(); + assert!(texts.contains(&"src/lib.rs")); + assert!(texts.contains(&"tests/foo.rs")); +} + +#[test] +fn bare_filename_with_extension_is_low_confidence() { + let s = "open README.md please"; + let cs = extract_path_candidates(s); + let r = cs.iter().find(|c| c.text == "README.md").unwrap(); + assert_eq!(r.confidence, Confidence::Low); +} + +#[test] +fn ignores_pure_words() { + let s = "the quick brown fox"; + let cs = extract_path_candidates(s); + assert!(cs.is_empty()); +} diff --git a/crates/dirigent_fermata/tests/core_op_decision.rs b/crates/dirigent_fermata/tests/core_op_decision.rs new file mode 100644 index 0000000..ee35d7a --- /dev/null +++ b/crates/dirigent_fermata/tests/core_op_decision.rs @@ -0,0 +1,42 @@ +use dirigent_fermata::core::{Decision, Op, Reason, Rule}; + +#[test] +fn op_variants_exist() { + let _ = Op::Read; + let _ = Op::Write; + let _ = Op::Execute; +} + +#[test] +fn decision_allow_is_simple() { + let d = Decision::Allow; + assert!(matches!(d, Decision::Allow)); +} + +#[test] +fn decision_deny_carries_reason() { + let rule = Rule { + source: "/proj/.botignore".into(), + pattern: ".env".into(), + }; + let d = Decision::Deny(Reason { + message: "blocked by .botignore".into(), + rule: Some(rule), + }); + match d { + Decision::Deny(r) => { + assert_eq!(r.message, "blocked by .botignore"); + assert!(r.rule.is_some()); + } + _ => panic!("expected Deny"), + } +} + +#[test] +fn decision_ask_carries_reason() { + let d = Decision::Ask(Reason { + message: "needs confirmation".into(), + rule: None, + }); + assert!(matches!(d, Decision::Ask(_))); +} diff --git a/crates/dirigent_fermata/tests/core_policy_command.rs b/crates/dirigent_fermata/tests/core_policy_command.rs new file mode 100644 index 0000000..6d49110 --- /dev/null +++ b/crates/dirigent_fermata/tests/core_policy_command.rs @@ -0,0 +1,52 @@ +use dirigent_fermata::core::{Decision, Policy}; +use std::fs; +use tempfile::TempDir; + +fn project_with(toml: &str) -> TempDir { + let tmp = TempDir::new().unwrap(); + fs::write(tmp.path().join("botignore.toml"), toml).unwrap(); + tmp +} + +#[test] +fn deny_substring_blocks() { + let tmp = project_with("[bash]\ndeny = [\"rm -rf /\"]\n"); + let p = Policy::load(tmp.path()).unwrap(); + assert!(matches!(p.check_command("sudo rm -rf / now").unwrap(), Decision::Deny(_))); +} + +#[test] +fn deny_glob_blocks() { + let tmp = project_with("[bash]\ndeny = [\"git push --force*\"]\n"); + let p = Policy::load(tmp.path()).unwrap(); + assert!(matches!(p.check_command("git push --force-with-lease").unwrap(), Decision::Deny(_))); +} + +#[test] +fn ask_returns_ask() { + let tmp = project_with("[bash]\nask = [\"rm *\"]\n"); + let p = Policy::load(tmp.path()).unwrap(); + assert!(matches!(p.check_command("rm somefile").unwrap(), Decision::Ask(_))); +} + +#[test] +fn allow_prefixes_allows() { + let tmp = project_with("[bash]\nallow_prefixes = [\"make test\"]\n"); + let p = Policy::load(tmp.path()).unwrap(); + assert_eq!(p.check_command("make test").unwrap(), Decision::Allow); + assert_eq!(p.check_command("make test-unit").unwrap(), Decision::Allow); +} + +#[test] +fn no_rules_means_allow() { + let tmp = project_with(""); + let p = Policy::load(tmp.path()).unwrap(); + assert_eq!(p.check_command("anything goes").unwrap(), Decision::Allow); +} + +#[test] +fn deny_takes_precedence_over_allow_prefix() { + let tmp = project_with("[bash]\ndeny = [\"rm -rf /\"]\nallow_prefixes = [\"rm\"]\n"); + let p = Policy::load(tmp.path()).unwrap(); + assert!(matches!(p.check_command("rm -rf /").unwrap(), Decision::Deny(_))); +} diff --git a/crates/dirigent_fermata/tests/core_policy_path.rs b/crates/dirigent_fermata/tests/core_policy_path.rs new file mode 100644 index 0000000..4612079 --- /dev/null +++ b/crates/dirigent_fermata/tests/core_policy_path.rs @@ -0,0 +1,64 @@ +use dirigent_fermata::core::{Decision, Op, Policy}; +use std::fs; +use tempfile::TempDir; + +fn make_project(botignore: &str, toml_text: &str) -> TempDir { + let tmp = TempDir::new().unwrap(); + fs::write(tmp.path().join(".botignore"), botignore).unwrap(); + if !toml_text.is_empty() { + fs::write(tmp.path().join("botignore.toml"), toml_text).unwrap(); + } + tmp +} + +#[test] +fn botignore_blocks_read() { + let tmp = make_project(".env\n", ""); + let policy = Policy::load(tmp.path()).unwrap(); + let target = tmp.path().join(".env"); + fs::write(&target, "").unwrap(); + let d = policy.check(Op::Read, &target).unwrap(); + assert!(matches!(d, Decision::Deny(_))); +} + +#[test] +fn botignore_blocks_write_too() { + let tmp = make_project(".env\n", ""); + let policy = Policy::load(tmp.path()).unwrap(); + let target = tmp.path().join(".env"); + let d = policy.check(Op::Write, &target).unwrap(); + assert!(matches!(d, Decision::Deny(_))); +} + +#[test] +fn unmatched_path_allowed() { + let tmp = make_project(".env\n", ""); + let policy = Policy::load(tmp.path()).unwrap(); + let target = tmp.path().join("src/main.rs"); + fs::create_dir_all(target.parent().unwrap()).unwrap(); + fs::write(&target, "").unwrap(); + let d = policy.check(Op::Read, &target).unwrap(); + assert_eq!(d, Decision::Allow); +} + +#[test] +fn toml_read_block_applies_only_to_read() { + let tmp = make_project("", "[read]\npatterns = [\"secrets/**\"]\n"); + let policy = Policy::load(tmp.path()).unwrap(); + let target = tmp.path().join("secrets/key.pem"); + fs::create_dir_all(target.parent().unwrap()).unwrap(); + fs::write(&target, "").unwrap(); + assert!(matches!(policy.check(Op::Read, &target).unwrap(), Decision::Deny(_))); + assert_eq!(policy.check(Op::Write, &target).unwrap(), Decision::Allow); +} + +#[test] +fn toml_write_block_applies_only_to_write() { + let tmp = make_project("", "[write]\npatterns = [\"vendor/**\"]\n"); + let policy = Policy::load(tmp.path()).unwrap(); + let target = tmp.path().join("vendor/lib.rs"); + fs::create_dir_all(target.parent().unwrap()).unwrap(); + fs::write(&target, "").unwrap(); + assert_eq!(policy.check(Op::Read, &target).unwrap(), Decision::Allow); + assert!(matches!(policy.check(Op::Write, &target).unwrap(), Decision::Deny(_))); +} diff --git a/crates/dirigent_fermata/tests/core_project.rs b/crates/dirigent_fermata/tests/core_project.rs new file mode 100644 index 0000000..b4a3dc3 --- /dev/null +++ b/crates/dirigent_fermata/tests/core_project.rs @@ -0,0 +1,69 @@ +use dirigent_fermata::core::project::find_project_root; +use std::fs; +use tempfile::TempDir; + +#[test] +fn finds_botignore_toml_first() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::create_dir_all(root.join("sub/deep")).unwrap(); + fs::write(root.join("botignore.toml"), "").unwrap(); + fs::write(root.join(".botignore"), "").unwrap(); + fs::create_dir_all(root.join(".git")).unwrap(); + + let target = root.join("sub/deep/file.rs"); + fs::write(&target, "").unwrap(); + + let found = find_project_root(&target).unwrap(); + assert_eq!(found, root); +} + +#[test] +fn falls_back_to_botignore() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::create_dir_all(root.join("sub")).unwrap(); + fs::write(root.join(".botignore"), "").unwrap(); + + let target = root.join("sub/file.rs"); + fs::write(&target, "").unwrap(); + + let found = find_project_root(&target).unwrap(); + assert_eq!(found, root); +} + +#[test] +fn falls_back_to_git() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::create_dir_all(root.join("sub")).unwrap(); + fs::create_dir_all(root.join(".git")).unwrap(); + + let target = root.join("sub/file.rs"); + fs::write(&target, "").unwrap(); + + let found = find_project_root(&target).unwrap(); + assert_eq!(found, root); +} + +#[test] +fn returns_none_when_no_marker() { + let tmp = TempDir::new().unwrap(); + let target = tmp.path().join("file.rs"); + std::fs::write(&target, "").unwrap(); + assert!(find_project_root(&target).is_none()); +} + +#[test] +fn walks_up_from_file_path_not_cwd() { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::create_dir_all(root.join("a/b/c")).unwrap(); + fs::write(root.join("a/.botignore"), "").unwrap(); + + let target = root.join("a/b/c/file.rs"); + fs::write(&target, "").unwrap(); + + let found = find_project_root(&target).unwrap(); + assert_eq!(found, root.join("a")); +} diff --git a/crates/dirigent_fermata/tests/core_toml_config.rs b/crates/dirigent_fermata/tests/core_toml_config.rs new file mode 100644 index 0000000..1ed9e36 --- /dev/null +++ b/crates/dirigent_fermata/tests/core_toml_config.rs @@ -0,0 +1,47 @@ +use dirigent_fermata::core::toml_config::{BotignoreToml, OpRules, BashRules}; + +#[test] +fn parses_full_config() { + let src = r#" +[read] +patterns = [".env*", "secrets/**"] + +[write] +patterns = ["vendor/**", "*.lock"] + +[bash] +deny = ["rm -rf /", "git push --force*"] +ask = ["rm:*"] +allow_prefixes = ["make test", "git checkout:*"] +"#; + let cfg: BotignoreToml = toml::from_str(src).unwrap(); + assert_eq!(cfg.read.unwrap().patterns, vec![".env*", "secrets/**"]); + assert_eq!(cfg.write.unwrap().patterns, vec!["vendor/**", "*.lock"]); + let bash = cfg.bash.unwrap(); + assert_eq!(bash.deny, vec!["rm -rf /", "git push --force*"]); + assert_eq!(bash.ask, vec!["rm:*"]); + assert_eq!(bash.allow_prefixes, vec!["make test", "git checkout:*"]); +} + +#[test] +fn empty_config_is_valid() { + let cfg: BotignoreToml = toml::from_str("").unwrap(); + assert!(cfg.read.is_none()); + assert!(cfg.write.is_none()); + assert!(cfg.bash.is_none()); +} + +#[test] +fn loads_from_disk_when_present() { + let tmp = tempfile::tempdir().unwrap(); + std::fs::write(tmp.path().join("botignore.toml"), "[read]\npatterns = [\".env\"]\n").unwrap(); + let cfg = BotignoreToml::load(tmp.path()).unwrap(); + assert_eq!(cfg.read.unwrap().patterns, vec![".env"]); +} + +#[test] +fn loads_empty_when_missing() { + let tmp = tempfile::tempdir().unwrap(); + let cfg = BotignoreToml::load(tmp.path()).unwrap(); + assert!(cfg.read.is_none()); +} diff --git a/crates/dirigent_fermata/tests/fixtures_a4.rs b/crates/dirigent_fermata/tests/fixtures_a4.rs new file mode 100644 index 0000000..eea6639 --- /dev/null +++ b/crates/dirigent_fermata/tests/fixtures_a4.rs @@ -0,0 +1,112 @@ +//! Smoke-test contract from `docs/workpad/brainstorm/fermata.md` Appendix A.4. + +use dirigent_fermata::core::{Decision, Op, Policy}; +use std::fs; +use tempfile::TempDir; + +fn fixture() -> TempDir { + let tmp = TempDir::new().unwrap(); + let root = tmp.path(); + fs::write(root.join(".botignore"), ".env\n.env.*\nconf/cert/**\nconf/mitmproxy/**\n").unwrap(); + fs::write( + root.join("botignore.toml"), + r#" +[read] +patterns = [ + "conf/localtestsettings.yaml", + "conf/localsettings.yaml", + "conf/default-secrets.yaml", + ".claude/self-reflections/**", +] + +[write] +patterns = [ + "conf/localtestsettings.yaml", + "conf/localsettings.yaml", + "conf/default-secrets.yaml", +] + +[bash] +deny = ["localtestsettings.yaml", "localsettings.yaml", "default-secrets.yaml", ".env"] +ask = ["rm *", "mv *"] +allow_prefixes = ["make test"] +"#, + ) + .unwrap(); + + fs::create_dir_all(root.join("conf")).unwrap(); + fs::create_dir_all(root.join("datatap")).unwrap(); + fs::create_dir_all(root.join(".claude/self-reflections")).unwrap(); + for f in [".env", "conf/localsettings.yaml", "datatap/foo.py", ".claude/self-reflections/x.md"] { + fs::write(root.join(f), "").unwrap(); + } + tmp +} + +#[test] +fn read_dot_env_denied() { + let t = fixture(); + let p = Policy::load(t.path()).unwrap(); + assert!(matches!(p.check(Op::Read, &t.path().join(".env")).unwrap(), Decision::Deny(_))); +} + +#[test] +fn bash_cat_dot_env_denied() { + let t = fixture(); + let p = Policy::load(t.path()).unwrap(); + assert!(matches!(p.check_command("cat ./.env").unwrap(), Decision::Deny(_))); +} + +#[test] +fn bash_rm_localsettings_denied() { + let t = fixture(); + let p = Policy::load(t.path()).unwrap(); + assert!(matches!( + p.check_command("rm ./conf/localsettings.yaml").unwrap(), + Decision::Deny(_) + )); +} + +#[test] +fn write_localsettings_denied() { + let t = fixture(); + let p = Policy::load(t.path()).unwrap(); + assert!(matches!( + p.check(Op::Write, &t.path().join("conf/localsettings.yaml")).unwrap(), + Decision::Deny(_) + )); +} + +#[test] +fn edit_datatap_foo_py_allowed() { + let t = fixture(); + let p = Policy::load(t.path()).unwrap(); + assert_eq!( + p.check(Op::Write, &t.path().join("datatap/foo.py")).unwrap(), + Decision::Allow + ); +} + +#[test] +fn bash_make_test_allowed() { + let t = fixture(); + let p = Policy::load(t.path()).unwrap(); + assert_eq!(p.check_command("make test").unwrap(), Decision::Allow); +} + +#[test] +fn bash_rm_somefile_asks() { + let t = fixture(); + let p = Policy::load(t.path()).unwrap(); + assert!(matches!(p.check_command("rm somefile").unwrap(), Decision::Ask(_))); +} + +#[test] +fn read_self_reflections_asks() { + // Note: A.4 has self-reflections under "ask" — current toml schema uses `[read].patterns` + // for hard reads. This documents the gap; once toml has a `[read].ask`, switch to Ask. + let t = fixture(); + let p = Policy::load(t.path()).unwrap(); + let d = p.check(Op::Read, &t.path().join(".claude/self-reflections/x.md")).unwrap(); + assert!(matches!(d, Decision::Deny(_))); +} diff --git a/crates/dirigent_fermata/tests/harness_claude.rs b/crates/dirigent_fermata/tests/harness_claude.rs new file mode 100644 index 0000000..2521973 --- /dev/null +++ b/crates/dirigent_fermata/tests/harness_claude.rs @@ -0,0 +1,86 @@ +use dirigent_fermata::core::{Decision, Reason}; +use dirigent_fermata::harness::{HarnessAdapter, PathKind, ToolOp}; +use dirigent_fermata::harness::claude::ClaudeAdapter; + +#[test] +fn parses_read_payload() { + let payload = br#"{"tool_name":"Read","tool_input":{"file_path":"/proj/.env"}}"#; + let call = ClaudeAdapter.parse_request(payload).unwrap(); + assert_eq!(call.tool_name, "Read"); + match call.op { + ToolOp::Path { path, kind } => { + assert_eq!(path.to_string_lossy(), "/proj/.env"); + assert_eq!(kind, PathKind::Read); + } + _ => panic!("expected Path op"), + } +} + +#[test] +fn parses_write_payload() { + let payload = br#"{"tool_name":"Write","tool_input":{"file_path":"/proj/out.txt"}}"#; + let call = ClaudeAdapter.parse_request(payload).unwrap(); + assert!(matches!(call.op, ToolOp::Path { kind: PathKind::Write, .. })); +} + +#[test] +fn parses_edit_as_write() { + let payload = br#"{"tool_name":"Edit","tool_input":{"file_path":"/proj/src.rs"}}"#; + let call = ClaudeAdapter.parse_request(payload).unwrap(); + assert!(matches!(call.op, ToolOp::Path { kind: PathKind::Write, .. })); +} + +#[test] +fn parses_multiedit_as_write() { + let payload = br#"{"tool_name":"MultiEdit","tool_input":{"file_path":"/proj/src.rs","edits":[]}}"#; + let call = ClaudeAdapter.parse_request(payload).unwrap(); + assert!(matches!(call.op, ToolOp::Path { kind: PathKind::Write, .. })); +} + +#[test] +fn parses_bash_payload() { + let payload = br#"{"tool_name":"Bash","tool_input":{"command":"rm -rf /"}}"#; + let call = ClaudeAdapter.parse_request(payload).unwrap(); + match call.op { + ToolOp::Command { text } => assert_eq!(text, "rm -rf /"), + _ => panic!("expected Command op"), + } +} + +#[test] +fn renders_deny_as_hookspecificoutput() { + let payload = br#"{"tool_name":"Read","tool_input":{"file_path":"/proj/.env"}}"#; + let call = ClaudeAdapter.parse_request(payload).unwrap(); + let d = Decision::Deny(Reason { + message: "blocked by .botignore: .env".into(), + rule: None, + }); + let out = ClaudeAdapter.render_decision(&call, &d).unwrap(); + let v: serde_json::Value = serde_json::from_slice(&out).unwrap(); + assert_eq!(v["hookSpecificOutput"]["hookEventName"], "PreToolUse"); + assert_eq!(v["hookSpecificOutput"]["permissionDecision"], "deny"); + assert!(v["hookSpecificOutput"]["permissionDecisionReason"] + .as_str() + .unwrap() + .contains(".env")); +} + +#[test] +fn renders_allow_as_allow() { + let payload = br#"{"tool_name":"Read","tool_input":{"file_path":"/proj/src/main.rs"}}"#; + let call = ClaudeAdapter.parse_request(payload).unwrap(); + let out = ClaudeAdapter.render_decision(&call, &Decision::Allow).unwrap(); + let v: serde_json::Value = serde_json::from_slice(&out).unwrap(); + assert_eq!(v["hookSpecificOutput"]["permissionDecision"], "allow"); +} + +#[test] +fn renders_ask_as_ask() { + let payload = br#"{"tool_name":"Bash","tool_input":{"command":"rm something"}}"#; + let call = ClaudeAdapter.parse_request(payload).unwrap(); + let out = ClaudeAdapter + .render_decision(&call, &Decision::Ask(Reason { message: "confirm".into(), rule: None })) + .unwrap(); + let v: serde_json::Value = serde_json::from_slice(&out).unwrap(); + assert_eq!(v["hookSpecificOutput"]["permissionDecision"], "ask"); +} diff --git a/crates/dirigent_inspector/Cargo.toml b/crates/dirigent_inspector/Cargo.toml new file mode 100644 index 0000000..a9fbc13 --- /dev/null +++ b/crates/dirigent_inspector/Cargo.toml @@ -0,0 +1,36 @@ +[package] +name = "dirigent_inspector" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[dependencies] +# Async traits +async-trait = "0.1" + +# Date/time handling +chrono = { version = "0.4", features = ["serde"] } + +# Protocol types (canonical node types) +dirigent_protocol = { path = "../dirigent_protocol" } + +# Serialization +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +# Cross-platform process/system metrics +sysinfo = "0.33" + +# Error handling +thiserror = "2.0" + +# Async runtime +tokio = { version = "1.42", features = ["sync", "time", "rt", "macros"] } + +# Logging +tracing = "0.1" + +[dev-dependencies] +tokio = { version = "1.42", features = ["full"] } diff --git a/crates/dirigent_inspector/src/channel.rs b/crates/dirigent_inspector/src/channel.rs new file mode 100644 index 0000000..7673122 --- /dev/null +++ b/crates/dirigent_inspector/src/channel.rs @@ -0,0 +1,349 @@ +use crate::error::{InspectorError, Result}; +use serde::{Deserialize, Serialize}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot}; + +/// A command that can be sent to a node. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NodeCommand { + /// Unique command ID for correlation with responses. + pub id: String, + /// What kind of command this is. + pub kind: CommandKind, + /// Arbitrary payload data. + pub payload: serde_json::Value, +} + +/// The type of command being sent to a node. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum CommandKind { + /// Request the node to report its internal state (introspective). + Introspect, + /// Execute a named operation. + Execute(String), + /// Custom extension command. + Custom(String), +} + +/// Response from a node to a command. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct CommandResponse { + /// The command ID this is responding to. + pub command_id: String, + /// Whether the command was handled successfully. + pub success: bool, + /// Response data. + pub data: serde_json::Value, +} + +impl CommandResponse { + /// Create a success response. + pub fn ok(command_id: impl Into<String>, data: serde_json::Value) -> Self { + Self { + command_id: command_id.into(), + success: true, + data, + } + } + + /// Create an error response. + pub fn err(command_id: impl Into<String>, message: impl Into<String>) -> Self { + Self { + command_id: command_id.into(), + success: false, + data: serde_json::json!({ "error": message.into() }), + } + } +} + +type CommandPayload = (NodeCommand, oneshot::Sender<CommandResponse>); + +/// Create a new inspector channel pair. +/// +/// Returns `(sender, receiver)` where: +/// - The **sender** is used by callers (UI, API) to send commands to the node. +/// - The **receiver** is used by the node's loop to receive and respond to commands. +/// +/// `capacity` controls the bounded channel size. +pub fn channel(capacity: usize) -> (InspectorChannelSender, InspectorChannelReceiver) { + let (tx, rx) = mpsc::channel(capacity); + let pending = Arc::new(AtomicUsize::new(0)); + + ( + InspectorChannelSender { + tx, + pending: Arc::clone(&pending), + }, + InspectorChannelReceiver { rx, pending }, + ) +} + +/// Caller-side of the inspector channel: send commands, check queue depth. +pub struct InspectorChannelSender { + tx: mpsc::Sender<CommandPayload>, + pending: Arc<AtomicUsize>, +} + +impl InspectorChannelSender { + /// Send a command and wait for the response. + pub async fn send(&self, cmd: NodeCommand) -> Result<CommandResponse> { + let (resp_tx, resp_rx) = oneshot::channel(); + self.pending.fetch_add(1, Ordering::Relaxed); + + self.tx + .send((cmd, resp_tx)) + .await + .map_err(|_| InspectorError::ChannelClosed)?; + + resp_rx.await.map_err(|_| InspectorError::ChannelClosed) + } + + /// Try to send a command without waiting, returning a receiver for the response. + /// + /// This is useful when you want to fire off a command and collect the + /// response later, or in a `select!` branch. + pub fn try_send(&self, cmd: NodeCommand) -> Result<oneshot::Receiver<CommandResponse>> { + let (resp_tx, resp_rx) = oneshot::channel(); + self.pending.fetch_add(1, Ordering::Relaxed); + + self.tx.try_send((cmd, resp_tx)).map_err(|e| match e { + mpsc::error::TrySendError::Full(_) => { + self.pending.fetch_sub(1, Ordering::Relaxed); + InspectorError::ChannelFull + } + mpsc::error::TrySendError::Closed(_) => { + self.pending.fetch_sub(1, Ordering::Relaxed); + InspectorError::ChannelClosed + } + })?; + + Ok(resp_rx) + } + + /// Number of commands currently in the queue (approximate). + pub fn pending_count(&self) -> usize { + self.pending.load(Ordering::Relaxed) + } +} + +impl Clone for InspectorChannelSender { + fn clone(&self) -> Self { + Self { + tx: self.tx.clone(), + pending: Arc::clone(&self.pending), + } + } +} + +/// Node-side of the inspector channel: receive commands, send responses. +pub struct InspectorChannelReceiver { + rx: mpsc::Receiver<CommandPayload>, + pending: Arc<AtomicUsize>, +} + +impl InspectorChannelReceiver { + /// Receive the next command. Returns `None` when all senders are dropped. + /// + /// The returned `oneshot::Sender` must be used to send back a response. + pub async fn recv(&mut self) -> Option<(NodeCommand, oneshot::Sender<CommandResponse>)> { + let result = self.rx.recv().await; + if result.is_some() { + self.pending.fetch_sub(1, Ordering::Relaxed); + } + result + } + + /// Try to receive without blocking. Returns `None` if the channel is empty. + pub fn try_recv(&mut self) -> Option<(NodeCommand, oneshot::Sender<CommandResponse>)> { + match self.rx.try_recv() { + Ok(payload) => { + self.pending.fetch_sub(1, Ordering::Relaxed); + Some(payload) + } + Err(_) => None, + } + } + + /// Number of commands currently in the queue (approximate). + pub fn pending_count(&self) -> usize { + self.pending.load(Ordering::Relaxed) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn introspect_cmd(id: &str) -> NodeCommand { + NodeCommand { + id: id.to_string(), + kind: CommandKind::Introspect, + payload: serde_json::Value::Null, + } + } + + fn exec_cmd(id: &str, name: &str, payload: serde_json::Value) -> NodeCommand { + NodeCommand { + id: id.to_string(), + kind: CommandKind::Execute(name.to_string()), + payload, + } + } + + #[tokio::test] + async fn test_send_recv_response() { + let (sender, mut receiver) = channel(10); + + // Simulate a node loop in a background task + let node_task = tokio::spawn(async move { + if let Some((cmd, resp_tx)) = receiver.recv().await { + assert_eq!(cmd.id, "cmd-1"); + assert!(matches!(cmd.kind, CommandKind::Introspect)); + let _ = resp_tx.send(CommandResponse::ok( + &cmd.id, + serde_json::json!({ "queue_len": 5 }), + )); + } + }); + + let response = sender.send(introspect_cmd("cmd-1")).await.unwrap(); + assert!(response.success); + assert_eq!(response.command_id, "cmd-1"); + assert_eq!(response.data["queue_len"], 5); + + node_task.await.unwrap(); + } + + #[tokio::test] + async fn test_try_send() { + let (sender, mut receiver) = channel(10); + + let resp_rx = sender + .try_send(exec_cmd("cmd-2", "restart", serde_json::json!({}))) + .unwrap(); + + // Node responds + let (cmd, resp_tx) = receiver.recv().await.unwrap(); + assert_eq!(cmd.id, "cmd-2"); + let _ = resp_tx.send(CommandResponse::ok(&cmd.id, serde_json::json!("restarted"))); + + let response = resp_rx.await.unwrap(); + assert!(response.success); + } + + #[tokio::test] + async fn test_pending_count() { + let (sender, mut receiver) = channel(10); + + assert_eq!(sender.pending_count(), 0); + + // Send 3 commands without receiving + let _r1 = sender.try_send(introspect_cmd("a")).unwrap(); + let _r2 = sender.try_send(introspect_cmd("b")).unwrap(); + let _r3 = sender.try_send(introspect_cmd("c")).unwrap(); + + assert_eq!(sender.pending_count(), 3); + assert_eq!(receiver.pending_count(), 3); + + // Receive one + let (cmd, resp_tx) = receiver.recv().await.unwrap(); + let _ = resp_tx.send(CommandResponse::ok(&cmd.id, serde_json::Value::Null)); + + assert_eq!(receiver.pending_count(), 2); + } + + #[tokio::test] + async fn test_try_send_full() { + let (sender, _receiver) = channel(1); + + // Fill the channel + let _r1 = sender.try_send(introspect_cmd("a")).unwrap(); + + // Should fail with ChannelFull + let result = sender.try_send(introspect_cmd("b")); + assert!(matches!(result, Err(InspectorError::ChannelFull))); + } + + #[tokio::test] + async fn test_send_after_receiver_dropped() { + let (sender, receiver) = channel(10); + drop(receiver); + + let result = sender.send(introspect_cmd("orphan")).await; + assert!(matches!(result, Err(InspectorError::ChannelClosed))); + } + + #[tokio::test] + async fn test_recv_after_sender_dropped() { + let (sender, mut receiver) = channel(10); + drop(sender); + + let result = receiver.recv().await; + assert!(result.is_none()); + } + + #[tokio::test] + async fn test_try_recv_empty() { + let (_sender, mut receiver) = channel(10); + assert!(receiver.try_recv().is_none()); + } + + #[tokio::test] + async fn test_error_response() { + let (sender, mut receiver) = channel(10); + + let node_task = tokio::spawn(async move { + if let Some((cmd, resp_tx)) = receiver.recv().await { + let _ = resp_tx.send(CommandResponse::err(&cmd.id, "not supported")); + } + }); + + let response = sender + .send(exec_cmd("cmd-x", "unknown", serde_json::Value::Null)) + .await + .unwrap(); + assert!(!response.success); + assert_eq!(response.data["error"], "not supported"); + + node_task.await.unwrap(); + } + + #[tokio::test] + async fn test_multiple_senders() { + let (sender, mut receiver) = channel(10); + let sender2 = sender.clone(); + + let _r1 = sender.try_send(introspect_cmd("from-1")).unwrap(); + let _r2 = sender2.try_send(introspect_cmd("from-2")).unwrap(); + + let (cmd1, resp_tx1) = receiver.recv().await.unwrap(); + let _ = resp_tx1.send(CommandResponse::ok(&cmd1.id, serde_json::Value::Null)); + + let (cmd2, resp_tx2) = receiver.recv().await.unwrap(); + let _ = resp_tx2.send(CommandResponse::ok(&cmd2.id, serde_json::Value::Null)); + + // Both commands were received (order may vary since mpsc is FIFO) + let ids = vec![cmd1.id, cmd2.id]; + assert!(ids.contains(&"from-1".to_string())); + assert!(ids.contains(&"from-2".to_string())); + } + + #[test] + fn test_command_serialization() { + let cmd = exec_cmd("cmd-1", "restart", serde_json::json!({"force": true})); + let json = serde_json::to_string(&cmd).unwrap(); + let deserialized: NodeCommand = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.id, "cmd-1"); + assert!(matches!(deserialized.kind, CommandKind::Execute(ref name) if name == "restart")); + } + + #[test] + fn test_response_serialization() { + let resp = CommandResponse::ok("cmd-1", serde_json::json!({"status": "done"})); + let json = serde_json::to_string(&resp).unwrap(); + let deserialized: CommandResponse = serde_json::from_str(&json).unwrap(); + assert!(deserialized.success); + assert_eq!(deserialized.data["status"], "done"); + } +} diff --git a/crates/dirigent_inspector/src/error.rs b/crates/dirigent_inspector/src/error.rs new file mode 100644 index 0000000..9ec227c --- /dev/null +++ b/crates/dirigent_inspector/src/error.rs @@ -0,0 +1,34 @@ +use crate::node::NodeId; + +/// Errors that can occur in the inspector system. +#[derive(Debug, thiserror::Error)] +pub enum InspectorError { + #[error("node not found: {0}")] + NodeNotFound(NodeId), + + #[error("node already exists: {0}")] + NodeAlreadyExists(NodeId), + + #[error("parent node not found: {0}")] + ParentNotFound(NodeId), + + #[error("cannot remove root node")] + CannotRemoveRoot, + + #[error("channel closed")] + ChannelClosed, + + #[error("channel full")] + ChannelFull, + + #[error("command timed out")] + CommandTimeout, + + #[error("process not found: pid {0}")] + ProcessNotFound(u32), + + #[error("{0}")] + Internal(String), +} + +pub type Result<T> = std::result::Result<T, InspectorError>; diff --git a/crates/dirigent_inspector/src/handle.rs b/crates/dirigent_inspector/src/handle.rs new file mode 100644 index 0000000..e00ff76 --- /dev/null +++ b/crates/dirigent_inspector/src/handle.rs @@ -0,0 +1,262 @@ +use crate::error::Result; +use crate::node::{Inspectable, NodeId, NodeMetadata, NodeState}; +use crate::registry::InspectorRegistry; +use std::collections::HashMap; +use std::sync::Arc; + +/// Handle to a registered node in the inspector tree. +/// +/// Returned when a node is registered. The producer (connector, service, etc.) +/// uses this handle to update their node's state and properties without needing +/// direct access to the registry. +/// +/// When the handle is dropped, the node is automatically deregistered (best-effort). +/// To keep the node alive without the handle, call `detach()` before dropping. +pub struct NodeHandle { + id: NodeId, + registry: Arc<InspectorRegistry>, + detached: bool, +} + +impl NodeHandle { + pub(crate) fn new(id: NodeId, registry: Arc<InspectorRegistry>) -> Self { + Self { + id, + registry, + detached: false, + } + } + + /// Get this node's ID. + pub fn id(&self) -> &NodeId { + &self.id + } + + /// Update this node's lifecycle state. + pub async fn set_state(&self, state: NodeState) -> Result<()> { + self.registry.update_state(&self.id, state).await + } + + /// Set a single property on this node. + pub async fn set_property(&self, key: &str, value: serde_json::Value) -> Result<()> { + let mut props = HashMap::new(); + props.insert(key.to_string(), value); + self.registry.update_properties(&self.id, props).await + } + + /// Set multiple properties on this node. + pub async fn set_properties(&self, props: HashMap<String, serde_json::Value>) -> Result<()> { + self.registry.update_properties(&self.id, props).await + } + + /// Register a child node under this node. + /// + /// Returns a new `NodeHandle` for the child. + pub async fn register_child( + &self, + child_id: NodeId, + metadata: NodeMetadata, + inspectable: Option<Arc<dyn Inspectable>>, + ) -> Result<NodeHandle> { + self.registry + .register(child_id, &self.id, metadata, inspectable) + .await + } + + /// Explicitly deregister this node and consume the handle. + pub async fn deregister(mut self) -> Result<()> { + self.detached = true; // prevent Drop from double-deregistering + self.registry.deregister(&self.id).await + } + + /// Detach this handle so the node survives when the handle is dropped. + /// + /// After calling this, dropping the handle will NOT deregister the node. + /// The node can still be removed via `InspectorRegistry::deregister()`. + pub fn detach(&mut self) { + self.detached = true; + } + + /// Get a reference to the registry this handle is connected to. + pub fn registry(&self) -> &Arc<InspectorRegistry> { + &self.registry + } +} + +impl Drop for NodeHandle { + fn drop(&mut self) { + if !self.detached { + let id = self.id.clone(); + let registry = Arc::clone(&self.registry); + // Best-effort async deregister from a sync Drop context + tokio::spawn(async move { + let _ = registry.deregister(&id).await; + }); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::node::{NodeKind, NodeMetadata, NodeState}; + + #[tokio::test] + async fn test_handle_set_state() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let handle = registry + .register( + NodeId::new("dirigent/test"), + &root, + NodeMetadata::new(NodeKind::Service, "Test"), + None, + ) + .await + .unwrap(); + + handle.set_state(NodeState::Running).await.unwrap(); + + let meta = registry + .get_node(&NodeId::new("dirigent/test")) + .await + .unwrap(); + assert_eq!(meta.state, NodeState::Running); + } + + #[tokio::test] + async fn test_handle_set_property() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let handle = registry + .register( + NodeId::new("dirigent/proc"), + &root, + NodeMetadata::new(NodeKind::Process, "Proc"), + None, + ) + .await + .unwrap(); + + handle + .set_property("pid", serde_json::json!(9999)) + .await + .unwrap(); + + let meta = registry + .get_node(&NodeId::new("dirigent/proc")) + .await + .unwrap(); + assert_eq!(meta.properties["pid"], serde_json::json!(9999)); + } + + #[tokio::test] + async fn test_handle_register_child() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let parent_handle = registry + .register( + NodeId::new("dirigent/parent"), + &root, + NodeMetadata::new(NodeKind::Connector, "Parent"), + None, + ) + .await + .unwrap(); + + let child_handle = parent_handle + .register_child( + NodeId::new("dirigent/parent/child"), + NodeMetadata::new(NodeKind::Process, "Child"), + None, + ) + .await + .unwrap(); + + assert_eq!(child_handle.id().as_str(), "dirigent/parent/child"); + assert!( + registry + .contains(&NodeId::new("dirigent/parent/child")) + .await + ); + } + + #[tokio::test] + async fn test_handle_deregister() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let handle = registry + .register( + NodeId::new("dirigent/temp"), + &root, + NodeMetadata::new(NodeKind::AsyncTask, "Temp"), + None, + ) + .await + .unwrap(); + + assert!(registry.contains(&NodeId::new("dirigent/temp")).await); + + handle.deregister().await.unwrap(); + + assert!(!registry.contains(&NodeId::new("dirigent/temp")).await); + } + + #[tokio::test] + async fn test_handle_drop_deregisters() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + { + let _handle = registry + .register( + NodeId::new("dirigent/ephemeral"), + &root, + NodeMetadata::new(NodeKind::AsyncTask, "Ephemeral"), + None, + ) + .await + .unwrap(); + // handle dropped here + } + + // Give the spawned deregister task a moment to run + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + assert!( + !registry.contains(&NodeId::new("dirigent/ephemeral")).await, + "Node should be deregistered after handle drop" + ); + } + + #[tokio::test] + async fn test_handle_detach_survives_drop() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + { + let mut handle = registry + .register( + NodeId::new("dirigent/persistent"), + &root, + NodeMetadata::new(NodeKind::Service, "Persistent"), + None, + ) + .await + .unwrap(); + handle.detach(); + // handle dropped here, but detached + } + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + assert!( + registry.contains(&NodeId::new("dirigent/persistent")).await, + "Detached node should survive handle drop" + ); + } +} diff --git a/crates/dirigent_inspector/src/lib.rs b/crates/dirigent_inspector/src/lib.rs new file mode 100644 index 0000000..e3b11b0 --- /dev/null +++ b/crates/dirigent_inspector/src/lib.rs @@ -0,0 +1,23 @@ +pub mod channel; +pub mod error; +pub mod handle; +pub mod node; +pub mod process; +pub mod registry; +pub mod snapshot; +pub mod system; +pub mod tree; + +// Re-export commonly used types +pub use channel::{ + channel as inspector_channel, CommandKind, CommandResponse, InspectorChannelReceiver, + InspectorChannelSender, NodeCommand, +}; +pub use error::{InspectorError, Result}; +pub use handle::NodeHandle; +pub use node::{Inspectable, NodeId, NodeKind, NodeMetadata, NodeState}; +pub use process::{ProcessInfo, ProcessMonitor, ProcessStatus}; +pub use registry::{InspectorEvent, InspectorRegistry}; +pub use snapshot::{NodeSnapshot, TreeSnapshot}; +pub use system::{SystemInfo, SystemMonitor}; +pub use tree::{NodeTree, TreeNode}; diff --git a/crates/dirigent_inspector/src/node.rs b/crates/dirigent_inspector/src/node.rs new file mode 100644 index 0000000..519f89d --- /dev/null +++ b/crates/dirigent_inspector/src/node.rs @@ -0,0 +1,109 @@ +use async_trait::async_trait; + +// Re-export canonical types from dirigent_protocol +pub use dirigent_protocol::inspector::{NodeId, NodeKind, NodeMetadata, NodeState}; + +/// Trait for components that support rich introspection. +/// +/// Implementing this trait allows a component to provide detailed status reports +/// when queried. The `inspect()` method runs in the component's own async context +/// (introspective), while `current_state()` is synchronous for quick polling. +#[async_trait] +pub trait Inspectable: Send + Sync { + /// Return a detailed status report as JSON. + /// + /// This is an introspective operation: it runs inside the node's context + /// and can access internal state that isn't part of the standard metadata. + async fn inspect(&self) -> serde_json::Value; + + /// Return the current lifecycle state. + /// + /// This should be a cheap, synchronous operation returning the node's + /// current state without blocking. + fn current_state(&self) -> NodeState; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_node_id_child() { + let root = NodeId::new("dirigent"); + let child = root.child("connectors"); + assert_eq!(child.as_str(), "dirigent/connectors"); + + let grandchild = child.child("acp-claude"); + assert_eq!(grandchild.as_str(), "dirigent/connectors/acp-claude"); + } + + #[test] + fn test_node_id_parent() { + let id = NodeId::new("dirigent/connectors/acp-claude"); + assert_eq!(id.parent().unwrap().as_str(), "dirigent/connectors"); + assert_eq!( + id.parent().unwrap().parent().unwrap().as_str(), + "dirigent" + ); + assert!(NodeId::new("dirigent").parent().is_none()); + } + + #[test] + fn test_node_id_name() { + assert_eq!(NodeId::new("dirigent/connectors/acp").name(), "acp"); + assert_eq!(NodeId::new("dirigent").name(), "dirigent"); + } + + #[test] + fn test_node_id_from_str() { + let id: NodeId = "dirigent/system/host".into(); + assert_eq!(id.as_str(), "dirigent/system/host"); + } + + #[test] + fn test_node_metadata_builder() { + let meta = NodeMetadata::new(NodeKind::Connector, "My Connector") + .with_state(NodeState::Running) + .with_property("base_url", serde_json::json!("http://localhost:3000")); + + assert_eq!(meta.kind, NodeKind::Connector); + assert_eq!(meta.label, "My Connector"); + assert_eq!(meta.state, NodeState::Running); + assert_eq!( + meta.properties.get("base_url").unwrap(), + &serde_json::json!("http://localhost:3000") + ); + } + + #[test] + fn test_node_metadata_serialization() { + let meta = NodeMetadata::new(NodeKind::Process, "stdio-transport") + .with_state(NodeState::Busy("processing message".into())) + .with_property("pid", serde_json::json!(12345)); + + let json = serde_json::to_string(&meta).unwrap(); + let deserialized: NodeMetadata = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.kind, NodeKind::Process); + assert_eq!(deserialized.label, "stdio-transport"); + assert_eq!( + deserialized.state, + NodeState::Busy("processing message".into()) + ); + } + + #[test] + fn test_node_state_display() { + assert_eq!(NodeState::Running.to_string(), "Running"); + assert_eq!( + NodeState::Error("timeout".into()).to_string(), + "Error(timeout)" + ); + } + + #[test] + fn test_node_kind_display() { + assert_eq!(NodeKind::Root.to_string(), "Root"); + assert_eq!(NodeKind::Custom("agent".into()).to_string(), "Custom(agent)"); + } +} diff --git a/crates/dirigent_inspector/src/process.rs b/crates/dirigent_inspector/src/process.rs new file mode 100644 index 0000000..d85f11c --- /dev/null +++ b/crates/dirigent_inspector/src/process.rs @@ -0,0 +1,381 @@ +use crate::node::NodeId; +use crate::registry::InspectorRegistry; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; +use sysinfo::{Pid, ProcessesToUpdate, System}; +use tokio::task::JoinHandle; +use tracing::{debug, warn}; + +/// Information about a monitored OS process. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ProcessInfo { + pub pid: u32, + pub name: String, + pub command: Vec<String>, + pub exe: Option<PathBuf>, + pub cwd: Option<PathBuf>, + pub status: ProcessStatus, + pub cpu_usage_percent: f32, + pub memory_bytes: u64, + pub virtual_memory_bytes: u64, + pub start_time_secs: u64, + pub run_time_secs: u64, + pub disk_read_bytes: u64, + pub disk_written_bytes: u64, +} + +/// Simplified cross-platform process status. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum ProcessStatus { + Running, + Sleeping, + Stopped, + Zombie, + Dead, + Unknown, +} + +impl From<sysinfo::ProcessStatus> for ProcessStatus { + fn from(s: sysinfo::ProcessStatus) -> Self { + match s { + sysinfo::ProcessStatus::Run => ProcessStatus::Running, + sysinfo::ProcessStatus::Sleep | sysinfo::ProcessStatus::Idle => ProcessStatus::Sleeping, + sysinfo::ProcessStatus::Stop | sysinfo::ProcessStatus::Tracing => { + ProcessStatus::Stopped + } + sysinfo::ProcessStatus::Zombie => ProcessStatus::Zombie, + sysinfo::ProcessStatus::Dead | sysinfo::ProcessStatus::UninterruptibleDiskSleep => { + ProcessStatus::Dead + } + _ => ProcessStatus::Unknown, + } + } +} + +/// Direction of the last observed I/O activity. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum IoDirection { + Read, + Write, +} + +/// Information about the most recent I/O activity. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct IoActivity { + pub direction: IoDirection, + pub timestamp: chrono::DateTime<chrono::Utc>, +} + +/// Monitors a set of OS processes by PID, providing metrics via `sysinfo`. +/// +/// Each tracked PID is associated with a `NodeId` in the inspector tree. +/// The monitor can be polled manually or run as a background task that +/// periodically refreshes and updates the registry. +pub struct ProcessMonitor { + system: System, + tracked: HashMap<u32, NodeId>, +} + +impl ProcessMonitor { + /// Create a new process monitor. + pub fn new() -> Self { + Self { + system: System::new(), + tracked: HashMap::new(), + } + } + + /// Start tracking a process by PID, associated with the given node ID. + pub fn track(&mut self, pid: u32, node_id: NodeId) { + debug!(pid, node_id = %node_id, "Tracking process"); + self.tracked.insert(pid, node_id); + } + + /// Stop tracking a process. + pub fn untrack(&mut self, pid: u32) { + debug!(pid, "Untracking process"); + self.tracked.remove(&pid); + } + + /// Get the set of tracked PIDs and their node IDs. + pub fn tracked_pids(&self) -> &HashMap<u32, NodeId> { + &self.tracked + } + + /// Refresh data for all tracked processes and return their info. + pub fn refresh(&mut self) -> HashMap<u32, ProcessInfo> { + // Build list of PIDs to refresh + let pids: Vec<Pid> = self.tracked.keys().map(|&p| Pid::from_u32(p)).collect(); + + // Refresh only tracked processes + self.system + .refresh_processes(ProcessesToUpdate::Some(&pids), true); + + let mut results = HashMap::new(); + + for (&pid, _node_id) in &self.tracked { + let sysinfo_pid = Pid::from_u32(pid); + if let Some(process) = self.system.process(sysinfo_pid) { + let disk_usage = process.disk_usage(); + let info = ProcessInfo { + pid, + name: process.name().to_string_lossy().to_string(), + command: process + .cmd() + .iter() + .map(|s| s.to_string_lossy().to_string()) + .collect(), + exe: process.exe().map(|p| p.to_path_buf()), + cwd: process.cwd().map(|p| p.to_path_buf()), + status: ProcessStatus::from(process.status()), + cpu_usage_percent: process.cpu_usage(), + memory_bytes: process.memory(), + virtual_memory_bytes: process.virtual_memory(), + start_time_secs: process.start_time(), + run_time_secs: process.run_time(), + disk_read_bytes: disk_usage.read_bytes, + disk_written_bytes: disk_usage.written_bytes, + }; + results.insert(pid, info); + } + } + + results + } + + /// Get info for a single tracked process. + pub fn get(&mut self, pid: u32) -> Option<ProcessInfo> { + let sysinfo_pid = Pid::from_u32(pid); + self.system + .refresh_processes(ProcessesToUpdate::Some(&[sysinfo_pid]), true); + + self.system.process(sysinfo_pid).map(|process| { + let disk_usage = process.disk_usage(); + ProcessInfo { + pid, + name: process.name().to_string_lossy().to_string(), + command: process + .cmd() + .iter() + .map(|s| s.to_string_lossy().to_string()) + .collect(), + exe: process.exe().map(|p| p.to_path_buf()), + cwd: process.cwd().map(|p| p.to_path_buf()), + status: ProcessStatus::from(process.status()), + cpu_usage_percent: process.cpu_usage(), + memory_bytes: process.memory(), + virtual_memory_bytes: process.virtual_memory(), + start_time_secs: process.start_time(), + run_time_secs: process.run_time(), + disk_read_bytes: disk_usage.read_bytes, + disk_written_bytes: disk_usage.written_bytes, + } + }) + } + + /// Check if a process is alive (outrospective: queries the OS directly). + pub fn is_alive(&mut self, pid: u32) -> bool { + let sysinfo_pid = Pid::from_u32(pid); + self.system + .refresh_processes(ProcessesToUpdate::Some(&[sysinfo_pid]), true); + self.system.process(sysinfo_pid).is_some() + } + + /// Spawn a background task that periodically refreshes tracked processes + /// and updates their nodes in the registry. + /// + /// The task runs until the returned `JoinHandle` is aborted or the + /// monitor is dropped. + pub fn start_polling( + mut self, + registry: Arc<InspectorRegistry>, + interval: Duration, + ) -> JoinHandle<()> { + tokio::spawn(async move { + let mut ticker = tokio::time::interval(interval); + loop { + ticker.tick().await; + + let infos = self.refresh(); + + for (pid, info) in &infos { + if let Some(node_id) = self.tracked.get(pid) { + let mut props = HashMap::new(); + props.insert("pid".to_string(), serde_json::json!(info.pid)); + props.insert("name".to_string(), serde_json::json!(info.name)); + props.insert("command".to_string(), serde_json::json!(info.command)); + props.insert( + "status".to_string(), + serde_json::to_value(&info.status).unwrap_or_default(), + ); + props.insert( + "cpu_usage_percent".to_string(), + serde_json::json!(info.cpu_usage_percent), + ); + props.insert( + "memory_bytes".to_string(), + serde_json::json!(info.memory_bytes), + ); + props.insert( + "virtual_memory_bytes".to_string(), + serde_json::json!(info.virtual_memory_bytes), + ); + props.insert( + "run_time_secs".to_string(), + serde_json::json!(info.run_time_secs), + ); + props.insert( + "disk_read_bytes".to_string(), + serde_json::json!(info.disk_read_bytes), + ); + props.insert( + "disk_written_bytes".to_string(), + serde_json::json!(info.disk_written_bytes), + ); + + if let Err(e) = registry.update_properties(node_id, props).await { + warn!( + pid, + node_id = %node_id, + error = %e, + "Failed to update process node properties" + ); + } + } + } + + // Check for dead processes + let dead_pids: Vec<u32> = self + .tracked + .keys() + .filter(|&&pid| !infos.contains_key(&pid)) + .copied() + .collect(); + + for pid in dead_pids { + if let Some(node_id) = self.tracked.get(&pid) { + debug!(pid, node_id = %node_id, "Process no longer alive"); + let _ = registry + .update_state(node_id, crate::node::NodeState::Stopped) + .await; + } + } + } + }) + } +} + +impl Default for ProcessMonitor { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_process_status_from_sysinfo() { + assert_eq!( + ProcessStatus::from(sysinfo::ProcessStatus::Run), + ProcessStatus::Running + ); + assert_eq!( + ProcessStatus::from(sysinfo::ProcessStatus::Sleep), + ProcessStatus::Sleeping + ); + assert_eq!( + ProcessStatus::from(sysinfo::ProcessStatus::Zombie), + ProcessStatus::Zombie + ); + assert_eq!( + ProcessStatus::from(sysinfo::ProcessStatus::Stop), + ProcessStatus::Stopped + ); + } + + #[test] + fn test_process_monitor_track_untrack() { + let mut monitor = ProcessMonitor::new(); + let node_id = NodeId::new("dirigent/test/proc"); + + monitor.track(1234, node_id.clone()); + assert!(monitor.tracked_pids().contains_key(&1234)); + + monitor.untrack(1234); + assert!(!monitor.tracked_pids().contains_key(&1234)); + } + + #[test] + fn test_process_monitor_refresh_current_process() { + let mut monitor = ProcessMonitor::new(); + let current_pid = std::process::id(); + let node_id = NodeId::new("dirigent/test/self"); + + monitor.track(current_pid, node_id); + + let infos = monitor.refresh(); + assert!( + infos.contains_key(¤t_pid), + "Current process should be found" + ); + + let info = &infos[¤t_pid]; + assert_eq!(info.pid, current_pid); + assert!(!info.name.is_empty()); + assert!(info.memory_bytes > 0); + } + + #[test] + fn test_process_monitor_get_current_process() { + let mut monitor = ProcessMonitor::new(); + let current_pid = std::process::id(); + + let info = monitor.get(current_pid); + assert!(info.is_some(), "Should find current process"); + + let info = info.unwrap(); + assert_eq!(info.pid, current_pid); + } + + #[test] + fn test_process_monitor_is_alive() { + let mut monitor = ProcessMonitor::new(); + + // Current process should be alive + assert!(monitor.is_alive(std::process::id())); + + // A very high PID should not exist + assert!(!monitor.is_alive(u32::MAX)); + } + + #[test] + fn test_process_info_serialization() { + let info = ProcessInfo { + pid: 1234, + name: "test".to_string(), + command: vec!["test".to_string(), "--flag".to_string()], + exe: Some(PathBuf::from("/usr/bin/test")), + cwd: Some(PathBuf::from("/home/user")), + status: ProcessStatus::Running, + cpu_usage_percent: 12.5, + memory_bytes: 1024 * 1024, + virtual_memory_bytes: 2048 * 1024, + start_time_secs: 1700000000, + run_time_secs: 3600, + disk_read_bytes: 1024, + disk_written_bytes: 2048, + }; + + let json = serde_json::to_string(&info).unwrap(); + let deserialized: ProcessInfo = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.pid, 1234); + assert_eq!(deserialized.status, ProcessStatus::Running); + assert_eq!(deserialized.memory_bytes, 1024 * 1024); + } +} diff --git a/crates/dirigent_inspector/src/registry.rs b/crates/dirigent_inspector/src/registry.rs new file mode 100644 index 0000000..bc96f73 --- /dev/null +++ b/crates/dirigent_inspector/src/registry.rs @@ -0,0 +1,459 @@ +use crate::error::Result; +use crate::node::{Inspectable, NodeId, NodeKind, NodeMetadata, NodeState}; +use crate::tree::NodeTree; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::{broadcast, RwLock}; + +/// Events emitted by the registry when the tree changes. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum InspectorEvent { + NodeRegistered { + id: NodeId, + parent: NodeId, + kind: NodeKind, + }, + NodeRemoved { + id: NodeId, + }, + StateChanged { + id: NodeId, + old: NodeState, + new: NodeState, + }, + PropertiesUpdated { + id: NodeId, + keys: Vec<String>, + }, +} + +/// Central registry for the inspector tree. +/// +/// Thread-safe: all operations acquire the internal `RwLock` as needed. +/// Emits `InspectorEvent`s on a broadcast channel for reactive consumers. +pub struct InspectorRegistry { + tree: Arc<RwLock<NodeTree>>, + event_tx: broadcast::Sender<InspectorEvent>, +} + +impl InspectorRegistry { + /// Create a new registry with a root node. + /// + /// The root node ID is `"dirigent"` by default, with `NodeKind::Root`. + pub fn new() -> Self { + let root_id = NodeId::new("dirigent"); + let root_meta = + NodeMetadata::new(NodeKind::Root, "Dirigent").with_state(NodeState::Running); + let tree = NodeTree::new(root_id, root_meta); + let (event_tx, _) = broadcast::channel(500); + + Self { + tree: Arc::new(RwLock::new(tree)), + event_tx, + } + } + + /// Create a registry with a custom root node. + pub fn with_root(root_id: NodeId, root_metadata: NodeMetadata) -> Self { + let tree = NodeTree::new(root_id, root_metadata); + let (event_tx, _) = broadcast::channel(500); + + Self { + tree: Arc::new(RwLock::new(tree)), + event_tx, + } + } + + /// Get the root node ID. + pub async fn root_id(&self) -> NodeId { + let tree = self.tree.read().await; + tree.root_id().clone() + } + + /// Register a new node under the given parent. + /// + /// Returns a `NodeHandle` that the producer can use to update this node. + pub async fn register( + self: &Arc<Self>, + id: NodeId, + parent: &NodeId, + metadata: NodeMetadata, + inspectable: Option<Arc<dyn Inspectable>>, + ) -> Result<crate::handle::NodeHandle> { + let kind = metadata.kind.clone(); + let parent_clone = parent.clone(); + + { + let mut tree = self.tree.write().await; + tree.insert(id.clone(), parent, metadata, inspectable)?; + } + + let _ = self.event_tx.send(InspectorEvent::NodeRegistered { + id: id.clone(), + parent: parent_clone, + kind, + }); + + Ok(crate::handle::NodeHandle::new(id, Arc::clone(self))) + } + + /// Deregister a node (and reparent its children to its parent). + pub async fn deregister(&self, id: &NodeId) -> Result<()> { + { + let mut tree = self.tree.write().await; + tree.remove(id)?; + } + + let _ = self + .event_tx + .send(InspectorEvent::NodeRemoved { id: id.clone() }); + Ok(()) + } + + /// Deregister a node and all its descendants. + pub async fn deregister_subtree(&self, id: &NodeId) -> Result<()> { + { + let mut tree = self.tree.write().await; + tree.remove_subtree(id)?; + } + + let _ = self + .event_tx + .send(InspectorEvent::NodeRemoved { id: id.clone() }); + Ok(()) + } + + /// Update a node's lifecycle state. + pub async fn update_state(&self, id: &NodeId, state: NodeState) -> Result<()> { + let old = { + let mut tree = self.tree.write().await; + tree.update_state(id, state.clone())? + }; + + if old != state { + let _ = self.event_tx.send(InspectorEvent::StateChanged { + id: id.clone(), + old, + new: state, + }); + } + + Ok(()) + } + + /// Update or insert properties on a node. + pub async fn update_properties( + &self, + id: &NodeId, + props: HashMap<String, serde_json::Value>, + ) -> Result<()> { + let keys = { + let mut tree = self.tree.write().await; + tree.update_properties(id, props)? + }; + + if !keys.is_empty() { + let _ = self.event_tx.send(InspectorEvent::PropertiesUpdated { + id: id.clone(), + keys, + }); + } + + Ok(()) + } + + /// Get a clone of a node's metadata. + pub async fn get_node(&self, id: &NodeId) -> Option<NodeMetadata> { + let tree = self.tree.read().await; + tree.get(id).map(|n| n.metadata.clone()) + } + + /// Get metadata for all direct children of a node. + pub async fn get_children(&self, id: &NodeId) -> Vec<(NodeId, NodeMetadata)> { + let tree = self.tree.read().await; + tree.children(id) + .into_iter() + .map(|n| (n.id.clone(), n.metadata.clone())) + .collect() + } + + /// Check if a node exists. + pub async fn contains(&self, id: &NodeId) -> bool { + let tree = self.tree.read().await; + tree.contains(id) + } + + /// Get the total number of nodes. + pub async fn node_count(&self) -> usize { + let tree = self.tree.read().await; + tree.len() + } + + /// Call `Inspectable::inspect()` on a node, if it implements the trait. + /// + /// Returns `None` if the node doesn't exist or doesn't have an `Inspectable` impl. + pub async fn inspect_node(&self, id: &NodeId) -> Option<serde_json::Value> { + let inspectable = { + let tree = self.tree.read().await; + tree.get(id) + .and_then(|n| n.inspectable.as_ref().map(Arc::clone)) + }; + + match inspectable { + Some(i) => Some(i.inspect().await), + None => None, + } + } + + /// Subscribe to tree change events. + pub fn subscribe(&self) -> broadcast::Receiver<InspectorEvent> { + self.event_tx.subscribe() + } + + /// Get a snapshot of the entire tree (see snapshot module). + pub async fn snapshot(&self) -> crate::snapshot::TreeSnapshot { + let tree = self.tree.read().await; + crate::snapshot::TreeSnapshot::from_tree(&tree) + } +} + +impl Default for InspectorRegistry { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_registry_new_has_root() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + assert_eq!(root.as_str(), "dirigent"); + assert!(registry.contains(&root).await); + assert_eq!(registry.node_count().await, 1); + } + + #[tokio::test] + async fn test_register_and_get() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let id = NodeId::new("dirigent/connectors"); + let meta = NodeMetadata::new(NodeKind::Connector, "Connectors"); + + let handle = registry + .register(id.clone(), &root, meta, None) + .await + .unwrap(); + assert_eq!(handle.id().as_str(), "dirigent/connectors"); + + let retrieved = registry.get_node(&id).await.unwrap(); + assert_eq!(retrieved.label, "Connectors"); + assert_eq!(registry.node_count().await, 2); + } + + #[tokio::test] + async fn test_register_emits_event() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + let mut rx = registry.subscribe(); + + let id = NodeId::new("dirigent/test"); + let meta = NodeMetadata::new(NodeKind::Service, "Test"); + + let _handle = registry + .register(id.clone(), &root, meta, None) + .await + .unwrap(); + + let event = rx.recv().await.unwrap(); + match event { + InspectorEvent::NodeRegistered { + id: event_id, + parent, + kind, + } => { + assert_eq!(event_id.as_str(), "dirigent/test"); + assert_eq!(parent.as_str(), "dirigent"); + assert_eq!(kind, NodeKind::Service); + } + _ => panic!("Expected NodeRegistered event"), + } + } + + #[tokio::test] + async fn test_deregister() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let id = NodeId::new("dirigent/temp"); + let meta = NodeMetadata::new(NodeKind::AsyncTask, "Temp"); + let _handle = registry + .register(id.clone(), &root, meta, None) + .await + .unwrap(); + + assert!(registry.contains(&id).await); + registry.deregister(&id).await.unwrap(); + assert!(!registry.contains(&id).await); + } + + #[tokio::test] + async fn test_update_state_emits_event() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let id = NodeId::new("dirigent/svc"); + let meta = NodeMetadata::new(NodeKind::Service, "Service"); + let _handle = registry + .register(id.clone(), &root, meta, None) + .await + .unwrap(); + + let mut rx = registry.subscribe(); + + registry + .update_state(&id, NodeState::Error("crash".into())) + .await + .unwrap(); + + let event = rx.recv().await.unwrap(); + match event { + InspectorEvent::StateChanged { id: eid, old, new } => { + assert_eq!(eid.as_str(), "dirigent/svc"); + assert_eq!(old, NodeState::Initializing); + assert_eq!(new, NodeState::Error("crash".into())); + } + _ => panic!("Expected StateChanged event"), + } + } + + #[tokio::test] + async fn test_no_event_on_same_state() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let id = NodeId::new("dirigent/svc"); + let meta = NodeMetadata::new(NodeKind::Service, "Service").with_state(NodeState::Running); + let _handle = registry + .register(id.clone(), &root, meta, None) + .await + .unwrap(); + + let mut rx = registry.subscribe(); + + // Update to the same state + registry + .update_state(&id, NodeState::Running) + .await + .unwrap(); + + // Should not receive a StateChanged event + let result = tokio::time::timeout(std::time::Duration::from_millis(50), rx.recv()).await; + assert!(result.is_err(), "Should not emit event for same state"); + } + + #[tokio::test] + async fn test_update_properties() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let id = NodeId::new("dirigent/proc"); + let meta = NodeMetadata::new(NodeKind::Process, "Process"); + let _handle = registry + .register(id.clone(), &root, meta, None) + .await + .unwrap(); + + let mut props = HashMap::new(); + props.insert("pid".to_string(), serde_json::json!(1234)); + props.insert("cpu".to_string(), serde_json::json!(50.0)); + + registry.update_properties(&id, props).await.unwrap(); + + let node = registry.get_node(&id).await.unwrap(); + assert_eq!(node.properties["pid"], serde_json::json!(1234)); + assert_eq!(node.properties["cpu"], serde_json::json!(50.0)); + } + + #[tokio::test] + async fn test_get_children() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let _h1 = registry + .register( + NodeId::new("dirigent/a"), + &root, + NodeMetadata::new(NodeKind::Service, "A"), + None, + ) + .await + .unwrap(); + let _h2 = registry + .register( + NodeId::new("dirigent/b"), + &root, + NodeMetadata::new(NodeKind::Service, "B"), + None, + ) + .await + .unwrap(); + + let children = registry.get_children(&root).await; + assert_eq!(children.len(), 2); + let labels: Vec<&str> = children.iter().map(|(_, m)| m.label.as_str()).collect(); + assert!(labels.contains(&"A")); + assert!(labels.contains(&"B")); + } + + #[tokio::test] + async fn test_inspect_node() { + use async_trait::async_trait; + + struct MockInspectable; + + #[async_trait] + impl Inspectable for MockInspectable { + async fn inspect(&self) -> serde_json::Value { + serde_json::json!({ "internal_queue_len": 42 }) + } + fn current_state(&self) -> NodeState { + NodeState::Running + } + } + + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let id = NodeId::new("dirigent/inspectable"); + let meta = NodeMetadata::new(NodeKind::Service, "Inspectable Service"); + let _handle = registry + .register(id.clone(), &root, meta, Some(Arc::new(MockInspectable))) + .await + .unwrap(); + + let result = registry.inspect_node(&id).await; + assert_eq!( + result.unwrap(), + serde_json::json!({ "internal_queue_len": 42 }) + ); + + // Non-inspectable node returns None + let id2 = NodeId::new("dirigent/plain"); + let _h2 = registry + .register( + id2.clone(), + &root, + NodeMetadata::new(NodeKind::Service, "Plain"), + None, + ) + .await + .unwrap(); + assert!(registry.inspect_node(&id2).await.is_none()); + } +} diff --git a/crates/dirigent_inspector/src/snapshot.rs b/crates/dirigent_inspector/src/snapshot.rs new file mode 100644 index 0000000..a2599b7 --- /dev/null +++ b/crates/dirigent_inspector/src/snapshot.rs @@ -0,0 +1,154 @@ +use crate::node::{NodeId, NodeMetadata}; +use crate::tree::NodeTree; +use serde::{Deserialize, Serialize}; + +/// A serializable point-in-time capture of the entire inspector tree. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TreeSnapshot { + pub timestamp: chrono::DateTime<chrono::Utc>, + pub nodes: Vec<NodeSnapshot>, +} + +/// Snapshot of a single node. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NodeSnapshot { + pub id: NodeId, + pub parent: Option<NodeId>, + pub children: Vec<NodeId>, + pub metadata: NodeMetadata, +} + +impl TreeSnapshot { + /// Create a snapshot from a `NodeTree`. + pub(crate) fn from_tree(tree: &NodeTree) -> Self { + let nodes = tree + .all_nodes() + .into_iter() + .map(|n| NodeSnapshot { + id: n.id.clone(), + parent: n.parent.clone(), + children: n.children.clone(), + metadata: n.metadata.clone(), + }) + .collect(); + + Self { + timestamp: chrono::Utc::now(), + nodes, + } + } + + /// Serialize the snapshot to a JSON `Value`. + pub fn to_json(&self) -> serde_json::Value { + serde_json::to_value(self).unwrap_or_default() + } + + /// Serialize the snapshot to pretty-printed JSON string. + pub fn to_json_pretty(&self) -> String { + serde_json::to_string_pretty(self).unwrap_or_default() + } + + /// Number of nodes in the snapshot. + pub fn node_count(&self) -> usize { + self.nodes.len() + } + + /// Find a node by ID. + pub fn find(&self, id: &NodeId) -> Option<&NodeSnapshot> { + self.nodes.iter().find(|n| &n.id == id) + } + + /// Get the root node (the one with no parent). + pub fn root(&self) -> Option<&NodeSnapshot> { + self.nodes.iter().find(|n| n.parent.is_none()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::node::{NodeKind, NodeMetadata, NodeState}; + use crate::tree::NodeTree; + + fn make_tree() -> NodeTree { + let root = NodeId::new("dirigent"); + let mut tree = NodeTree::new( + root.clone(), + NodeMetadata::new(NodeKind::Root, "Dirigent").with_state(NodeState::Running), + ); + tree.insert( + NodeId::new("dirigent/svc"), + &root, + NodeMetadata::new(NodeKind::Service, "Service A").with_state(NodeState::Running), + None, + ) + .unwrap(); + tree.insert( + NodeId::new("dirigent/svc/task"), + &NodeId::new("dirigent/svc"), + NodeMetadata::new(NodeKind::AsyncTask, "Task 1"), + None, + ) + .unwrap(); + tree + } + + #[test] + fn test_snapshot_from_tree() { + let tree = make_tree(); + let snap = TreeSnapshot::from_tree(&tree); + + assert_eq!(snap.node_count(), 3); + assert!(snap.find(&NodeId::new("dirigent")).is_some()); + assert!(snap.find(&NodeId::new("dirigent/svc")).is_some()); + assert!(snap.find(&NodeId::new("dirigent/svc/task")).is_some()); + } + + #[test] + fn test_snapshot_root() { + let tree = make_tree(); + let snap = TreeSnapshot::from_tree(&tree); + let root = snap.root().unwrap(); + assert_eq!(root.id.as_str(), "dirigent"); + assert!(root.parent.is_none()); + } + + #[test] + fn test_snapshot_serialization_roundtrip() { + let tree = make_tree(); + let snap = TreeSnapshot::from_tree(&tree); + + let json = serde_json::to_string(&snap).unwrap(); + let deserialized: TreeSnapshot = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.node_count(), snap.node_count()); + assert_eq!( + deserialized + .find(&NodeId::new("dirigent/svc")) + .unwrap() + .metadata + .label, + "Service A" + ); + } + + #[test] + fn test_snapshot_to_json_pretty() { + let tree = make_tree(); + let snap = TreeSnapshot::from_tree(&tree); + let pretty = snap.to_json_pretty(); + + assert!(pretty.contains("dirigent")); + assert!(pretty.contains("Service A")); + } + + #[test] + fn test_snapshot_to_json_value() { + let tree = make_tree(); + let snap = TreeSnapshot::from_tree(&tree); + let val = snap.to_json(); + + assert!(val.get("timestamp").is_some()); + assert!(val.get("nodes").unwrap().is_array()); + } +} diff --git a/crates/dirigent_inspector/src/system.rs b/crates/dirigent_inspector/src/system.rs new file mode 100644 index 0000000..4f71777 --- /dev/null +++ b/crates/dirigent_inspector/src/system.rs @@ -0,0 +1,231 @@ +use crate::node::NodeId; +use crate::registry::InspectorRegistry; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; +use sysinfo::System; +use tokio::task::JoinHandle; +use tracing::warn; + +/// Host system information. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SystemInfo { + pub hostname: Option<String>, + pub os_name: Option<String>, + pub os_version: Option<String>, + pub kernel_version: Option<String>, + pub arch: String, + pub total_memory_bytes: u64, + pub used_memory_bytes: u64, + pub available_memory_bytes: u64, + pub total_swap_bytes: u64, + pub used_swap_bytes: u64, + pub cpu_count: usize, + pub physical_core_count: Option<usize>, + pub global_cpu_usage_percent: f32, + pub uptime_secs: u64, +} + +/// Monitors host system metrics (memory, CPU, etc.). +pub struct SystemMonitor { + system: System, +} + +impl SystemMonitor { + /// Create a new system monitor. + pub fn new() -> Self { + let mut system = System::new(); + // Initial refresh to populate baseline data + system.refresh_memory(); + system.refresh_cpu_usage(); + Self { system } + } + + /// Refresh and return current system information. + pub fn refresh(&mut self) -> SystemInfo { + self.system.refresh_memory(); + self.system.refresh_cpu_usage(); + + SystemInfo { + hostname: System::host_name(), + os_name: System::name(), + os_version: System::os_version(), + kernel_version: System::kernel_version(), + arch: System::cpu_arch(), + total_memory_bytes: self.system.total_memory(), + used_memory_bytes: self.system.used_memory(), + available_memory_bytes: self.system.available_memory(), + total_swap_bytes: self.system.total_swap(), + used_swap_bytes: self.system.used_swap(), + cpu_count: self.system.cpus().len(), + physical_core_count: self.system.physical_core_count(), + global_cpu_usage_percent: self.system.global_cpu_usage(), + uptime_secs: System::uptime(), + } + } + + /// Spawn a background task that periodically updates a system node in the registry. + /// + /// The `node_id` should already be registered in the tree (e.g., "dirigent/system/host"). + pub fn start_polling( + mut self, + registry: Arc<InspectorRegistry>, + node_id: NodeId, + interval: Duration, + ) -> JoinHandle<()> { + tokio::spawn(async move { + let mut ticker = tokio::time::interval(interval); + loop { + ticker.tick().await; + + let info = self.refresh(); + + let mut props = HashMap::new(); + props.insert("hostname".to_string(), serde_json::json!(info.hostname)); + props.insert("os_name".to_string(), serde_json::json!(info.os_name)); + props.insert("os_version".to_string(), serde_json::json!(info.os_version)); + props.insert("arch".to_string(), serde_json::json!(info.arch)); + props.insert( + "total_memory_bytes".to_string(), + serde_json::json!(info.total_memory_bytes), + ); + props.insert( + "used_memory_bytes".to_string(), + serde_json::json!(info.used_memory_bytes), + ); + props.insert( + "available_memory_bytes".to_string(), + serde_json::json!(info.available_memory_bytes), + ); + props.insert( + "total_swap_bytes".to_string(), + serde_json::json!(info.total_swap_bytes), + ); + props.insert( + "used_swap_bytes".to_string(), + serde_json::json!(info.used_swap_bytes), + ); + props.insert("cpu_count".to_string(), serde_json::json!(info.cpu_count)); + props.insert( + "physical_core_count".to_string(), + serde_json::json!(info.physical_core_count), + ); + props.insert( + "global_cpu_usage_percent".to_string(), + serde_json::json!(info.global_cpu_usage_percent), + ); + props.insert( + "uptime_secs".to_string(), + serde_json::json!(info.uptime_secs), + ); + + if let Err(e) = registry.update_properties(&node_id, props).await { + warn!( + node_id = %node_id, + error = %e, + "Failed to update system node properties" + ); + } + } + }) + } +} + +impl Default for SystemMonitor { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_system_monitor_refresh() { + let mut monitor = SystemMonitor::new(); + + // Need a brief sleep for CPU usage sampling + std::thread::sleep(std::time::Duration::from_millis(200)); + + let info = monitor.refresh(); + + assert!(info.total_memory_bytes > 0, "Should have total memory"); + assert!(info.cpu_count > 0, "Should have at least 1 CPU"); + assert!(info.uptime_secs > 0, "Should have uptime"); + assert!(info.arch.len() > 0, "Should have arch info"); + } + + #[test] + fn test_system_info_serialization() { + let info = SystemInfo { + hostname: Some("testhost".to_string()), + os_name: Some("macOS".to_string()), + os_version: Some("14.0".to_string()), + kernel_version: Some("23.0.0".to_string()), + arch: "arm64".to_string(), + total_memory_bytes: 16 * 1024 * 1024 * 1024, + used_memory_bytes: 8 * 1024 * 1024 * 1024, + available_memory_bytes: 8 * 1024 * 1024 * 1024, + total_swap_bytes: 2 * 1024 * 1024 * 1024, + used_swap_bytes: 512 * 1024 * 1024, + cpu_count: 10, + physical_core_count: Some(10), + global_cpu_usage_percent: 25.5, + uptime_secs: 86400, + }; + + let json = serde_json::to_string(&info).unwrap(); + let deserialized: SystemInfo = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.hostname, Some("testhost".to_string())); + assert_eq!(deserialized.total_memory_bytes, 16 * 1024 * 1024 * 1024); + assert_eq!(deserialized.cpu_count, 10); + } + + #[tokio::test] + async fn test_system_monitor_polling() { + use crate::node::{NodeKind, NodeMetadata, NodeState}; + + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + // Register system node + let node_id = NodeId::new("dirigent/system/host"); + let mut handle = registry + .register( + node_id.clone(), + &root, + NodeMetadata::new(NodeKind::System, "Host System").with_state(NodeState::Running), + None, + ) + .await + .unwrap(); + // Detach so polling task can update it + // (handle would deregister on drop otherwise) + handle.detach(); + + let monitor = SystemMonitor::new(); + let task = monitor.start_polling( + Arc::clone(®istry), + node_id.clone(), + Duration::from_millis(100), + ); + + // Wait for at least one poll cycle + tokio::time::sleep(Duration::from_millis(250)).await; + + let meta = registry.get_node(&node_id).await.unwrap(); + assert!( + meta.properties.contains_key("total_memory_bytes"), + "Should have system metrics after polling" + ); + assert!( + meta.properties.contains_key("cpu_count"), + "Should have CPU count" + ); + + task.abort(); + } +} diff --git a/crates/dirigent_inspector/src/tree.rs b/crates/dirigent_inspector/src/tree.rs new file mode 100644 index 0000000..1ac6c80 --- /dev/null +++ b/crates/dirigent_inspector/src/tree.rs @@ -0,0 +1,544 @@ +use crate::error::{InspectorError, Result}; +use crate::node::{Inspectable, NodeId, NodeMetadata, NodeState}; +use std::collections::HashMap; +use std::sync::Arc; + +/// A node within the inspector tree, holding metadata and relationships. +pub struct TreeNode { + pub id: NodeId, + pub metadata: NodeMetadata, + pub parent: Option<NodeId>, + pub children: Vec<NodeId>, + pub inspectable: Option<Arc<dyn Inspectable>>, +} + +impl TreeNode { + fn new( + id: NodeId, + metadata: NodeMetadata, + parent: Option<NodeId>, + inspectable: Option<Arc<dyn Inspectable>>, + ) -> Self { + Self { + id, + metadata, + parent, + children: Vec::new(), + inspectable, + } + } +} + +/// The inspector tree: a rooted tree of `TreeNode`s with parent-child relationships. +/// +/// All mutations go through this struct. Thread safety is provided by wrapping +/// `NodeTree` in `Arc<RwLock<NodeTree>>` at the registry level. +pub struct NodeTree { + nodes: HashMap<NodeId, TreeNode>, + root: NodeId, +} + +impl NodeTree { + /// Create a new tree with a root node. + /// + /// The root node is always present and cannot be removed. + pub fn new(root_id: NodeId, root_metadata: NodeMetadata) -> Self { + let mut nodes = HashMap::new(); + nodes.insert( + root_id.clone(), + TreeNode::new(root_id.clone(), root_metadata, None, None), + ); + Self { + nodes, + root: root_id, + } + } + + /// Get the root node ID. + pub fn root_id(&self) -> &NodeId { + &self.root + } + + /// Insert a new node as a child of the given parent. + pub fn insert( + &mut self, + id: NodeId, + parent: &NodeId, + metadata: NodeMetadata, + inspectable: Option<Arc<dyn Inspectable>>, + ) -> Result<()> { + if self.nodes.contains_key(&id) { + return Err(InspectorError::NodeAlreadyExists(id)); + } + if !self.nodes.contains_key(parent) { + return Err(InspectorError::ParentNotFound(parent.clone())); + } + + let node = TreeNode::new(id.clone(), metadata, Some(parent.clone()), inspectable); + self.nodes.insert(id.clone(), node); + + // Add as child of parent + if let Some(parent_node) = self.nodes.get_mut(parent) { + parent_node.children.push(id); + } + + Ok(()) + } + + /// Remove a node and reparent its children to the node's parent. + /// + /// The root node cannot be removed. + pub fn remove(&mut self, id: &NodeId) -> Result<()> { + if id == &self.root { + return Err(InspectorError::CannotRemoveRoot); + } + + let node = self + .nodes + .get(id) + .ok_or_else(|| InspectorError::NodeNotFound(id.clone()))?; + + let parent_id = node.parent.clone(); + let children: Vec<NodeId> = node.children.clone(); + + // Reparent children to the removed node's parent + if let Some(ref parent_id) = parent_id { + for child_id in &children { + if let Some(child) = self.nodes.get_mut(child_id) { + child.parent = Some(parent_id.clone()); + } + } + // Update parent's children: remove this node, add reparented children + if let Some(parent) = self.nodes.get_mut(parent_id) { + parent.children.retain(|c| c != id); + parent.children.extend(children); + } + } + + self.nodes.remove(id); + Ok(()) + } + + /// Remove a node and all its descendants. + /// + /// The root node cannot be removed. + pub fn remove_subtree(&mut self, id: &NodeId) -> Result<()> { + if id == &self.root { + return Err(InspectorError::CannotRemoveRoot); + } + + if !self.nodes.contains_key(id) { + return Err(InspectorError::NodeNotFound(id.clone())); + } + + // Collect all descendant IDs via BFS + let mut to_remove = Vec::new(); + let mut queue = vec![id.clone()]; + while let Some(current) = queue.pop() { + if let Some(node) = self.nodes.get(¤t) { + queue.extend(node.children.clone()); + } + to_remove.push(current); + } + + // Remove this node from parent's children list + let parent_id = self.nodes.get(id).and_then(|n| n.parent.clone()); + if let Some(parent_id) = parent_id { + if let Some(parent) = self.nodes.get_mut(&parent_id) { + parent.children.retain(|c| c != id); + } + } + + // Remove all collected nodes + for node_id in &to_remove { + self.nodes.remove(node_id); + } + + Ok(()) + } + + /// Get a reference to a node by ID. + pub fn get(&self, id: &NodeId) -> Option<&TreeNode> { + self.nodes.get(id) + } + + /// Get a mutable reference to a node by ID. + pub fn get_mut(&mut self, id: &NodeId) -> Option<&mut TreeNode> { + self.nodes.get_mut(id) + } + + /// Get the direct children of a node. + pub fn children(&self, id: &NodeId) -> Vec<&TreeNode> { + self.nodes + .get(id) + .map(|node| { + node.children + .iter() + .filter_map(|child_id| self.nodes.get(child_id)) + .collect() + }) + .unwrap_or_default() + } + + /// Get all ancestors of a node, from immediate parent to root. + pub fn ancestors(&self, id: &NodeId) -> Vec<&TreeNode> { + let mut result = Vec::new(); + let mut current = self.nodes.get(id).and_then(|n| n.parent.as_ref()); + while let Some(parent_id) = current { + if let Some(parent) = self.nodes.get(parent_id) { + result.push(parent); + current = parent.parent.as_ref(); + } else { + break; + } + } + result + } + + /// Get all nodes in the tree. + pub fn all_nodes(&self) -> Vec<&TreeNode> { + self.nodes.values().collect() + } + + /// Total number of nodes in the tree. + pub fn len(&self) -> usize { + self.nodes.len() + } + + /// Whether the tree is empty (only root). + pub fn is_empty(&self) -> bool { + self.nodes.len() <= 1 + } + + /// Update a node's state, returning the old state. + pub fn update_state(&mut self, id: &NodeId, state: NodeState) -> Result<NodeState> { + let node = self + .nodes + .get_mut(id) + .ok_or_else(|| InspectorError::NodeNotFound(id.clone()))?; + let old = node.metadata.state.clone(); + node.metadata.state = state; + node.metadata.last_updated = chrono::Utc::now(); + Ok(old) + } + + /// Update or insert properties on a node. Returns only the keys whose values actually changed. + /// + /// A key is considered changed if it is new (not previously present) or if its value + /// differs from the stored value (compared via `serde_json::Value::PartialEq`). + /// The `last_updated` timestamp is only bumped when at least one value changed. + pub fn update_properties( + &mut self, + id: &NodeId, + props: HashMap<String, serde_json::Value>, + ) -> Result<Vec<String>> { + let node = self + .nodes + .get_mut(id) + .ok_or_else(|| InspectorError::NodeNotFound(id.clone()))?; + + let mut changed_keys = Vec::new(); + for (key, new_value) in props { + let is_changed = match node.metadata.properties.get(&key) { + Some(existing) => existing != &new_value, + None => true, // new key + }; + if is_changed { + changed_keys.push(key.clone()); + node.metadata.properties.insert(key, new_value); + } + } + + if !changed_keys.is_empty() { + node.metadata.last_updated = chrono::Utc::now(); + } + + Ok(changed_keys) + } + + /// Check if a node exists. + pub fn contains(&self, id: &NodeId) -> bool { + self.nodes.contains_key(id) + } + + /// Get all descendant IDs of a node (not including the node itself). + pub fn descendants(&self, id: &NodeId) -> Vec<NodeId> { + let mut result = Vec::new(); + let mut queue: Vec<NodeId> = self + .nodes + .get(id) + .map(|n| n.children.clone()) + .unwrap_or_default(); + + while let Some(current) = queue.pop() { + if let Some(node) = self.nodes.get(¤t) { + queue.extend(node.children.clone()); + } + result.push(current); + } + result + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::node::{NodeKind, NodeMetadata}; + + fn root_meta() -> NodeMetadata { + NodeMetadata::new(NodeKind::Root, "Dirigent") + } + + fn connector_meta(label: &str) -> NodeMetadata { + NodeMetadata::new(NodeKind::Connector, label).with_state(NodeState::Running) + } + + fn make_tree() -> NodeTree { + let root = NodeId::new("dirigent"); + let mut tree = NodeTree::new(root.clone(), root_meta()); + + // Add category nodes + tree.insert( + root.child("connectors"), + &root, + NodeMetadata::new(NodeKind::Root, "Connectors"), + None, + ) + .unwrap(); + tree.insert( + root.child("services"), + &root, + NodeMetadata::new(NodeKind::Root, "Services"), + None, + ) + .unwrap(); + + // Add connectors + let connectors = root.child("connectors"); + tree.insert( + connectors.child("opencode-1"), + &connectors, + connector_meta("OpenCode #1"), + None, + ) + .unwrap(); + tree.insert( + connectors.child("acp-claude"), + &connectors, + connector_meta("ACP Claude"), + None, + ) + .unwrap(); + + // Add a process child under acp-claude + let acp = connectors.child("acp-claude"); + tree.insert( + acp.child("stdio-process"), + &acp, + NodeMetadata::new(NodeKind::Process, "stdio-transport") + .with_state(NodeState::Running) + .with_property("pid", serde_json::json!(42)), + None, + ) + .unwrap(); + + tree + } + + #[test] + fn test_new_tree_has_root() { + let root = NodeId::new("dirigent"); + let tree = NodeTree::new(root.clone(), root_meta()); + + assert_eq!(tree.len(), 1); + assert!(tree.get(&root).is_some()); + assert_eq!(tree.root_id(), &root); + } + + #[test] + fn test_insert_and_lookup() { + let tree = make_tree(); + + assert_eq!(tree.len(), 6); // root + connectors + services + opencode + acp + stdio + assert!(tree.contains(&NodeId::new("dirigent/connectors/acp-claude"))); + assert!(tree.contains(&NodeId::new("dirigent/connectors/acp-claude/stdio-process"))); + } + + #[test] + fn test_insert_duplicate_fails() { + let mut tree = make_tree(); + let result = tree.insert( + NodeId::new("dirigent/connectors/opencode-1"), + &NodeId::new("dirigent/connectors"), + connector_meta("Duplicate"), + None, + ); + assert!(matches!(result, Err(InspectorError::NodeAlreadyExists(_)))); + } + + #[test] + fn test_insert_missing_parent_fails() { + let mut tree = make_tree(); + let result = tree.insert( + NodeId::new("dirigent/nonexistent/child"), + &NodeId::new("dirigent/nonexistent"), + connector_meta("Orphan"), + None, + ); + assert!(matches!(result, Err(InspectorError::ParentNotFound(_)))); + } + + #[test] + fn test_children() { + let tree = make_tree(); + let connectors = NodeId::new("dirigent/connectors"); + let children = tree.children(&connectors); + + assert_eq!(children.len(), 2); + let child_ids: Vec<&str> = children.iter().map(|c| c.id.as_str()).collect(); + assert!(child_ids.contains(&"dirigent/connectors/opencode-1")); + assert!(child_ids.contains(&"dirigent/connectors/acp-claude")); + } + + #[test] + fn test_ancestors() { + let tree = make_tree(); + let stdio = NodeId::new("dirigent/connectors/acp-claude/stdio-process"); + let ancestors = tree.ancestors(&stdio); + + assert_eq!(ancestors.len(), 3); + assert_eq!(ancestors[0].id.as_str(), "dirigent/connectors/acp-claude"); + assert_eq!(ancestors[1].id.as_str(), "dirigent/connectors"); + assert_eq!(ancestors[2].id.as_str(), "dirigent"); + } + + #[test] + fn test_remove_reparents_children() { + let mut tree = make_tree(); + + // Remove acp-claude; its child (stdio-process) should be reparented to "connectors" + tree.remove(&NodeId::new("dirigent/connectors/acp-claude")) + .unwrap(); + + assert!(!tree.contains(&NodeId::new("dirigent/connectors/acp-claude"))); + assert!(tree.contains(&NodeId::new("dirigent/connectors/acp-claude/stdio-process"))); + + // stdio-process should now be a child of connectors + let connectors = tree.get(&NodeId::new("dirigent/connectors")).unwrap(); + assert!(connectors + .children + .contains(&NodeId::new("dirigent/connectors/acp-claude/stdio-process"))); + + // stdio-process parent should be connectors + let stdio = tree + .get(&NodeId::new("dirigent/connectors/acp-claude/stdio-process")) + .unwrap(); + assert_eq!( + stdio.parent.as_ref().unwrap().as_str(), + "dirigent/connectors" + ); + } + + #[test] + fn test_remove_subtree() { + let mut tree = make_tree(); + + tree.remove_subtree(&NodeId::new("dirigent/connectors/acp-claude")) + .unwrap(); + + assert!(!tree.contains(&NodeId::new("dirigent/connectors/acp-claude"))); + assert!(!tree.contains(&NodeId::new("dirigent/connectors/acp-claude/stdio-process"))); + assert_eq!(tree.len(), 4); // root + connectors + services + opencode + + // connectors should only have opencode-1 + let connectors = tree.get(&NodeId::new("dirigent/connectors")).unwrap(); + assert_eq!(connectors.children.len(), 1); + } + + #[test] + fn test_cannot_remove_root() { + let mut tree = make_tree(); + let result = tree.remove(&NodeId::new("dirigent")); + assert!(matches!(result, Err(InspectorError::CannotRemoveRoot))); + + let result = tree.remove_subtree(&NodeId::new("dirigent")); + assert!(matches!(result, Err(InspectorError::CannotRemoveRoot))); + } + + #[test] + fn test_update_state() { + let mut tree = make_tree(); + let id = NodeId::new("dirigent/connectors/opencode-1"); + + let old = tree + .update_state(&id, NodeState::Error("timeout".into())) + .unwrap(); + assert_eq!(old, NodeState::Running); + + let node = tree.get(&id).unwrap(); + assert_eq!(node.metadata.state, NodeState::Error("timeout".into())); + } + + #[test] + fn test_update_properties() { + let mut tree = make_tree(); + let id = NodeId::new("dirigent/connectors/acp-claude/stdio-process"); + + let mut props = HashMap::new(); + props.insert("cpu_percent".to_string(), serde_json::json!(45.2)); + props.insert("memory_mb".to_string(), serde_json::json!(128)); + + let keys = tree.update_properties(&id, props).unwrap(); + assert_eq!(keys.len(), 2); + + let node = tree.get(&id).unwrap(); + assert_eq!(node.metadata.properties["pid"], serde_json::json!(42)); // original preserved + assert_eq!( + node.metadata.properties["cpu_percent"], + serde_json::json!(45.2) + ); + } + + #[test] + fn test_update_properties_no_change() { + let mut tree = make_tree(); + let id = NodeId::new("dirigent/connectors/acp-claude/stdio-process"); + + // First update: new key, should be reported as changed + let mut props = HashMap::new(); + props.insert("cpu_percent".to_string(), serde_json::json!(45.2)); + let keys = tree.update_properties(&id, props).unwrap(); + assert_eq!(keys.len(), 1); + + // Second update: same value, should NOT be reported as changed + let mut props = HashMap::new(); + props.insert("cpu_percent".to_string(), serde_json::json!(45.2)); + let keys = tree.update_properties(&id, props).unwrap(); + assert_eq!(keys.len(), 0, "Same value should not be reported as changed"); + + // Third update: different value, should be reported + let mut props = HashMap::new(); + props.insert("cpu_percent".to_string(), serde_json::json!(50.0)); + let keys = tree.update_properties(&id, props).unwrap(); + assert_eq!(keys.len(), 1); + } + + #[test] + fn test_descendants() { + let tree = make_tree(); + let root = NodeId::new("dirigent"); + let descendants = tree.descendants(&root); + + assert_eq!(descendants.len(), 5); // all except root itself + } + + #[test] + fn test_is_empty() { + let root = NodeId::new("dirigent"); + let tree = NodeTree::new(root, root_meta()); + assert!(tree.is_empty()); + + let tree = make_tree(); + assert!(!tree.is_empty()); + } +} diff --git a/crates/dirigent_inspector/tests/integration.rs b/crates/dirigent_inspector/tests/integration.rs new file mode 100644 index 0000000..cf40c09 --- /dev/null +++ b/crates/dirigent_inspector/tests/integration.rs @@ -0,0 +1,355 @@ +use dirigent_inspector::*; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +/// Full integration test simulating a Dirigent-like setup: +/// - Root node "dirigent" +/// - Connector nodes under "dirigent/connectors" +/// - Process node under a connector +/// - Service nodes under "dirigent/services" +/// - System node under "dirigent/system" +/// - Bidirectional channel communication +/// - Snapshot capture +#[tokio::test] +async fn test_full_inspector_tree() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + assert_eq!(root.as_str(), "dirigent"); + + // Subscribe to events + let mut event_rx = registry.subscribe(); + + // -- Build the tree structure -- + + // Category: connectors + let connectors_handle = registry + .register( + NodeId::new("dirigent/connectors"), + &root, + NodeMetadata::new(NodeKind::Custom("category".into()), "Connectors") + .with_state(NodeState::Running), + None, + ) + .await + .unwrap(); + + // Category: services + let mut services_handle = registry + .register( + NodeId::new("dirigent/services"), + &root, + NodeMetadata::new(NodeKind::Custom("category".into()), "Services") + .with_state(NodeState::Running), + None, + ) + .await + .unwrap(); + + // Category: system + let mut system_handle = registry + .register( + NodeId::new("dirigent/system"), + &root, + NodeMetadata::new(NodeKind::Custom("category".into()), "System") + .with_state(NodeState::Running), + None, + ) + .await + .unwrap(); + + // Connector: ACP Claude + let acp_handle = connectors_handle + .register_child( + NodeId::new("dirigent/connectors/acp-claude"), + NodeMetadata::new(NodeKind::Connector, "ACP Claude") + .with_state(NodeState::Running) + .with_property("transport", serde_json::json!("stdio")), + None, + ) + .await + .unwrap(); + + // Process: stdio transport child + let current_pid = std::process::id(); + let proc_handle = acp_handle + .register_child( + NodeId::new("dirigent/connectors/acp-claude/stdio-process"), + NodeMetadata::new(NodeKind::Process, "stdio-transport") + .with_state(NodeState::Running) + .with_property("pid", serde_json::json!(current_pid)), + None, + ) + .await + .unwrap(); + + // Connector: OpenCode + let _opencode_handle = connectors_handle + .register_child( + NodeId::new("dirigent/connectors/opencode-1"), + NodeMetadata::new(NodeKind::Connector, "OpenCode #1") + .with_state(NodeState::Running) + .with_property("base_url", serde_json::json!("http://localhost:12225")), + None, + ) + .await + .unwrap(); + + // Service: Archivist + let _archivist_handle = services_handle + .register_child( + NodeId::new("dirigent/services/archivist"), + NodeMetadata::new(NodeKind::Service, "Archivist EventHandler") + .with_state(NodeState::Idle), + None, + ) + .await + .unwrap(); + + // System: Host + let _host_handle = system_handle + .register_child( + NodeId::new("dirigent/system/host"), + NodeMetadata::new(NodeKind::System, "Host Machine").with_state(NodeState::Running), + None, + ) + .await + .unwrap(); + + // -- Verify tree structure -- + // dirigent, connectors, services, system, acp-claude, stdio-process, opencode-1, archivist, host = 9 + assert_eq!(registry.node_count().await, 9); + + // Check children + let root_children = registry.get_children(&root).await; + assert_eq!(root_children.len(), 3); // connectors, services, system + + let connector_children = registry + .get_children(&NodeId::new("dirigent/connectors")) + .await; + assert_eq!(connector_children.len(), 2); // acp-claude, opencode-1 + + let acp_children = registry + .get_children(&NodeId::new("dirigent/connectors/acp-claude")) + .await; + assert_eq!(acp_children.len(), 1); // stdio-process + + // -- State transitions -- + proc_handle + .set_state(NodeState::Busy("processing message".into())) + .await + .unwrap(); + + let proc_meta = registry + .get_node(&NodeId::new("dirigent/connectors/acp-claude/stdio-process")) + .await + .unwrap(); + assert_eq!( + proc_meta.state, + NodeState::Busy("processing message".into()) + ); + + // -- Property updates -- + let mut props = HashMap::new(); + props.insert("cpu_percent".to_string(), serde_json::json!(23.5)); + props.insert("memory_mb".to_string(), serde_json::json!(256)); + proc_handle.set_properties(props).await.unwrap(); + + let proc_meta = registry + .get_node(&NodeId::new("dirigent/connectors/acp-claude/stdio-process")) + .await + .unwrap(); + assert_eq!(proc_meta.properties["cpu_percent"], serde_json::json!(23.5)); + assert_eq!(proc_meta.properties["pid"], serde_json::json!(current_pid)); // original preserved + + // -- Snapshot -- + let snapshot = registry.snapshot().await; + assert_eq!(snapshot.node_count(), 9); + + // Verify snapshot structure + let snap_root = snapshot.root().unwrap(); + assert_eq!(snap_root.id.as_str(), "dirigent"); + assert_eq!(snap_root.children.len(), 3); + + let snap_proc = snapshot + .find(&NodeId::new("dirigent/connectors/acp-claude/stdio-process")) + .unwrap(); + assert_eq!( + snap_proc.parent, + Some(NodeId::new("dirigent/connectors/acp-claude")) + ); + assert_eq!( + snap_proc.metadata.state, + NodeState::Busy("processing message".into()) + ); + + // Snapshot serialization roundtrip + let json = serde_json::to_string(&snapshot).unwrap(); + let deserialized: TreeSnapshot = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.node_count(), 9); + + // -- Events -- + // Drain all events that were emitted during setup + let mut event_count = 0; + while let Ok(event) = event_rx.try_recv() { + event_count += 1; + // Just verify they're valid events + match event { + InspectorEvent::NodeRegistered { .. } + | InspectorEvent::StateChanged { .. } + | InspectorEvent::PropertiesUpdated { .. } + | InspectorEvent::NodeRemoved { .. } => {} + } + } + assert!(event_count > 0, "Should have received events"); + + // -- Cleanup: detach category handles so they survive -- + services_handle.detach(); + system_handle.detach(); +} + +/// Test process monitor with the current process. +#[tokio::test] +async fn test_process_monitor_integration() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + let current_pid = std::process::id(); + let node_id = NodeId::new("dirigent/test-process"); + + let mut handle = registry + .register( + node_id.clone(), + &root, + NodeMetadata::new(NodeKind::Process, "Test Process").with_state(NodeState::Running), + None, + ) + .await + .unwrap(); + handle.detach(); + + // Create monitor and track current process + let mut monitor = ProcessMonitor::new(); + monitor.track(current_pid, node_id.clone()); + + // Start polling + let task = monitor.start_polling(Arc::clone(®istry), Duration::from_millis(100)); + + // Wait for data to be populated + tokio::time::sleep(Duration::from_millis(350)).await; + + let meta = registry.get_node(&node_id).await.unwrap(); + assert!( + meta.properties.contains_key("pid"), + "Should have PID property" + ); + assert!( + meta.properties.contains_key("memory_bytes"), + "Should have memory property" + ); + assert_eq!(meta.properties["pid"], serde_json::json!(current_pid)); + + task.abort(); +} + +/// Test bidirectional channel communication with a simulated node loop. +#[tokio::test] +async fn test_channel_integration() { + let (sender, mut receiver) = inspector_channel(10); + + // Simulate a node loop that handles commands + let node_loop = tokio::spawn(async move { + let mut handled = 0; + while let Some((cmd, resp_tx)) = receiver.recv().await { + let response = match cmd.kind { + CommandKind::Introspect => CommandResponse::ok( + &cmd.id, + serde_json::json!({ + "queue_depth": receiver.pending_count(), + "sessions_active": 3 + }), + ), + CommandKind::Execute(ref name) if name == "restart" => { + CommandResponse::ok(&cmd.id, serde_json::json!("restarting...")) + } + _ => CommandResponse::err(&cmd.id, "unknown command"), + }; + let _ = resp_tx.send(response); + handled += 1; + if handled >= 2 { + break; + } + } + }); + + // Send introspect command + let resp = sender + .send(NodeCommand { + id: "cmd-1".to_string(), + kind: CommandKind::Introspect, + payload: serde_json::Value::Null, + }) + .await + .unwrap(); + assert!(resp.success); + assert_eq!(resp.data["sessions_active"], 3); + + // Send execute command + let resp = sender + .send(NodeCommand { + id: "cmd-2".to_string(), + kind: CommandKind::Execute("restart".to_string()), + payload: serde_json::Value::Null, + }) + .await + .unwrap(); + assert!(resp.success); + + node_loop.await.unwrap(); +} + +/// Test that dropping a handle auto-deregisters, and subtree removal works. +#[tokio::test] +async fn test_lifecycle_management() { + let registry = Arc::new(InspectorRegistry::new()); + let root = registry.root_id().await; + + // Build a subtree + let parent = registry + .register( + NodeId::new("dirigent/parent"), + &root, + NodeMetadata::new(NodeKind::Connector, "Parent"), + None, + ) + .await + .unwrap(); + + let _child1 = parent + .register_child( + NodeId::new("dirigent/parent/child1"), + NodeMetadata::new(NodeKind::Process, "Child 1"), + None, + ) + .await + .unwrap(); + + let _child2 = parent + .register_child( + NodeId::new("dirigent/parent/child2"), + NodeMetadata::new(NodeKind::AsyncTask, "Child 2"), + None, + ) + .await + .unwrap(); + + assert_eq!(registry.node_count().await, 4); // root + parent + 2 children + + // Remove entire subtree + registry + .deregister_subtree(&NodeId::new("dirigent/parent")) + .await + .unwrap(); + + assert_eq!(registry.node_count().await, 1); // only root remains +} diff --git a/crates/dirigent_langfuse/CLAUDE.md b/crates/dirigent_langfuse/CLAUDE.md new file mode 100644 index 0000000..d2dcc10 --- /dev/null +++ b/crates/dirigent_langfuse/CLAUDE.md @@ -0,0 +1,72 @@ +# Package: dirigent_langfuse + +Phase 4 stream backend that mirrors BusEvents to a Langfuse ingestion +endpoint. + +## Scope + +- `LangfuseFactory` registered as `kind = "langfuse"` in the + `StreamFactoryRegistry`. +- `LangfuseStream` implements `SessionStream`: + - Maps each `BusEvent` via `mapping::bus_event_to_items`. + - Buffers up to 32 items per flush; flushes eagerly when full and on + shutdown. + - POSTs `{host}/api/public/ingestion` with basic-auth + `(public_key, secret_key)`. + +## File map + +- `src/lib.rs` — public API: `LangfuseStream`, `LangfuseConfig`, + `LangfuseFactory`. +- `src/client.rs` — `LangfuseClient` (reqwest wrapper with retry) and + the `LangfuseStream` implementation. +- `src/mapping.rs` — `bus_event_to_items` mapping. +- `src/factory.rs` — `StreamFactory` impl. + +## Event → ingestion mapping + +| BusEvent variant | Langfuse item | +|------------------|---------------| +| `SessionCreated` | `trace-create` (id = `scroll_id`) | +| `MessageStarted` | `generation-create` | +| `MessageCompleted` | `generation-update` with output | +| `SessionUpdate` (non-tool) | skipped | +| All others | skipped | + +Events without a bound `scroll_id` (no late-bind hit) are dropped — the +implementation does NOT buffer pending events keyed by connector_id / +native_session_id in Phase 4. If buffering is needed later, extend +`LangfuseStream::on_event`. + +## Failure modes + +- Transport error → `StreamOutcome::Failed(StreamError::Transport)`. + Health drift applies; the stream goes Degraded after one failure and + Unavailable after five consecutive failures. +- 5xx response → retried up to 3 times with exponential backoff + (100ms → 200 → 400 → 800, capped at 1s). +- 4xx response → returned as `LangfuseError::Status(code)`; no retry. +- Empty scroll_id → `StreamOutcome::Skipped` (not a failure). + +## Configuration + +```toml +[[streams]] +name = "langfuse-prod" +type = "langfuse" +enabled = true +[streams.scope] +kind = "connector" +connector_uid = "01985d00-..." +[streams.params] +host = "https://langfuse.example.com" +public_key = "pk-lf-..." +secret_key = "sk-lf-..." +``` + +## Deferred + +- Tool-call → span mapping (`SpanCreate`/`SpanUpdate`): scaffolded but + not yet populated. +- Buffering pending events keyed by `(connector_id, native_session_id)` + for late-bind scenarios. diff --git a/crates/dirigent_langfuse/Cargo.toml b/crates/dirigent_langfuse/Cargo.toml new file mode 100644 index 0000000..3c97f02 --- /dev/null +++ b/crates/dirigent_langfuse/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "dirigent_langfuse" +version = "0.1.0" +edition = "2021" + +[features] +default = [] +server = ["dep:reqwest", "dep:tokio", "dep:dirigent_core", "dirigent_core/server"] + +[dependencies] +async-trait = "0.1" +chrono = { version = "0.4", features = ["serde"] } +dirigent_core = { path = "../dirigent_core", optional = true } +dirigent_protocol = { path = "../dirigent_protocol" } +reqwest = { version = "0.12", optional = true, features = ["json"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "2.0" +tokio = { version = "1", optional = true, features = ["rt", "sync", "macros"] } +toml = "0.8" +tracing = "0.1" +url = "2" +uuid = { version = "1", features = ["v4", "v7"] } diff --git a/crates/dirigent_langfuse/src/client.rs b/crates/dirigent_langfuse/src/client.rs new file mode 100644 index 0000000..131df3d --- /dev/null +++ b/crates/dirigent_langfuse/src/client.rs @@ -0,0 +1,204 @@ +//! Langfuse ingestion client. Phase 4 feature-gated on `server`. + +use std::sync::Arc; +#[cfg(feature = "server")] +use std::time::Duration; + +use async_trait::async_trait; +use chrono::Utc; +use thiserror::Error; +#[cfg(feature = "server")] +use tokio::sync::Mutex; +#[cfg(feature = "server")] +use tracing::warn; + +use dirigent_protocol::streaming::{ + BusEvent, SessionStream, StreamKind, StreamOutcome, StreamScope, StreamSummary, +}; +#[cfg(feature = "server")] +use dirigent_protocol::streaming::StreamError; + +#[cfg(feature = "server")] +use crate::mapping::{bus_event_to_items, IngestItem}; + +/// Langfuse stream configuration (credentials + host). +#[derive(Debug, Clone)] +pub struct LangfuseConfig { + pub host: String, + pub public_key: String, + pub secret_key: String, +} + +#[derive(Debug, Error)] +#[cfg_attr(not(feature = "server"), allow(dead_code))] +pub enum LangfuseError { + #[error("transport: {0}")] + Transport(String), + #[error("unexpected status: {0}")] + Status(u16), + #[error("serialisation: {0}")] + Serialisation(String), +} + +/// Thin wrapper around `reqwest::Client` that POSTs batches to +/// `{host}/api/public/ingestion` with HTTP basic auth. +#[cfg(feature = "server")] +pub(crate) struct LangfuseClient { + http: reqwest::Client, + host: String, + auth: (String, String), +} + +#[cfg(feature = "server")] +impl LangfuseClient { + pub fn new(config: LangfuseConfig) -> Result<Self, LangfuseError> { + let http = reqwest::Client::builder() + .timeout(Duration::from_secs(10)) + .build() + .map_err(|e| LangfuseError::Transport(e.to_string()))?; + Ok(Self { + http, + host: config.host, + auth: (config.public_key, config.secret_key), + }) + } + + pub async fn ingest_batch(&self, batch: Vec<IngestItem>) -> Result<(), LangfuseError> { + if batch.is_empty() { + return Ok(()); + } + let url = format!("{}/api/public/ingestion", self.host.trim_end_matches('/')); + let payload = serde_json::json!({ "batch": batch }); + + let mut attempt = 0u32; + let mut delay_ms = 100u64; + loop { + let resp = self + .http + .post(&url) + .basic_auth(&self.auth.0, Some(&self.auth.1)) + .json(&payload) + .send() + .await; + + match resp { + Ok(r) if r.status().is_success() => return Ok(()), + Ok(r) if r.status().is_server_error() && attempt < 3 => { + warn!(status = %r.status(), attempt, "langfuse ingestion 5xx; retrying"); + } + Ok(r) => return Err(LangfuseError::Status(r.status().as_u16())), + Err(e) if attempt < 3 => { + warn!(error = %e, attempt, "langfuse transport error; retrying"); + } + Err(e) => return Err(LangfuseError::Transport(e.to_string())), + } + tokio::time::sleep(Duration::from_millis(delay_ms)).await; + attempt += 1; + delay_ms = (delay_ms * 2).min(1000); + } + } +} + +/// A live Langfuse stream. Buffers items in-memory and flushes every N ms +/// or M items, whichever is first. +pub struct LangfuseStream { + pub config: LangfuseConfig, + pub scope: StreamScope, + pub name: String, + pub active_since: chrono::DateTime<chrono::Utc>, + #[cfg(feature = "server")] + client: Arc<LangfuseClient>, + #[cfg(feature = "server")] + buffer: Arc<Mutex<Vec<IngestItem>>>, +} + +#[cfg(feature = "server")] +const FLUSH_ITEMS: usize = 32; + +impl LangfuseStream { + #[cfg(feature = "server")] + pub fn new( + name: String, + config: LangfuseConfig, + scope: StreamScope, + ) -> Result<Arc<Self>, LangfuseError> { + let client = Arc::new(LangfuseClient::new(config.clone())?); + Ok(Arc::new(Self { + config, + scope, + name, + active_since: Utc::now(), + client, + buffer: Arc::new(Mutex::new(Vec::new())), + })) + } + + #[cfg(not(feature = "server"))] + pub fn new(name: String, config: LangfuseConfig, scope: StreamScope) -> Arc<Self> { + Arc::new(Self { + config, + scope, + name, + active_since: Utc::now(), + }) + } + + #[cfg(feature = "server")] + async fn flush(&self) -> Result<(), LangfuseError> { + let mut buf = self.buffer.lock().await; + if buf.is_empty() { + return Ok(()); + } + let batch: Vec<_> = buf.drain(..).collect(); + drop(buf); + self.client.ingest_batch(batch).await + } +} + +#[async_trait] +impl SessionStream for LangfuseStream { + fn summary(&self) -> StreamSummary { + StreamSummary { + name: self.name.clone(), + kind: StreamKind::Langfuse, + target: format!("langfuse: {}", self.config.host), + active_since: self.active_since, + } + } + fn scope(&self) -> StreamScope { + self.scope.clone() + } + + #[cfg(feature = "server")] + async fn on_event(&self, event: &BusEvent) -> StreamOutcome { + let items = bus_event_to_items(event); + if items.is_empty() { + return StreamOutcome::Skipped; + } + + let mut buf = self.buffer.lock().await; + buf.extend(items); + if buf.len() >= FLUSH_ITEMS { + let batch: Vec<_> = buf.drain(..).collect(); + drop(buf); + match self.client.ingest_batch(batch).await { + Ok(()) => StreamOutcome::Ok, + Err(e) => StreamOutcome::Failed(StreamError::Transport(e.to_string())), + } + } else { + StreamOutcome::Ok + } + } + + #[cfg(not(feature = "server"))] + async fn on_event(&self, _event: &BusEvent) -> StreamOutcome { + StreamOutcome::Ok + } + + async fn shutdown(&self) { + #[cfg(feature = "server")] + { + let _ = self.flush().await; + } + } +} diff --git a/crates/dirigent_langfuse/src/factory.rs b/crates/dirigent_langfuse/src/factory.rs new file mode 100644 index 0000000..8abea8c --- /dev/null +++ b/crates/dirigent_langfuse/src/factory.rs @@ -0,0 +1,48 @@ +//! Phase 4: factory that builds a stub `LangfuseStream`. Task 22 upgrades +//! it to read credentials from params and construct a real client. + +use std::sync::Arc; + +use async_trait::async_trait; + +use dirigent_core::sharing::{StreamBuildError, StreamConfig, StreamFactory}; +use dirigent_protocol::streaming::SessionStream; + +use crate::client::{LangfuseConfig, LangfuseStream}; + +pub struct LangfuseFactory; + +#[async_trait] +impl StreamFactory for LangfuseFactory { + fn kind(&self) -> &'static str { "langfuse" } + + async fn build(&self, cfg: &StreamConfig) -> Result<Arc<dyn SessionStream>, StreamBuildError> { + // Parse params. Required fields: + // host: String (URL) + // public_key: String + // secret_key: String + // + // Phase 4 stub: parse-or-fail, then construct LangfuseStream with + // the parsed config. Task 22 uses the host to build a reqwest client. + + let host = cfg.params + .get("host").and_then(|v| v.as_str()) + .ok_or_else(|| StreamBuildError::Config("missing `host` (url string)".into()))?; + let public_key = cfg.params + .get("public_key").and_then(|v| v.as_str()) + .ok_or_else(|| StreamBuildError::Config("missing `public_key`".into()))?; + let secret_key = cfg.params + .get("secret_key").and_then(|v| v.as_str()) + .ok_or_else(|| StreamBuildError::Config("missing `secret_key`".into()))?; + + let lf_cfg = LangfuseConfig { + host: host.to_string(), + public_key: public_key.to_string(), + secret_key: secret_key.to_string(), + }; + + let stream = LangfuseStream::new(cfg.name.clone(), lf_cfg, cfg.scope.clone()) + .map_err(|e| StreamBuildError::Transport(e.to_string()))?; + Ok(stream as Arc<dyn SessionStream>) + } +} diff --git a/crates/dirigent_langfuse/src/lib.rs b/crates/dirigent_langfuse/src/lib.rs new file mode 100644 index 0000000..50d9f14 --- /dev/null +++ b/crates/dirigent_langfuse/src/lib.rs @@ -0,0 +1,13 @@ +//! Langfuse SessionStream implementation. +//! +//! Phase 4 scope: stub implementation. Task 22 adds the real HTTP +//! client + event-to-ingestion mapping. + +mod client; +#[cfg(feature = "server")] +mod factory; +mod mapping; + +pub use client::{LangfuseConfig, LangfuseStream}; +#[cfg(feature = "server")] +pub use factory::LangfuseFactory; diff --git a/crates/dirigent_langfuse/src/mapping.rs b/crates/dirigent_langfuse/src/mapping.rs new file mode 100644 index 0000000..5660945 --- /dev/null +++ b/crates/dirigent_langfuse/src/mapping.rs @@ -0,0 +1,173 @@ +//! BusEvent → Langfuse ingestion mapping. +//! +//! Maps the common BusEvent kinds to Langfuse ingestion items (traces, +//! generations, spans). Events without a `scroll_id` are dropped — +//! Langfuse requires a trace id up-front. + +// The items below are only wired into the stream when the `server` +// feature is on; the default-feature build keeps them for symmetry but +// does not reference them, so allow dead-code warnings there. +#![cfg_attr(not(feature = "server"), allow(dead_code))] + +use chrono::{DateTime, Utc}; +use serde::Serialize; +use uuid::Uuid; + +use dirigent_protocol::{streaming::BusEvent, Event}; + +/// A single Langfuse ingestion item. +/// +/// Batched into `{ "batch": [...] }` in `LangfuseClient::ingest_batch`. +#[derive(Debug, Clone, Serialize)] +pub struct IngestItem { + pub id: String, // UUIDv7 + pub timestamp: DateTime<Utc>, + #[serde(rename = "type")] + pub kind: IngestKind, + pub body: serde_json::Value, +} + +#[derive(Debug, Clone, Copy, Serialize)] +#[serde(rename_all = "kebab-case")] +#[allow(dead_code)] // SpanCreate/SpanUpdate reserved for future tool-call mapping +pub enum IngestKind { + TraceCreate, + GenerationCreate, + GenerationUpdate, + SpanCreate, + SpanUpdate, +} + +pub fn bus_event_to_items(bus_event: &BusEvent) -> Vec<IngestItem> { + let Some(scroll_id) = bus_event.routing.scroll_id else { + // No scroll_id binding yet — drop. Upstream callers may choose to + // buffer pending events keyed by (connector_id, native_id) until + // SessionRegistered arrives; Phase 4 scope: drop and log. + return Vec::new(); + }; + + let trace_id = scroll_id.to_string(); + let now = Utc::now(); + + match &*bus_event.event { + Event::SessionCreated { session, .. } => { + // `session.title` is a `String`; fall back to the id if empty. + let name = if session.title.is_empty() { + session.id.clone() + } else { + session.title.clone() + }; + vec![IngestItem { + id: Uuid::now_v7().to_string(), + timestamp: now, + kind: IngestKind::TraceCreate, + body: serde_json::json!({ + "id": trace_id, + "name": name, + }), + }] + } + Event::MessageStarted { message, .. } => { + vec![IngestItem { + id: Uuid::now_v7().to_string(), + timestamp: now, + kind: IngestKind::GenerationCreate, + body: serde_json::json!({ + "id": message.id, + "traceId": trace_id, + "name": format!("{:?}", message.role), + "startTime": message.created_at, + }), + }] + } + Event::MessageCompleted { message, .. } => { + vec![IngestItem { + id: Uuid::now_v7().to_string(), + timestamp: now, + kind: IngestKind::GenerationUpdate, + body: serde_json::json!({ + "id": message.id, + "traceId": trace_id, + "endTime": now, + "output": serialize_content(&message.content), + }), + }] + } + Event::TurnComplete { .. } => Vec::new(), // captured by MessageCompleted + // SessionUpdate::ToolCall* — would need a case-by-case mapping; out of + // Phase 4 scope. Return empty for now. + _ => Vec::new(), + } +} + +fn serialize_content(parts: &[dirigent_protocol::MessagePart]) -> serde_json::Value { + serde_json::to_value(parts).unwrap_or(serde_json::Value::Null) +} + +#[cfg(test)] +mod tests { + use super::*; + use dirigent_protocol::streaming::{BusEvent, EventKind, EventOrigin, EventRouting}; + use dirigent_protocol::{Event, Message, MessageRole, MessageStatus}; + use std::sync::Arc; + + fn make_bus_event_with_scroll(event: Event, scroll_id: Uuid) -> BusEvent { + BusEvent { + routing: EventRouting { + scroll_id: Some(scroll_id), + connector_uid: Some(Uuid::new_v4()), + connector_id: Some("c".into()), + native_session_id: Some("s".into()), + kind: EventKind::Message, + }, + origin: EventOrigin::Runtime, + event: Arc::new(event), + } + } + + #[test] + fn message_started_produces_generation_create() { + let scroll_id = Uuid::new_v4(); + let msg = Message { + id: "m1".into(), + session_id: "s".into(), + role: MessageRole::Assistant, + created_at: chrono::Utc::now(), + content: vec![], + status: MessageStatus::Streaming, + metadata: None, + }; + let bus_event = make_bus_event_with_scroll( + Event::MessageStarted { + connector_id: "c".into(), + message: msg, + }, + scroll_id, + ); + let items = bus_event_to_items(&bus_event); + assert_eq!(items.len(), 1); + assert!(matches!(items[0].kind, IngestKind::GenerationCreate)); + } + + #[test] + fn no_scroll_id_drops_event() { + let event = Event::Connected; + let bus_event = BusEvent { + routing: EventRouting::default(), + origin: EventOrigin::Runtime, + event: Arc::new(event), + }; + let items = bus_event_to_items(&bus_event); + assert_eq!(items.len(), 0); + } + + #[test] + fn unmapped_event_returns_empty() { + // `Connected` is not one of our mapped variants even when a scroll_id + // is bound → expect 0 items. + let scroll_id = Uuid::new_v4(); + let bus_event = make_bus_event_with_scroll(Event::Connected, scroll_id); + let items = bus_event_to_items(&bus_event); + assert_eq!(items.len(), 0); + } +} diff --git a/crates/dirigent_matrix/CLAUDE.md b/crates/dirigent_matrix/CLAUDE.md new file mode 100644 index 0000000..2d5793d --- /dev/null +++ b/crates/dirigent_matrix/CLAUDE.md @@ -0,0 +1,96 @@ +# Package: dirigent_matrix + +Matrix integration for Dirigent session sharing. + +## Quick Facts +- **Type**: Library +- **Main Entry**: src/lib.rs +- **Dependencies**: matrix-sdk, dirigent_protocol, tokio, serde, thiserror, async-trait +- **Status**: Phase 1 -- Bot-mode session sharing + +## Purpose + +Provides bidirectional bridging between Dirigent sessions and Matrix rooms. +A session can be "shared" to a Matrix room, allowing Matrix users to send +messages to the agent and see responses in real-time. + +## Architecture + +### MatrixService (`service.rs`) +Central singleton owning the matrix-sdk Client. Handles: +- Bot authentication (login with username/password, session restore via SQLite store) +- Background sync loop for receiving Matrix events +- Share registry (tracks active session shares by connector_id + session_id) +- Room message dispatch to appropriate shares + +### MatrixSessionShare (`share.rs`) +Bidirectional bridge for one (connector_id, session_id) to one Matrix room: +- **Dirigent to Matrix**: Subscribes to connector events, forwards completed assistant messages as m.notice +- **Matrix to Dirigent**: Receives room messages via MatrixService dispatch, sends ConnectorCommandProxy through mpsc channel +- Implements the `SessionShare` trait from dirigent_protocol + +### MatrixConfig (`config.rs`) +Configuration parsed from `[matrix]` section in dirigent.toml: +- Homeserver URL, username, password source (env var or inline) +- Device ID for session persistence across restarts +- Display name, default invite list, store path + +### Room Management (`room.rs`) +- Private, non-federated room creation for session shares +- Room naming conventions (`"Dirigent: <title>"`) + +## Configuration + +Identity and credentials live in an Account; sharing behavior in `[matrix]`: + +```toml +[accounts.matrix-bot] +type = "matrix" +homeserver = "https://matrix.example.com" +username = "dirigent_bot" +device_id = "DIRIGENT_01" +display_name = "Dirigent Bot" + +[accounts.matrix-bot.credentials.password] +source = "env" +key = "DIRIGENT_MATRIX_PASSWORD" + +[matrix] +account = "matrix-bot" +default_invite = ["@user:example.com"] +store_path = "matrix/bot/store" +``` + +## Key Types +- `MatrixService` -- Singleton service, owns Client and share registry +- `MatrixSessionShare` -- Bidirectional session-to-room bridge +- `MatrixBehaviorConfig` -- Sharing behavior (account ref, invites, store path) +- `ConnectorCommandProxy` -- Message proxy decoupling from dirigent_core types +- `CreateRoomOptions` -- Room creation parameters + +## Integration with CoreRuntime + +The MatrixService is wired into CoreRuntime as an optional component (like archivist): +- `CoreRuntime::start_matrix_service()` -- Resolves Account from config, creates and starts service +- `CoreRuntime::create_matrix_share()` -- Creates room, starts bridge, spawns command proxy task +- `CoreRuntime::matrix_service()` -- Accessor for the running service + +## Event Flow + +``` +Connector emits Event::MessageCompleted (role=assistant) + -> MatrixSessionShare event forwarder task + -> Sends m.notice to Matrix room + +Matrix user sends message in room + -> MatrixService sync loop receives SyncRoomMessageEvent + -> Looks up share by room_id + -> share.inject_message(text) -> ConnectorCommandProxy + -> Proxy task translates to ConnectorCommand::SendMessage + -> Connector processes message +``` + +## Related Packages +- **dirigent_protocol**: SessionShare trait, Event types consumed by share forwarder +- **dirigent_core**: CoreRuntime integration, ConnectorCommand, ConnectorHandle +- **dirigent_config**: Path resolution (DIRIGENT_DATA_DIR for SQLite store) diff --git a/crates/dirigent_matrix/Cargo.toml b/crates/dirigent_matrix/Cargo.toml new file mode 100644 index 0000000..142ac40 --- /dev/null +++ b/crates/dirigent_matrix/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "dirigent_matrix" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[features] +default = ["bundled-sqlite"] +bundled-sqlite = ["matrix-sdk/bundled-sqlite"] + +[dependencies] +# Matrix SDK +matrix-sdk = { version = "0.9", default-features = false, features = ["rustls-tls", "sqlite"] } + +# Internal dependencies +dirigent_protocol = { path = "../dirigent_protocol" } +dirigent_auth = { path = "../dirigent_auth" } + +# Async runtime +tokio = { version = "1.42", features = ["sync", "time", "macros", "rt"] } + +# Markdown rendering +pulldown-cmark = "0.12" + +# Serialization +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +# Logging +tracing = "0.1" + +# Error handling +thiserror = "2.0" + +# Async traits +async-trait = "0.1" + +# UUID +uuid = { version = "1.11", features = ["v7"] } + +# Timestamps (for StreamSummary::active_since) +chrono = { version = "0.4", features = ["serde"] } + +[dev-dependencies] +tokio = { version = "1.42", features = ["full"] } diff --git a/crates/dirigent_matrix/src/config.rs b/crates/dirigent_matrix/src/config.rs new file mode 100644 index 0000000..3582974 --- /dev/null +++ b/crates/dirigent_matrix/src/config.rs @@ -0,0 +1,73 @@ +use serde::{Deserialize, Serialize}; + +/// How Dirigent connects to Matrix. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum MatrixConnectionMode { + /// Dedicated bot user with username/password login (existing behavior). + #[default] + Bot, + /// Appservice-provisioned virtual user with stored access token. + Provisioned, +} + +/// A persistent Matrix room defined in configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PersistentRoom { + /// Human-readable label for this room (shown in room picker). + pub label: String, + /// Matrix room ID (e.g. "!abc:matrix.org"). + pub room_id: String, +} + +/// Matrix sharing behavior — separate from identity. +/// +/// Identity and credentials come from an Account referenced by name. +/// This struct only defines sharing behavior. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MatrixBehaviorConfig { + /// Account name (key in [accounts.*]) for the Matrix connection. + pub account: String, + /// Connection mode: "bot" (default) or "provisioned". + #[serde(default)] + pub mode: MatrixConnectionMode, + /// Matrix user IDs to invite to newly created share rooms. + #[serde(default)] + pub default_invite: Vec<String>, + /// Directory for matrix-sdk SQLite store (relative to DIRIGENT_DATA_DIR). + #[serde(default = "default_store_path")] + pub store_path: String, + /// Pre-defined rooms that always appear in the room selection UI. + #[serde(default)] + pub rooms: Vec<PersistentRoom>, +} + +fn default_store_path() -> String { + "matrix/bot/store".to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_mode_is_bot() { + let json = r#"{"account": "matrix-bot"}"#; + let config: MatrixBehaviorConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.mode, MatrixConnectionMode::Bot); + } + + #[test] + fn test_provisioned_mode() { + let json = r#"{"account": "matrix-virt", "mode": "provisioned"}"#; + let config: MatrixBehaviorConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.mode, MatrixConnectionMode::Provisioned); + } + + #[test] + fn test_bot_mode_explicit() { + let json = r#"{"account": "matrix-bot", "mode": "bot"}"#; + let config: MatrixBehaviorConfig = serde_json::from_str(json).unwrap(); + assert_eq!(config.mode, MatrixConnectionMode::Bot); + } +} diff --git a/crates/dirigent_matrix/src/error.rs b/crates/dirigent_matrix/src/error.rs new file mode 100644 index 0000000..1ccbbe0 --- /dev/null +++ b/crates/dirigent_matrix/src/error.rs @@ -0,0 +1,39 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum MatrixError { + #[error("Matrix SDK error: {0}")] + Sdk(#[from] matrix_sdk::Error), + + #[error("Matrix HTTP error: {0}")] + Http(#[from] matrix_sdk::HttpError), + + #[error("Matrix client build error: {0}")] + ClientBuild(#[from] matrix_sdk::ClientBuildError), + + #[error("Not logged in")] + NotLoggedIn, + + #[error("Room not found: {0}")] + RoomNotFound(String), + + #[error("Share not found: connector={connector_id} session={session_id}")] + ShareNotFound { + connector_id: String, + session_id: String, + }, + + #[error("Share already exists: connector={connector_id} session={session_id}")] + ShareAlreadyExists { + connector_id: String, + session_id: String, + }, + + #[error("Configuration error: {0}")] + Config(String), + + #[error("Channel closed")] + ChannelClosed, +} + +pub type Result<T> = std::result::Result<T, MatrixError>; diff --git a/crates/dirigent_matrix/src/lib.rs b/crates/dirigent_matrix/src/lib.rs new file mode 100644 index 0000000..d2fcb58 --- /dev/null +++ b/crates/dirigent_matrix/src/lib.rs @@ -0,0 +1,17 @@ +//! Matrix integration for Dirigent session sharing +//! +//! This package provides bidirectional bridging between Dirigent sessions +//! and Matrix rooms. A session can be "shared" to a Matrix room, allowing +//! Matrix users to interact with the agent and see responses in real-time. + +pub mod config; +pub mod error; +pub mod room; +pub mod service; +pub mod share; + +pub use config::MatrixBehaviorConfig; +pub use error::{MatrixError, Result}; +pub use room::CreateRoomOptions; +pub use service::MatrixService; +pub use share::{ConnectorCommandProxy, MatrixSessionShare}; diff --git a/crates/dirigent_matrix/src/room.rs b/crates/dirigent_matrix/src/room.rs new file mode 100644 index 0000000..9c27ce1 --- /dev/null +++ b/crates/dirigent_matrix/src/room.rs @@ -0,0 +1,81 @@ +//! Matrix room creation and management helpers. + +use matrix_sdk::{ + ruma::{ + api::client::room::create_room::v3::{ + CreationContent, Request as CreateRoomRequest, RoomPreset, + }, + OwnedRoomId, OwnedUserId, + }, + Client, +}; +use tracing::debug; + +/// Options for creating a new Matrix room for session sharing. +pub struct CreateRoomOptions { + /// Human-readable room name. + pub name: String, + /// Optional room topic. + pub topic: Option<String>, + /// Matrix user IDs (as strings, e.g. "@user:example.com") to invite at + /// creation time. Invalid IDs are silently skipped. + pub invite: Vec<String>, +} + +/// Create a private, non-federated Matrix room for bridging a Dirigent session. +/// +/// The room is configured as a `PrivateChat` (invite-only, shared history) and +/// the `m.federate` flag is disabled so the room does not appear on remote +/// homeservers. +/// +/// # Errors +/// +/// Returns [`crate::MatrixError::NotLoggedIn`] if the client is not authenticated. +/// Other errors propagate from the Matrix SDK. +pub async fn create_share_room( + client: &Client, + options: CreateRoomOptions, +) -> crate::Result<OwnedRoomId> { + if !client.logged_in() { + return Err(crate::MatrixError::NotLoggedIn); + } + + let invite: Vec<OwnedUserId> = options + .invite + .iter() + .filter_map(|id| id.parse::<OwnedUserId>().ok()) + .collect(); + + let mut request = CreateRoomRequest::new(); + request.name = Some(options.name.clone()); + request.topic = options.topic.clone(); + request.invite = invite; + request.preset = Some(RoomPreset::PrivateChat); + + // Disable federation so the room stays on the local homeserver. + let mut creation_content = CreationContent::new(); + creation_content.federate = false; + request.creation_content = + Some(matrix_sdk::ruma::serde::Raw::new(&creation_content).map_err(|e| { + crate::MatrixError::Config(format!( + "Failed to serialize creation content: {}", + e + )) + })?); + + debug!(room_name = %options.name, "Creating Matrix share room"); + + let room = client.create_room(request).await?; + Ok(room.room_id().to_owned()) +} + +/// Generate a human-readable Matrix room name for a session. +/// +/// Format: `"Dirigent: <session_title>"` or `"Dirigent: <connector_id>"` when +/// no title is available. +pub fn room_name_for_session(connector_id: &str, session_title: Option<&str>) -> String { + match session_title.filter(|t| !t.is_empty()) { + Some(title) => format!("Dirigent: {}", title), + None => format!("Dirigent: {}", connector_id), + } +} diff --git a/crates/dirigent_matrix/src/service.rs b/crates/dirigent_matrix/src/service.rs new file mode 100644 index 0000000..2366729 --- /dev/null +++ b/crates/dirigent_matrix/src/service.rs @@ -0,0 +1,436 @@ +use std::{ + collections::HashMap, + path::PathBuf, + sync::Arc, +}; + +use matrix_sdk::{ + config::SyncSettings, + ruma::{ + events::room::message::{MessageType, SyncRoomMessageEvent}, + OwnedUserId, + }, + Client, Room, +}; +use tokio::sync::RwLock; +use tracing::{debug, error, info, warn}; + +use crate::{config::MatrixBehaviorConfig, error::MatrixError, share::MatrixSessionShare, Result}; + +/// Key for the share registry: (connector_id, session_id) +type ShareKey = (String, String); + +/// Central Matrix service. +/// +/// Owns the SDK [`Client`], handles login/session-restore, manages the sync +/// loop, and maintains a registry of [`MatrixSessionShare`]s that bridge +/// Dirigent sessions to Matrix rooms. +pub struct MatrixService { + account: dirigent_auth::Account, + homeserver: String, + username: String, + display_name_str: String, + device_id: String, + behavior: MatrixBehaviorConfig, + data_dir: PathBuf, + client: Arc<RwLock<Option<Client>>>, + shares: Arc<RwLock<HashMap<ShareKey, MatrixSessionShare>>>, +} + +impl MatrixService { + /// Create a new (not yet logged-in) service from an Account and behavior config. + pub fn from_account( + account: &dirigent_auth::Account, + behavior: MatrixBehaviorConfig, + data_dir: PathBuf, + ) -> Result<Self> { + let homeserver = account + .property_str("homeserver") + .ok_or_else(|| MatrixError::Config("Account missing 'homeserver' property".into()))? + .to_string(); + let username = account + .profile + .username + .clone() + .ok_or_else(|| { + MatrixError::Config("Account missing 'username' in profile".into()) + })?; + let device_id = account + .property_str_or("device_id", "DIRIGENT_01") + .to_string(); + let display_name_str = account.display_name().to_string(); + + Ok(Self { + account: account.clone(), + homeserver, + username, + display_name_str, + device_id, + behavior, + data_dir, + client: Arc::new(RwLock::new(None)), + shares: Arc::new(RwLock::new(HashMap::new())), + }) + } + + /// Return the current behavior configuration. + pub fn behavior(&self) -> &MatrixBehaviorConfig { + &self.behavior + } + + /// Return a clone of the inner [`Client`], if logged in. + pub async fn client_cloned(&self) -> Option<Client> { + self.client.read().await.clone() + } + + /// Resolve a Matrix [`Room`] handle from its string room id. + /// + /// Returns: + /// - `Ok(Some(room))` — client is logged in and knows the room + /// - `Ok(None)` — client is logged in but the room isn't known (not + /// joined / never invited / wrong id) + /// - `Err(MatrixError::NotLoggedIn)` — no client yet + /// - `Err(MatrixError::Config(..))` — `room_id` isn't a valid Matrix id + /// + /// Exposed for consumers (e.g. `dirigent_core`'s `MatrixFactory`) + /// that need to look up a pre-existing room without taking a + /// `matrix_sdk` dependency of their own. + pub async fn room_by_id(&self, room_id: &str) -> Result<Option<Room>> { + let client = self + .client_cloned() + .await + .ok_or(MatrixError::NotLoggedIn)?; + + let parsed: matrix_sdk::ruma::OwnedRoomId = room_id + .parse() + .map_err(|e: matrix_sdk::ruma::IdParseError| { + MatrixError::Config(format!("invalid room_id '{}': {}", room_id, e)) + })?; + + Ok(client.get_room(&parsed)) + } + + // ----------------------------------------------------------------------- + // Authentication + // ----------------------------------------------------------------------- + + /// Build the SDK client (with SQLite store) and authenticate. + /// + /// Attempts to restore a previously persisted session first. Falls back + /// to a fresh username/password login when no session is found. + pub async fn login(&self) -> Result<()> { + let store_path = self.data_dir.join(&self.behavior.store_path); + std::fs::create_dir_all(&store_path).map_err(|e| { + MatrixError::Config(format!("Failed to create store directory: {}", e)) + })?; + + let client = Client::builder() + .homeserver_url(&self.homeserver) + .sqlite_store(&store_path, None) + .build() + .await?; + + // Try to restore an existing session from the store. + if client.logged_in() { + info!( + homeserver = %self.homeserver, + username = %self.username, + "Restored existing Matrix session" + ); + *self.client.write().await = Some(client); + return Ok(()); + } + + // No stored session — authenticate based on the configured mode. + match self.behavior.mode { + crate::config::MatrixConnectionMode::Bot => { + // Existing flow: login with bot username/password. + let password = self + .account + .resolve_credential("password") + .map_err(|e| { + MatrixError::Config(format!("Failed to resolve password: {}", e)) + })?; + + info!( + homeserver = %self.homeserver, + username = %self.username, + "Performing fresh Matrix login (bot mode)" + ); + + client + .matrix_auth() + .login_username(&self.username, &password) + .device_id(&self.device_id) + .initial_device_display_name(&self.display_name_str) + .send() + .await?; + } + crate::config::MatrixConnectionMode::Provisioned => { + // Restore session from a stored virtual-user access token. + let token = self + .account + .resolve_credential("token") + .map_err(|e| { + MatrixError::Config(format!("Failed to resolve token: {}", e)) + })?; + + let user_id_str = self + .account + .property_str("user_id") + .ok_or_else(|| { + MatrixError::Config( + "Provisioned mode requires 'user_id' property on account".into(), + ) + })?; + + info!( + homeserver = %self.homeserver, + user_id = %user_id_str, + "Restoring provisioned session with access token" + ); + + use matrix_sdk::matrix_auth::MatrixSession; + use matrix_sdk::ruma::{OwnedDeviceId, OwnedUserId}; + + let user_id: OwnedUserId = user_id_str.try_into().map_err(|_| { + MatrixError::Config(format!("Invalid user_id: {}", user_id_str)) + })?; + let device_id: OwnedDeviceId = self.device_id.clone().into(); + + let session = MatrixSession { + meta: matrix_sdk::SessionMeta { + user_id, + device_id, + }, + tokens: matrix_sdk::matrix_auth::MatrixSessionTokens { + access_token: token, + refresh_token: None, + }, + }; + + client.matrix_auth().restore_session(session).await?; + } + } + + info!( + homeserver = %self.homeserver, + username = %self.username, + "Matrix login successful" + ); + + *self.client.write().await = Some(client); + Ok(()) + } + + // ----------------------------------------------------------------------- + // Sync loop + // ----------------------------------------------------------------------- + + /// Start the background sync task. + /// + /// Registers an event handler for incoming room messages and spawns a + /// task that runs `client.sync()` indefinitely. Returns immediately + /// after spawning. + /// + /// # Errors + /// + /// Returns [`MatrixError::NotLoggedIn`] if `login()` was not called first. + pub async fn start_sync(&self) -> Result<()> { + let client = self + .client + .read() + .await + .clone() + .ok_or(MatrixError::NotLoggedIn)?; + + if !client.logged_in() { + return Err(MatrixError::NotLoggedIn); + } + + let shares = Arc::clone(&self.shares); + let bot_user_id: Option<OwnedUserId> = client.user_id().map(|u| u.to_owned()); + + // Register the room message event handler. + client.add_event_handler({ + let shares = Arc::clone(&shares); + let bot_user_id = bot_user_id.clone(); + move |ev: SyncRoomMessageEvent, room: Room| { + let shares = Arc::clone(&shares); + let bot_user_id = bot_user_id.clone(); + async move { + on_room_message(ev, room, shares, bot_user_id).await; + } + } + }); + + // Spawn the sync loop. + tokio::spawn(async move { + info!("Starting Matrix sync loop"); + if let Err(e) = client.sync(SyncSettings::default()).await { + error!("Matrix sync loop exited with error: {}", e); + } + }); + + Ok(()) + } + + // ----------------------------------------------------------------------- + // Share registry + // ----------------------------------------------------------------------- + + /// Register a share, making it eligible to receive Matrix messages. + /// + /// # Errors + /// + /// Returns [`MatrixError::ShareAlreadyExists`] if a share for the same + /// `(connector_id, session_id)` pair is already registered. + pub async fn register_share(&self, share: MatrixSessionShare) -> Result<()> { + let key = (share.connector_id.clone(), share.session_id.clone()); + let mut map = self.shares.write().await; + if map.contains_key(&key) { + return Err(MatrixError::ShareAlreadyExists { + connector_id: key.0, + session_id: key.1, + }); + } + map.insert(key, share); + Ok(()) + } + + /// Remove a share from the registry and shut it down. + /// + /// # Errors + /// + /// Returns [`MatrixError::ShareNotFound`] if no matching share exists. + pub async fn remove_share(&self, connector_id: &str, session_id: &str) -> Result<()> { + let key = (connector_id.to_owned(), session_id.to_owned()); + let share = self + .shares + .write() + .await + .remove(&key) + .ok_or_else(|| MatrixError::ShareNotFound { + connector_id: connector_id.to_owned(), + session_id: session_id.to_owned(), + })?; + share.shutdown().await; + Ok(()) + } + + /// Return the number of currently registered shares. + pub async fn share_count(&self) -> usize { + self.shares.read().await.len() + } + + /// Return the room IDs and keys for all currently registered shares. + /// + /// Returns a list of `(connector_id, session_id, room_id)` tuples. + pub async fn list_shares(&self) -> Vec<(String, String, String)> { + self.shares + .read() + .await + .iter() + .map(|((cid, sid), s)| (cid.clone(), sid.clone(), s.room_id.clone())) + .collect() + } + + /// Look up a share by connector and session ID. + /// + /// Returns `None` when not found. + pub async fn get_share( + &self, + connector_id: &str, + session_id: &str, + ) -> Option<(String, bool)> { + let key = (connector_id.to_owned(), session_id.to_owned()); + let map = self.shares.read().await; + if let Some(share) = map.get(&key) { + let room_id = share.room_id.clone(); + // is_active requires an await; we can't hold the read guard across + // the await point. Drop the guard first. + drop(map); + let active = { + // Re-acquire to call is_active (we still have shares Arc) + let map = self.shares.read().await; + if let Some(s) = map.get(&key) { + s.is_active().await + } else { + false + } + }; + Some((room_id, active)) + } else { + None + } + } + + /// Shut down all shares and signal the service to stop. + pub async fn shutdown(&self) { + let mut map = self.shares.write().await; + for (_, share) in map.drain() { + share.shutdown().await; + } + // Release the client so the sync loop can terminate naturally. + *self.client.write().await = None; + } +} + +// --------------------------------------------------------------------------- +// Event handler: Matrix → Dirigent +// --------------------------------------------------------------------------- + +/// Called by the SDK for every incoming room message. +/// +/// Looks up which registered share owns this room, then calls +/// `share.inject_message(text)` so the text flows into the Dirigent session. +/// Messages from the bot itself are skipped. +async fn on_room_message( + ev: SyncRoomMessageEvent, + room: Room, + shares: Arc<RwLock<HashMap<ShareKey, MatrixSessionShare>>>, + bot_user_id: Option<OwnedUserId>, +) { + // We only care about original (non-redacted, non-edited) messages. + let original = match ev.as_original() { + Some(o) => o, + None => return, + }; + + // Skip bot's own messages. + if let Some(bot_id) = &bot_user_id { + if original.sender == *bot_id { + return; + } + } + + // Extract plain-text body. + let text = match &original.content.msgtype { + MessageType::Text(t) => t.body.clone(), + MessageType::Notice(_) => return, // ignore notices (including our own) + _ => return, + }; + + let room_id_str = room.room_id().as_str().to_owned(); + + // Find the share that owns this room. + let map = shares.read().await; + let matching = map + .values() + .find(|s| s.room_id == room_id_str); + + if let Some(share) = matching { + debug!( + room_id = %room_id_str, + connector_id = %share.connector_id, + session_id = %share.session_id, + "Injecting Matrix message into Dirigent session" + ); + share.inject_message(&text).await; + } else { + warn!( + room_id = %room_id_str, + "Received message in unregistered room, ignoring" + ); + } +} diff --git a/crates/dirigent_matrix/src/share.rs b/crates/dirigent_matrix/src/share.rs new file mode 100644 index 0000000..9a487b4 --- /dev/null +++ b/crates/dirigent_matrix/src/share.rs @@ -0,0 +1,723 @@ +use std::sync::Arc; + +use chrono::{DateTime, Utc}; +use matrix_sdk::Room; +use tokio::sync::{broadcast, mpsc, oneshot, Mutex, RwLock}; +use tracing::{debug, error, warn}; +use uuid::Uuid; + +use dirigent_protocol::accumulator::{AccumulatedMessage, AccumulatedPart, MessageAccumulator, ToolCallData}; +use dirigent_protocol::{ContentBlock, Event, MessageRole, SessionUpdate}; + +/// Command proxy sent from Matrix → Dirigent direction. +/// +/// The caller who wires up the share is responsible for translating this +/// into a real `ConnectorCommand::SendMessage` (or equivalent) using their +/// own connector/session handle. +#[derive(Debug, Clone)] +pub struct ConnectorCommandProxy { + pub session_id: String, + pub text: String, +} + +/// A bidirectional bridge between a Dirigent session and a Matrix room. +/// +/// **Dirigent → Matrix**: Subscribes to a broadcast event stream, forwards +/// completed assistant messages and session errors into the Matrix room as +/// `m.notice` messages. +/// +/// **Matrix → Dirigent**: The `inject_message` method is called by +/// `MatrixService` when a Matrix message arrives; it sends a +/// `ConnectorCommandProxy` through an mpsc channel that the owner can read. +pub struct MatrixSessionShare { + /// Connector that owns the session. + pub connector_id: String, + /// Session being bridged (native connector session ID). + pub session_id: String, + /// Scroll ID for the archived session this share is scoped to. + /// + /// Required by `SessionStream::scope()` to select `StreamScope::Session`. + pub scroll_id: Uuid, + /// Matrix room ID (as a string, e.g. "!abc:example.com"). + pub room_id: String, + /// Room handle used by the stream `on_event` path when no legacy + /// forwarder task is running. Behind an `Option` so `start()` — which + /// consumes the `Room` when spawning the legacy forwarder — can leave + /// it empty. + room_for_stream: Option<Room>, + /// Shared message accumulator so streaming chunks survive across + /// multiple `on_event` calls (and the legacy forwarder task). + accumulator: Arc<Mutex<MessageAccumulator>>, + /// When this share was activated (for `StreamSummary::active_since`). + active_since: DateTime<Utc>, + + /// Sender side of the Matrix→Dirigent command channel. + command_tx: mpsc::Sender<ConnectorCommandProxy>, + /// Shutdown signal for the event-forwarder task. + shutdown_tx: Arc<RwLock<Option<oneshot::Sender<()>>>>, + /// Whether the forwarder task is still running. + is_active: Arc<RwLock<bool>>, +} + +impl MatrixSessionShare { + /// Construct and start a new `MatrixSessionShare`. + /// + /// Spawns a background task that reads from `event_rx`, filters for + /// events belonging to `(connector_id, session_id)`, and forwards + /// relevant ones into the Matrix `room`. + /// + /// Returns the share and the receiver end of the Matrix→Dirigent channel. + pub fn start( + connector_id: String, + session_id: String, + room_id: String, + room: Room, + event_rx: broadcast::Receiver<Event>, + ) -> (Self, mpsc::Receiver<ConnectorCommandProxy>) { + // Legacy start() keeps ownership of the Room inside the forwarder + // task; `scroll_id` isn't known by the legacy call path so we + // default it to `Uuid::nil()`. The stream-path consumers use + // `new_for_stream` instead and supply a real scroll_id. + Self::start_with_scroll( + connector_id, + session_id, + Uuid::nil(), + room_id, + room, + event_rx, + ) + } + + /// Same as `start`, but lets callers attach the `scroll_id` of the + /// archived session — required for `SessionStream::scope()` to be + /// meaningful when the share is also driven as a stream. + pub fn start_with_scroll( + connector_id: String, + session_id: String, + scroll_id: Uuid, + room_id: String, + room: Room, + event_rx: broadcast::Receiver<Event>, + ) -> (Self, mpsc::Receiver<ConnectorCommandProxy>) { + let (command_tx, command_rx) = mpsc::channel(32); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + let is_active = Arc::new(RwLock::new(true)); + let accumulator = Arc::new(Mutex::new(MessageAccumulator::new())); + + let share = MatrixSessionShare { + connector_id: connector_id.clone(), + session_id: session_id.clone(), + scroll_id, + room_id: room_id.clone(), + // Legacy path: the forwarder task owns the Room, so we don't + // hold a second handle on the struct. `on_event` would double- + // drive delivery if both were active. + room_for_stream: None, + accumulator: accumulator.clone(), + active_since: Utc::now(), + command_tx, + shutdown_tx: Arc::new(RwLock::new(Some(shutdown_tx))), + is_active: is_active.clone(), + }; + + // Spawn the event-forwarder task + tokio::spawn(run_event_forwarder( + connector_id, + session_id, + room, + event_rx, + shutdown_rx, + is_active, + accumulator, + )); + + (share, command_rx) + } + + /// Construct a share wired for the stream path (`SessionStream`) only. + /// + /// No legacy event-forwarder task is spawned — the + /// `StreamRegistry` worker will drive `on_event` instead. The Room + /// handle is retained on the struct so `on_event` can deliver to it. + /// + /// Returns the share plus the receiver end of the Matrix→Dirigent + /// command channel (identical semantics to `start`). + pub fn new_for_stream( + connector_id: String, + session_id: String, + scroll_id: Uuid, + room_id: String, + room: Room, + ) -> (Self, mpsc::Receiver<ConnectorCommandProxy>) { + let (command_tx, command_rx) = mpsc::channel(32); + let (shutdown_tx, _shutdown_rx) = oneshot::channel(); + let is_active = Arc::new(RwLock::new(true)); + let accumulator = Arc::new(Mutex::new(MessageAccumulator::new())); + + let share = MatrixSessionShare { + connector_id, + session_id, + scroll_id, + room_id, + room_for_stream: Some(room), + accumulator, + active_since: Utc::now(), + command_tx, + shutdown_tx: Arc::new(RwLock::new(Some(shutdown_tx))), + is_active, + }; + + (share, command_rx) + } + + /// Inject a message received from Matrix into the Dirigent session. + /// + /// Called by `MatrixService` when a room message arrives (filtered to + /// skip the bot's own messages). Sends a `ConnectorCommandProxy` through + /// the internal mpsc channel. + pub async fn inject_message(&self, text: &str) { + let proxy = ConnectorCommandProxy { + session_id: self.session_id.clone(), + text: text.to_owned(), + }; + if let Err(e) = self.command_tx.send(proxy).await { + warn!( + connector_id = %self.connector_id, + session_id = %self.session_id, + "Failed to inject Matrix message into session (channel closed): {}", + e + ); + } + } + + /// Signal the event-forwarder task to stop and wait for it to finish. + pub async fn shutdown(&self) { + let tx = self.shutdown_tx.write().await.take(); + if let Some(tx) = tx { + let _ = tx.send(()); + } + // Give the task a moment to notice the shutdown signal. + // The is_active flag is set to false by the task itself when it exits. + } + + /// Whether the event-forwarder task is still running. + pub async fn is_active(&self) -> bool { + *self.is_active.read().await + } +} + +// --------------------------------------------------------------------------- +// Internal event-forwarder task +// --------------------------------------------------------------------------- + +async fn run_event_forwarder( + connector_id: String, + session_id: String, + room: Room, + mut event_rx: broadcast::Receiver<Event>, + shutdown_rx: oneshot::Receiver<()>, + is_active: Arc<RwLock<bool>>, + accumulator: Arc<Mutex<MessageAccumulator>>, +) { + // Fuse the shutdown signal so we can use it inside tokio::select! + let mut shutdown_rx = shutdown_rx; + + loop { + tokio::select! { + _ = &mut shutdown_rx => { + debug!( + connector_id = %connector_id, + session_id = %session_id, + "MatrixSessionShare forwarder received shutdown signal" + ); + break; + } + result = event_rx.recv() => { + match result { + Err(broadcast::error::RecvError::Closed) => { + debug!( + connector_id = %connector_id, + session_id = %session_id, + "Event broadcast channel closed, stopping forwarder" + ); + break; + } + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!( + connector_id = %connector_id, + session_id = %session_id, + "Event forwarder lagged by {} messages", + n + ); + continue; + } + Ok(event) => { + let mut acc = accumulator.lock().await; + handle_event( + &event, + &connector_id, + &session_id, + &room, + &mut *acc, + ).await; + } + } + } + } + } + + *is_active.write().await = false; +} + +/// Handle a single protocol event: forward relevant ones to the Matrix room. +async fn handle_event( + event: &Event, + connector_id: &str, + session_id: &str, + room: &Room, + accumulator: &mut MessageAccumulator, +) { + use matrix_sdk::ruma::events::room::message::RoomMessageEventContent; + + match event { + // -- Streaming accumulation: gather chunks into the accumulator -- + Event::SessionUpdate { + connector_id: cid, + session_id: sid, + update, + } if cid == connector_id && sid == session_id => { + // Send typing indicator on first chunk of a new message + let is_agent_chunk = matches!( + update, + SessionUpdate::AgentMessageChunk { .. } + | SessionUpdate::AgentThoughtChunk { .. } + | SessionUpdate::ToolCall { .. } + ); + let msg_id = match update { + SessionUpdate::AgentMessageChunk { message_id, .. } + | SessionUpdate::AgentThoughtChunk { message_id, .. } + | SessionUpdate::ToolCall { message_id, .. } + | SessionUpdate::ToolCallUpdate { message_id, .. } => Some(message_id.as_str()), + _ => None, + }; + if is_agent_chunk { + if let Some(mid) = msg_id { + if !accumulator.has_buffer(mid) { + // First chunk for this message — start typing indicator + let _ = room.typing_notice(true).await; + } + } + } + + match update { + SessionUpdate::AgentMessageChunk { + message_id, + content, + .. + } => { + accumulator.add_chunk(message_id, session_id, &connector_id, "assistant", content.clone()); + } + SessionUpdate::AgentThoughtChunk { + message_id, + content, + .. + } => { + if let ContentBlock::Text { text } = content { + accumulator.add_thinking(message_id, session_id, &connector_id, text); + } + } + SessionUpdate::ToolCall { + message_id, + tool_call, + .. + } => { + accumulator.add_or_update_tool_call( + message_id, + ToolCallData { + id: tool_call.id.clone(), + tool_name: tool_call.tool_name.clone(), + input: tool_call.raw_input.clone().unwrap_or_default(), + output: tool_call.raw_output.clone(), + }, + ); + } + SessionUpdate::ToolCallUpdate { + message_id, + tool_call, + .. + } => { + accumulator.add_or_update_tool_call( + message_id, + ToolCallData { + id: tool_call.id.clone(), + tool_name: tool_call.tool_name.clone(), + input: tool_call.raw_input.clone().unwrap_or_default(), + output: tool_call.raw_output.clone(), + }, + ); + } + _ => {} // UserMessageChunk, Unknown -- not forwarded + } + } + + // -- Streaming finalization: send accumulated content on TurnComplete -- + Event::TurnComplete { + connector_id: cid, + session_id: sid, + message_id, + .. + } if cid == connector_id && sid == session_id => { + if let Some(accumulated) = accumulator.finalize(message_id) { + if accumulated.role == "assistant" && !accumulated.is_empty() { + send_accumulated_to_matrix(&accumulated, room).await; + } + } + // Stop typing indicator + let _ = room.typing_notice(false).await; + debug!( + connector_id = %connector_id, + session_id = %session_id, + message_id = %message_id, + "TurnComplete received for bridged session" + ); + } + + // -- Non-streaming fallback: send content from MessageCompleted -- + Event::MessageCompleted { + connector_id: cid, + message, + } if cid == connector_id && message.session_id == session_id => { + if message.role != MessageRole::Assistant { + debug!( + connector_id = %connector_id, + session_id = %session_id, + role = ?message.role, + "Skipping non-assistant MessageCompleted" + ); + return; + } + + // Non-streaming path: content populated directly in MessageCompleted. + // Skip if the accumulator already has data for this message (streaming + // path handles delivery via TurnComplete instead). + if !message.content.is_empty() && !accumulator.has_buffer(&message.id) { + let accumulated = AccumulatedMessage::from_message_parts( + message.id.clone(), + message.session_id.clone(), + connector_id.to_string(), + "assistant".to_string(), + &message.content, + ); + send_accumulated_to_matrix(&accumulated, room).await; + } + } + + Event::SessionError { + connector_id: cid, + session_id: sid, + error_message, + is_recoverable, + .. + } if cid == connector_id && sid == session_id => { + let notice = if *is_recoverable { + format!("\u{26a0}\u{fe0f} Session warning: {}", error_message) + } else { + format!("\u{274c} Session error (unrecoverable): {}", error_message) + }; + let content = RoomMessageEventContent::notice_plain(notice); + if let Err(e) = room.send(content).await { + error!( + connector_id = %connector_id, + session_id = %session_id, + "Failed to send session error to Matrix room: {}", + e + ); + } + } + + // Events for *this* session but not handled above (expected, low noise) + Event::SessionIdle { session_id: sid, .. } + | Event::SessionMetadataUpdated { + session_id: sid, .. + } + if sid == session_id => + { + debug!( + connector_id = %connector_id, + session_id = %session_id, + event = event_name(event), + "Ignoring non-forwarded event for this session" + ); + } + + // Events for a *different* session/connector -- expected, just noise + Event::MessageCompleted { + connector_id: cid, + message, + } if cid != connector_id || message.session_id != session_id => { + // Different session's message -- expected on a shared broadcast channel + } + Event::TurnComplete { + connector_id: cid, + session_id: sid, + .. + } if cid != connector_id || sid != session_id => { + // Different session -- expected + } + Event::SessionError { + connector_id: cid, + session_id: sid, + .. + } if cid != connector_id || sid != session_id => { + // Different session -- expected + } + Event::SessionUpdate { + connector_id: cid, + session_id: sid, + .. + } if cid != connector_id || sid != session_id => { + // Different session -- expected + } + + // Connector lifecycle, inspector, system events -- expected on broadcast + Event::ConnectorCreated { .. } + | Event::ConnectorRemoved { .. } + | Event::ConnectorStateChanged { .. } + | Event::Connected + | Event::Disconnected + | Event::InspectorSnapshot { .. } + | Event::InspectorNodeRegistered { .. } + | Event::InspectorNodeRemoved { .. } + | Event::InspectorStateChanged { .. } + | Event::InspectorPropertiesUpdated { .. } + | Event::SystemTaskStatusChanged { .. } + | Event::SessionsListed { .. } + | Event::SessionCreated { .. } + | Event::SessionUpdated { .. } + | Event::SessionDeleted { .. } + | Event::SessionClosed { .. } + | Event::MessagesListed { .. } + | Event::SessionSystemMessageSet { .. } + | Event::SessionMetadataReceived { .. } + | Event::SessionTransferred { .. } + | Event::ForwardingPanic { .. } + | Event::AgentRequest { .. } + | Event::AcpClientConnected { .. } + | Event::AcpClientDisconnected { .. } + | Event::AcpClientSessionOpened { .. } + | Event::AcpClientSessionRouted { .. } + | Event::MessageStarted { .. } + | Event::MessageFailed { .. } + | Event::Error { .. } + | Event::SessionRegistered { .. } => { + // Expected broadcast traffic, not relevant to this share + } + + other => { + warn!( + connector_id = %connector_id, + session_id = %session_id, + event = event_name(other), + "Unhandled event type in Matrix forwarder" + ); + } + } +} + +/// Render markdown text to HTML for Matrix consumption. +fn markdown_to_html(markdown: &str) -> String { + use pulldown_cmark::{Options, Parser}; + + let mut options = Options::empty(); + options.insert(Options::ENABLE_STRIKETHROUGH); + options.insert(Options::ENABLE_TABLES); + + let parser = Parser::new_ext(markdown, options); + let mut html = String::new(); + pulldown_cmark::html::push_html(&mut html, parser); + html +} + +/// Send an accumulated message to a Matrix room, one message per content part. +async fn send_accumulated_to_matrix(msg: &AccumulatedMessage, room: &Room) { + use matrix_sdk::ruma::events::room::message::RoomMessageEventContent; + + for part in &msg.parts { + let content = match part { + AccumulatedPart::Text { text } if !text.is_empty() => { + let html = markdown_to_html(text); + Some(RoomMessageEventContent::text_html(text.clone(), html)) + } + AccumulatedPart::Thinking { text } if !text.is_empty() => { + Some(RoomMessageEventContent::notice_plain(format!( + "\u{1f4ad} {text}" + ))) + } + AccumulatedPart::Tool { data } => { + let mut notice = format!("\u{1f527} Tool: {}", data.tool_name); + if let Some(out) = &data.output { + let out_str = if let Some(s) = out.as_str() { + s.to_string() + } else { + serde_json::to_string_pretty(out).unwrap_or_default() + }; + if !out_str.is_empty() { + let truncated = if out_str.len() > 500 { + format!("{}... (truncated)", &out_str[..500]) + } else { + out_str + }; + notice.push_str(&format!("\nOutput: {truncated}")); + } + } + Some(RoomMessageEventContent::notice_plain(notice)) + } + _ => None, + }; + if let Some(content) = content { + if let Err(e) = room.send(content).await { + error!("Failed to send message part to Matrix room: {}", e); + } + } + } +} + +/// Return a human-readable name for an event variant (for logging). +fn event_name(event: &Event) -> &'static str { + match event { + Event::SessionsListed { .. } => "SessionsListed", + Event::SessionCreated { .. } => "SessionCreated", + Event::SessionUpdated { .. } => "SessionUpdated", + Event::SessionMetadataUpdated { .. } => "SessionMetadataUpdated", + Event::SessionDeleted { .. } => "SessionDeleted", + Event::SessionClosed { .. } => "SessionClosed", + Event::SessionSystemMessageSet { .. } => "SessionSystemMessageSet", + Event::SessionIdle { .. } => "SessionIdle", + Event::SessionMetadataReceived { .. } => "SessionMetadataReceived", + Event::TurnComplete { .. } => "TurnComplete", + Event::SessionError { .. } => "SessionError", + Event::SessionTransferred { .. } => "SessionTransferred", + Event::ForwardingPanic { .. } => "ForwardingPanic", + Event::SessionUpdate { .. } => "SessionUpdate", + Event::AgentRequest { .. } => "AgentRequest", + Event::AcpClientConnected { .. } => "AcpClientConnected", + Event::AcpClientDisconnected { .. } => "AcpClientDisconnected", + Event::AcpClientSessionOpened { .. } => "AcpClientSessionOpened", + Event::AcpClientSessionRouted { .. } => "AcpClientSessionRouted", + Event::MessagesListed { .. } => "MessagesListed", + Event::MessageStarted { .. } => "MessageStarted", + Event::MessageCompleted { .. } => "MessageCompleted", + Event::MessageFailed { .. } => "MessageFailed", + Event::ConnectorCreated { .. } => "ConnectorCreated", + Event::ConnectorRemoved { .. } => "ConnectorRemoved", + Event::ConnectorStateChanged { .. } => "ConnectorStateChanged", + Event::Connected => "Connected", + Event::Disconnected => "Disconnected", + Event::Error { .. } => "Error", + Event::InspectorSnapshot { .. } => "InspectorSnapshot", + Event::InspectorNodeRegistered { .. } => "InspectorNodeRegistered", + Event::InspectorNodeRemoved { .. } => "InspectorNodeRemoved", + Event::InspectorStateChanged { .. } => "InspectorStateChanged", + Event::InspectorPropertiesUpdated { .. } => "InspectorPropertiesUpdated", + Event::SessionRegistered { .. } => "SessionRegistered", + Event::SystemTaskStatusChanged { .. } => "SystemTaskStatusChanged", + } +} + +// --------------------------------------------------------------------------- +// SessionShare trait implementation +// --------------------------------------------------------------------------- + +#[async_trait::async_trait] +impl dirigent_protocol::sharing::SessionShare for MatrixSessionShare { + fn summary(&self) -> dirigent_protocol::sharing::ShareSummary { + dirigent_protocol::sharing::ShareSummary { + id: format!("matrix:{}:{}", self.connector_id, self.session_id), + connector_id: self.connector_id.clone(), + session_id: self.session_id.clone(), + backend: "matrix".to_string(), + destination: self.room_id.clone(), + active: self.is_active.try_read().map(|g| *g).unwrap_or(false), + } + } + + fn is_active(&self) -> bool { + self.is_active.try_read().map(|g| *g).unwrap_or(false) + } + + async fn shutdown(&self) { + // Delegate to the existing shutdown method (same implementation) + let tx = self.shutdown_tx.write().await.take(); + if let Some(tx) = tx { + let _ = tx.send(()); + } + } +} + +// --------------------------------------------------------------------------- +// SessionStream trait implementation (Phase 4 migration, Task 18) +// --------------------------------------------------------------------------- +// +// Dual-impl: `MatrixSessionShare` keeps its bi-directional `SessionShare` +// impl (room management, `inject_message`) while also gaining a +// uni-directional `SessionStream` impl so the `StreamRegistry` can drive it +// via the central `SharingBus`. +// +// When driven as a stream, the caller (factory) is expected to construct +// the share via `new_for_stream(..)` so a Room handle is stored on the +// struct. `on_event` then translates bus events back to the legacy +// `handle_event` dispatcher, preserving the accumulator state across +// calls via the shared `Arc<Mutex<MessageAccumulator>>`. + +#[async_trait::async_trait] +impl dirigent_protocol::streaming::SessionStream for MatrixSessionShare { + fn summary(&self) -> dirigent_protocol::streaming::StreamSummary { + dirigent_protocol::streaming::StreamSummary { + name: format!("{}:{}", self.connector_id, self.session_id), + kind: dirigent_protocol::streaming::StreamKind::Matrix, + target: format!("matrix:{}", self.room_id), + active_since: self.active_since, + } + } + + fn scope(&self) -> dirigent_protocol::streaming::StreamScope { + dirigent_protocol::streaming::StreamScope::Session { + scroll_id: self.scroll_id, + } + } + + async fn on_event( + &self, + event: &dirigent_protocol::streaming::BusEvent, + ) -> dirigent_protocol::streaming::StreamOutcome { + // If we were started via the legacy `start()` path, the Room + // handle lives inside the forwarder task and this method has no + // way to deliver. Treat that as a deliberate skip rather than a + // transport failure. + let room = match &self.room_for_stream { + Some(r) => r, + None => return dirigent_protocol::streaming::StreamOutcome::Skipped, + }; + + let mut acc = self.accumulator.lock().await; + handle_event( + &*event.event, + &self.connector_id, + &self.session_id, + room, + &mut *acc, + ) + .await; + dirigent_protocol::streaming::StreamOutcome::Ok + } + + async fn shutdown(&self) { + // Delegate to the bi-directional shutdown; both impls share the + // same underlying oneshot + is_active signal. + let tx = self.shutdown_tx.write().await.take(); + if let Some(tx) = tx { + let _ = tx.send(()); + } + *self.is_active.write().await = false; + } +} diff --git a/crates/dirigent_process/Cargo.toml b/crates/dirigent_process/Cargo.toml new file mode 100644 index 0000000..cfa4a31 --- /dev/null +++ b/crates/dirigent_process/Cargo.toml @@ -0,0 +1,32 @@ +[package] +name = "dirigent_process" +version = "0.1.0" +edition = "2021" +description = "Cross-platform process lifecycle management for Dirigent" + +[lib] +path = "src/lib.rs" + +[features] +default = [] +tokio = ["dep:tokio"] + +[dependencies] +tracing = "0.1" +tokio = { version = "1", features = ["process", "time"], optional = true } + +[target.'cfg(windows)'.dependencies] +windows-sys = { version = "0.59", features = [ + "Win32_System_JobObjects", + "Win32_System_Threading", + "Win32_Foundation", + "Win32_System_Console", + "Win32_Security", +] } + +[target.'cfg(unix)'.dependencies] +nix = { version = "0.29", features = ["signal", "process"] } +libc = "0.2" + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/crates/dirigent_process/src/lib.rs b/crates/dirigent_process/src/lib.rs new file mode 100644 index 0000000..66a9f28 --- /dev/null +++ b/crates/dirigent_process/src/lib.rs @@ -0,0 +1,30 @@ +pub mod traits; +mod shutdown; + +#[cfg(windows)] +mod windows; +#[cfg(unix)] +mod unix; +#[cfg(target_os = "linux")] +mod linux; + +pub use traits::{ProcessGroupManager, ProcessLifecycle}; +pub use shutdown::graceful_shutdown_sync; +#[cfg(feature = "tokio")] +pub use shutdown::graceful_shutdown_async; + +use std::sync::Arc; + +/// Create the platform-appropriate ProcessGroupManager. +/// +/// Call `init()` on the returned manager before use. +pub fn create_manager() -> Arc<dyn ProcessGroupManager> { + #[cfg(windows)] + { Arc::new(windows::WindowsProcessGroupManager::new()) } + + #[cfg(target_os = "linux")] + { Arc::new(linux::LinuxProcessGroupManager::new()) } + + #[cfg(all(unix, not(target_os = "linux")))] + { Arc::new(unix::UnixProcessGroupManager::new()) } +} diff --git a/crates/dirigent_process/src/linux.rs b/crates/dirigent_process/src/linux.rs new file mode 100644 index 0000000..50fa07e --- /dev/null +++ b/crates/dirigent_process/src/linux.rs @@ -0,0 +1,91 @@ +#![cfg(target_os = "linux")] + +use crate::traits::{ProcessGroupManager, ProcessLifecycle}; +use nix::sys::signal::{killpg, Signal}; +use nix::unistd::Pid; +use std::io; +use std::os::unix::process::CommandExt; +use tracing::{debug, info, warn}; + +/// Linux process group manager with kernel-level orphan prevention. +/// +/// Uses `PR_SET_CHILD_SUBREAPER` so orphaned grandchildren are reparented +/// to this process, and `PR_SET_PDEATHSIG` so children auto-die when +/// the parent crashes. +pub struct LinuxProcessGroupManager; + +impl LinuxProcessGroupManager { + pub fn new() -> Self { Self } +} + +impl ProcessGroupManager for LinuxProcessGroupManager { + fn init(&self) -> Result<(), io::Error> { + unsafe { + if libc::prctl(libc::PR_SET_CHILD_SUBREAPER, 1, 0, 0, 0) != 0 { + let err = io::Error::last_os_error(); + warn!(error = %err, "Failed to set PR_SET_CHILD_SUBREAPER"); + return Err(err); + } + } + info!("Linux process group manager initialized (child subreaper enabled)"); + Ok(()) + } + + fn create_lifecycle(&self) -> Box<dyn ProcessLifecycle> { + Box::new(LinuxProcessLifecycle) + } +} + +pub struct LinuxProcessLifecycle; + +impl ProcessLifecycle for LinuxProcessLifecycle { + fn configure_command(&self, cmd: &mut std::process::Command) { + unsafe { + cmd.pre_exec(|| { + if libc::setpgid(0, 0) != 0 { + return Err(io::Error::last_os_error()); + } + if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGKILL, 0, 0, 0) != 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + }); + } + } + + #[cfg(feature = "tokio")] + fn configure_async_command(&self, cmd: &mut tokio::process::Command) { + unsafe { + cmd.pre_exec(|| { + if libc::setpgid(0, 0) != 0 { + return Err(io::Error::last_os_error()); + } + if libc::prctl(libc::PR_SET_PDEATHSIG, libc::SIGKILL, 0, 0, 0) != 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + }); + } + } + + fn register_child(&self, pid: u32) -> Result<(), io::Error> { + debug!(pid, pgid = pid, "Linux child registered (process group + PR_SET_PDEATHSIG)"); + Ok(()) + } + + fn send_shutdown_signal(&self, pid: u32) -> Result<(), io::Error> { + let pgid = Pid::from_raw(pid as i32); + killpg(pgid, Signal::SIGTERM) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + debug!(pid, "Sent SIGTERM to process group"); + Ok(()) + } + + fn send_kill_signal(&self, pid: u32) -> Result<(), io::Error> { + let pgid = Pid::from_raw(pid as i32); + killpg(pgid, Signal::SIGKILL) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + debug!(pid, "Sent SIGKILL to process group"); + Ok(()) + } +} diff --git a/crates/dirigent_process/src/shutdown.rs b/crates/dirigent_process/src/shutdown.rs new file mode 100644 index 0000000..dba4483 --- /dev/null +++ b/crates/dirigent_process/src/shutdown.rs @@ -0,0 +1,66 @@ +use crate::traits::ProcessLifecycle; +use std::time::Duration; + +/// Graceful shutdown: send signal → wait → force kill (sync, blocking). +/// +/// Returns `true` if the process exited within the timeout, `false` if force-killed. +pub fn graceful_shutdown_sync( + lifecycle: &dyn ProcessLifecycle, + child: &mut std::process::Child, + timeout: Duration, +) -> bool { + let pid = child.id(); + if pid == 0 { + return true; + } + + if lifecycle.send_shutdown_signal(pid).is_err() { + return true; + } + + let start = std::time::Instant::now(); + let poll_interval = Duration::from_millis(50); + + while start.elapsed() < timeout { + match child.try_wait() { + Ok(Some(_)) => return true, + Ok(None) => std::thread::sleep(poll_interval), + Err(_) => return true, + } + } + + tracing::debug!(pid, "Graceful shutdown timed out, force killing"); + let _ = lifecycle.send_kill_signal(pid); + let _ = child.wait(); + false +} + +/// Graceful shutdown: send signal → wait → force kill (async, non-blocking). +/// +/// Returns `true` if the process exited within the timeout, `false` if force-killed. +#[cfg(feature = "tokio")] +pub async fn graceful_shutdown_async( + lifecycle: &dyn ProcessLifecycle, + child: &mut tokio::process::Child, + timeout: Duration, +) -> bool { + let pid = match child.id() { + Some(0) | None => return true, + Some(pid) => pid, + }; + + if lifecycle.send_shutdown_signal(pid).is_err() { + return true; + } + + match tokio::time::timeout(timeout, child.wait()).await { + Ok(Ok(_)) => true, + Ok(Err(_)) => true, + Err(_) => { + tracing::debug!(pid, "Graceful shutdown timed out, force killing"); + let _ = lifecycle.send_kill_signal(pid); + let _ = child.wait().await; + false + } + } +} diff --git a/crates/dirigent_process/src/traits.rs b/crates/dirigent_process/src/traits.rs new file mode 100644 index 0000000..b81c865 --- /dev/null +++ b/crates/dirigent_process/src/traits.rs @@ -0,0 +1,40 @@ +use std::io; + +/// Global process group manager — one per application lifetime. +/// +/// On Windows, owns a Job Object with KILL_ON_JOB_CLOSE. +/// On Linux, configures the process as a child subreaper. +/// On macOS, no-op (process groups handle cleanup). +pub trait ProcessGroupManager: Send + Sync { + /// Initialize platform-specific parent process configuration. + fn init(&self) -> Result<(), io::Error>; + + /// Create a lifecycle handle for managing a child process. + fn create_lifecycle(&self) -> Box<dyn ProcessLifecycle>; +} + +/// Per-child process lifecycle manager. +/// +/// All methods are synchronous — OS signal/handle calls are instant. +/// For timeout-based shutdown, use the free functions in the `shutdown` module. +pub trait ProcessLifecycle: Send + Sync { + /// Configure a std::process::Command before spawning. + /// Sets platform-specific flags (process group, creation flags, pre_exec hooks). + fn configure_command(&self, cmd: &mut std::process::Command); + + /// Configure a tokio::process::Command before spawning. + #[cfg(feature = "tokio")] + fn configure_async_command(&self, cmd: &mut tokio::process::Command); + + /// Register a spawned child with the lifecycle manager. + /// Must be called immediately after spawn with the child's PID. + fn register_child(&self, pid: u32) -> Result<(), io::Error>; + + /// Send a graceful shutdown signal to the process (and its tree). + /// Windows: CTRL_BREAK_EVENT. Unix: SIGTERM to process group. + fn send_shutdown_signal(&self, pid: u32) -> Result<(), io::Error>; + + /// Forcefully kill the process (and its tree). + /// Windows: TerminateProcess. Unix: SIGKILL to process group. + fn send_kill_signal(&self, pid: u32) -> Result<(), io::Error>; +} diff --git a/crates/dirigent_process/src/unix.rs b/crates/dirigent_process/src/unix.rs new file mode 100644 index 0000000..a63b470 --- /dev/null +++ b/crates/dirigent_process/src/unix.rs @@ -0,0 +1,78 @@ +#![cfg(unix)] + +use crate::traits::{ProcessGroupManager, ProcessLifecycle}; +use nix::sys::signal::{killpg, Signal}; +use nix::unistd::Pid; +use std::io; +use std::os::unix::process::CommandExt; +use tracing::{debug, info}; + +/// macOS / generic Unix process group manager. +/// +/// Uses process groups for tree management. No kernel-level orphan +/// prevention (macOS lacks `PR_SET_PDEATHSIG`). Relies on launchd +/// supervision for crash recovery. +pub struct UnixProcessGroupManager; + +impl UnixProcessGroupManager { + pub fn new() -> Self { Self } +} + +impl ProcessGroupManager for UnixProcessGroupManager { + fn init(&self) -> Result<(), io::Error> { + info!("Unix process group manager initialized"); + Ok(()) + } + + fn create_lifecycle(&self) -> Box<dyn ProcessLifecycle> { + Box::new(UnixProcessLifecycle) + } +} + +pub struct UnixProcessLifecycle; + +impl ProcessLifecycle for UnixProcessLifecycle { + fn configure_command(&self, cmd: &mut std::process::Command) { + unsafe { + cmd.pre_exec(|| { + if libc::setpgid(0, 0) != 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + }); + } + } + + #[cfg(feature = "tokio")] + fn configure_async_command(&self, cmd: &mut tokio::process::Command) { + unsafe { + cmd.pre_exec(|| { + if libc::setpgid(0, 0) != 0 { + return Err(io::Error::last_os_error()); + } + Ok(()) + }); + } + } + + fn register_child(&self, pid: u32) -> Result<(), io::Error> { + debug!(pid, pgid = pid, "Child registered in its own process group"); + Ok(()) + } + + fn send_shutdown_signal(&self, pid: u32) -> Result<(), io::Error> { + let pgid = Pid::from_raw(pid as i32); + killpg(pgid, Signal::SIGTERM) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + debug!(pid, "Sent SIGTERM to process group"); + Ok(()) + } + + fn send_kill_signal(&self, pid: u32) -> Result<(), io::Error> { + let pgid = Pid::from_raw(pid as i32); + killpg(pgid, Signal::SIGKILL) + .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + debug!(pid, "Sent SIGKILL to process group"); + Ok(()) + } +} diff --git a/crates/dirigent_process/src/windows.rs b/crates/dirigent_process/src/windows.rs new file mode 100644 index 0000000..90e1dad --- /dev/null +++ b/crates/dirigent_process/src/windows.rs @@ -0,0 +1,199 @@ +#![cfg(windows)] + +use crate::traits::{ProcessGroupManager, ProcessLifecycle}; +use std::io; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Mutex; +use tracing::{debug, info, warn}; +use windows_sys::Win32::Foundation::{CloseHandle, FALSE, HANDLE}; +use windows_sys::Win32::System::JobObjects::{ + AssignProcessToJobObject, CreateJobObjectW, JobObjectExtendedLimitInformation, + SetInformationJobObject, JOBOBJECT_EXTENDED_LIMIT_INFORMATION, JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE, +}; +use windows_sys::Win32::System::Threading::{ + OpenProcess, TerminateProcess, PROCESS_ALL_ACCESS, +}; +use windows_sys::Win32::System::Console::{ + GenerateConsoleCtrlEvent, CTRL_BREAK_EVENT, +}; + +/// Windows process group manager using Job Objects. +/// +/// Creates a Job Object with `JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE` — when +/// this manager is dropped (or the process crashes), the OS automatically +/// kills all assigned child processes including grandchildren. +pub struct WindowsProcessGroupManager { + /// Wrapped in a Mutex so we can mutate through a shared reference after init. + job_handle: Mutex<HANDLE>, + initialized: AtomicBool, +} + +// Safety: HANDLE (*mut c_void) is not Send/Sync by default, but we only +// mutate it during init() (guarded by AtomicBool + Mutex) and read it +// (via copy) in create_lifecycle() and drop. No concurrent mutation occurs. +unsafe impl Send for WindowsProcessGroupManager {} +unsafe impl Sync for WindowsProcessGroupManager {} + +impl WindowsProcessGroupManager { + pub fn new() -> Self { + Self { + job_handle: Mutex::new(std::ptr::null_mut()), + initialized: AtomicBool::new(false), + } + } + + fn handle(&self) -> HANDLE { + *self.job_handle.lock().unwrap() + } +} + +impl Default for WindowsProcessGroupManager { + fn default() -> Self { + Self::new() + } +} + +impl ProcessGroupManager for WindowsProcessGroupManager { + fn init(&self) -> Result<(), io::Error> { + if self.initialized.swap(true, Ordering::SeqCst) { + return Ok(()); + } + + unsafe { + let handle = CreateJobObjectW(std::ptr::null(), std::ptr::null()); + if handle.is_null() { + self.initialized.store(false, Ordering::SeqCst); + return Err(io::Error::last_os_error()); + } + + let mut info: JOBOBJECT_EXTENDED_LIMIT_INFORMATION = std::mem::zeroed(); + info.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + + let result = SetInformationJobObject( + handle, + JobObjectExtendedLimitInformation, + &info as *const _ as *const _, + std::mem::size_of::<JOBOBJECT_EXTENDED_LIMIT_INFORMATION>() as u32, + ); + + if result == FALSE { + CloseHandle(handle); + self.initialized.store(false, Ordering::SeqCst); + return Err(io::Error::last_os_error()); + } + + *self.job_handle.lock().unwrap() = handle; + + info!("Windows Job Object created with KILL_ON_JOB_CLOSE"); + Ok(()) + } + } + + fn create_lifecycle(&self) -> Box<dyn ProcessLifecycle> { + Box::new(WindowsProcessLifecycle { + job_handle: self.handle(), + }) + } +} + +impl Drop for WindowsProcessGroupManager { + fn drop(&mut self) { + let handle = self.handle(); + if !handle.is_null() { + unsafe { CloseHandle(handle); } + debug!("Windows Job Object closed"); + } + } +} + +/// Per-child lifecycle manager for Windows. +/// +/// Assigns children to the parent's Job Object and uses +/// `CTRL_BREAK_EVENT` / `TerminateProcess` for shutdown. +pub struct WindowsProcessLifecycle { + job_handle: HANDLE, +} + +// Safety: same reasoning as WindowsProcessGroupManager — HANDLE is used +// read-only after construction (only passed to OS APIs). +unsafe impl Send for WindowsProcessLifecycle {} +unsafe impl Sync for WindowsProcessLifecycle {} + +impl ProcessLifecycle for WindowsProcessLifecycle { + fn configure_command(&self, cmd: &mut std::process::Command) { + use std::os::windows::process::CommandExt; + const CREATE_NEW_PROCESS_GROUP: u32 = 0x0000_0200; + const CREATE_NO_WINDOW: u32 = 0x0800_0000; + cmd.creation_flags(CREATE_NEW_PROCESS_GROUP | CREATE_NO_WINDOW); + } + + #[cfg(feature = "tokio")] + fn configure_async_command(&self, cmd: &mut tokio::process::Command) { + use std::os::windows::process::CommandExt; + const CREATE_NEW_PROCESS_GROUP: u32 = 0x0000_0200; + const CREATE_NO_WINDOW: u32 = 0x0800_0000; + cmd.creation_flags(CREATE_NEW_PROCESS_GROUP | CREATE_NO_WINDOW); + } + + fn register_child(&self, pid: u32) -> Result<(), io::Error> { + if self.job_handle.is_null() { + warn!(pid, "Job Object not initialized, skipping child registration"); + return Ok(()); + } + + unsafe { + let process_handle = OpenProcess(PROCESS_ALL_ACCESS, FALSE, pid); + if process_handle.is_null() { + return Err(io::Error::last_os_error()); + } + + let result = AssignProcessToJobObject(self.job_handle, process_handle); + CloseHandle(process_handle); + + if result == FALSE { + let err = io::Error::last_os_error(); + warn!(pid, error = %err, "Failed to assign process to Job Object (may already be in a job)"); + return Err(err); + } + + debug!(pid, "Process assigned to Job Object"); + Ok(()) + } + } + + fn send_shutdown_signal(&self, pid: u32) -> Result<(), io::Error> { + unsafe { + if GenerateConsoleCtrlEvent(CTRL_BREAK_EVENT, pid) == FALSE { + return Err(io::Error::last_os_error()); + } + } + debug!(pid, "Sent CTRL_BREAK_EVENT"); + Ok(()) + } + + fn send_kill_signal(&self, pid: u32) -> Result<(), io::Error> { + unsafe { + let handle = OpenProcess(PROCESS_ALL_ACCESS, FALSE, pid); + if handle.is_null() { + let err = io::Error::last_os_error(); + // ERROR_INVALID_PARAMETER (87) means process already exited + if err.raw_os_error() == Some(87) { + return Ok(()); + } + return Err(err); + } + let result = TerminateProcess(handle, 1); + CloseHandle(handle); + if result == FALSE { + let err = io::Error::last_os_error(); + // ERROR_ACCESS_DENIED (5) — process may have already exited + if err.raw_os_error() == Some(5) { + return Ok(()); + } + return Err(err); + } + } + debug!(pid, "Sent TerminateProcess"); + Ok(()) + } +} diff --git a/crates/dirigent_process/tests/lifecycle.rs b/crates/dirigent_process/tests/lifecycle.rs new file mode 100644 index 0000000..39a5264 --- /dev/null +++ b/crates/dirigent_process/tests/lifecycle.rs @@ -0,0 +1,149 @@ +use dirigent_process::{create_manager, graceful_shutdown_sync}; +use std::process::Command; +use std::time::Duration; + +/// Build a long-running command that does not require a TTY on any platform. +/// +/// On Windows, `timeout /t N /nobreak` fails when stdin is a pipe (no console), +/// so we use `ping -n N 127.0.0.1` which sleeps approximately N-1 seconds with +/// no TTY requirement. +/// +/// On Unix, `sleep N` is the idiomatic choice. +#[cfg(windows)] +fn long_sleep_cmd(seconds: u32) -> Command { + let mut cmd = Command::new("ping"); + cmd.args(["-n", &seconds.to_string(), "127.0.0.1"]); + cmd +} + +#[cfg(unix)] +fn long_sleep_cmd(seconds: u32) -> Command { + let mut cmd = Command::new("sleep"); + cmd.arg(seconds.to_string()); + cmd +} + +#[cfg(all(windows, feature = "tokio"))] +fn long_sleep_async_cmd(seconds: u32) -> tokio::process::Command { + let mut cmd = tokio::process::Command::new("ping"); + cmd.args(["-n", &seconds.to_string(), "127.0.0.1"]); + cmd +} + +#[cfg(all(unix, feature = "tokio"))] +fn long_sleep_async_cmd(seconds: u32) -> tokio::process::Command { + let mut cmd = tokio::process::Command::new("sleep"); + cmd.arg(seconds.to_string()); + cmd +} + +#[test] +fn test_manager_init() { + let mgr = create_manager(); + mgr.init().expect("init should succeed"); + // Double init should also succeed (idempotent) + mgr.init().expect("double init should succeed"); +} + +#[test] +fn test_create_lifecycle() { + let mgr = create_manager(); + mgr.init().expect("init failed"); + let _lifecycle = mgr.create_lifecycle(); +} + +#[test] +fn test_configure_and_spawn() { + let mgr = create_manager(); + mgr.init().expect("init failed"); + let lifecycle = mgr.create_lifecycle(); + + let mut cmd = long_sleep_cmd(30); + lifecycle.configure_command(&mut cmd); + + let mut child = cmd.spawn().expect("spawn failed"); + let pid = child.id(); + assert!(pid > 0); + + // Register should succeed + lifecycle.register_child(pid).expect("register failed"); + + // Process should still be running + assert!(child.try_wait().expect("try_wait failed").is_none()); + + // Clean up + let _ = child.kill(); + let _ = child.wait(); +} + +#[test] +fn test_graceful_shutdown_sync() { + let mgr = create_manager(); + mgr.init().expect("init failed"); + let lifecycle = mgr.create_lifecycle(); + + let mut cmd = long_sleep_cmd(60); + lifecycle.configure_command(&mut cmd); + let mut child = cmd.spawn().expect("spawn failed"); + let pid = child.id(); + lifecycle.register_child(pid).expect("register failed"); + + // Graceful shutdown with 3s timeout — process won't exit voluntarily, + // so it should be force-killed after timeout + let exited_gracefully = graceful_shutdown_sync( + lifecycle.as_ref(), + &mut child, + Duration::from_secs(3), + ); + + // Process should be dead now + assert!(child.try_wait().expect("try_wait failed").is_some()); + // It was force-killed (ping/sleep don't handle SIGTERM/CTRL_BREAK) + assert!(!exited_gracefully); +} + +#[test] +fn test_send_kill_signal() { + let mgr = create_manager(); + mgr.init().expect("init failed"); + let lifecycle = mgr.create_lifecycle(); + + let mut cmd = long_sleep_cmd(60); + lifecycle.configure_command(&mut cmd); + let mut child = cmd.spawn().expect("spawn failed"); + let pid = child.id(); + lifecycle.register_child(pid).expect("register failed"); + + // Direct kill signal + lifecycle.send_kill_signal(pid).expect("kill failed"); + + // Wait for process to die + let status = child.wait().expect("wait failed"); + assert!(!status.success()); +} + +#[cfg(feature = "tokio")] +#[tokio::test] +async fn test_async_graceful_shutdown() { + use dirigent_process::graceful_shutdown_async; + + let mgr = create_manager(); + mgr.init().expect("init failed"); + let lifecycle = mgr.create_lifecycle(); + + let mut cmd = long_sleep_async_cmd(60); + lifecycle.configure_async_command(&mut cmd); + let mut child = cmd.spawn().expect("spawn failed"); + let pid = child.id().expect("no pid"); + lifecycle.register_child(pid).expect("register failed"); + + let exited_gracefully = graceful_shutdown_async( + lifecycle.as_ref(), + &mut child, + Duration::from_secs(3), + ) + .await; + + assert!(child.try_wait().expect("try_wait failed").is_some()); + assert!(!exited_gracefully); +} diff --git a/crates/dirigent_projects/Cargo.toml b/crates/dirigent_projects/Cargo.toml new file mode 100644 index 0000000..bb6b6a0 --- /dev/null +++ b/crates/dirigent_projects/Cargo.toml @@ -0,0 +1,33 @@ +[package] +name = "dirigent_projects" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[dependencies] +# Async traits +async-trait = "0.1" +# Date/time handling +chrono = { version = "0.4", features = ["serde"] } +dirigent_auth = { path = "../dirigent_auth" } +# Home directory resolution +dirs = "6" +# Protocol types (WASM-compatible project types) +dirigent_protocol = { path = "../dirigent_protocol" } +# Serialization +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +# Error handling +thiserror = "2.0" +# Async runtime and file operations +tokio = { version = "1", features = ["fs", "io-util", "process", "sync"] } +# Logging +tracing = "0.1" +# UUID support with v7 and serde +uuid = { version = "1.0", features = ["serde", "v7"] } + +[dev-dependencies] +tempfile = "3.0" +tokio = { version = "1", features = ["fs", "macros", "rt-multi-thread", "sync"] } diff --git a/crates/dirigent_projects/src/detection.rs b/crates/dirigent_projects/src/detection.rs new file mode 100644 index 0000000..6ad36cb --- /dev/null +++ b/crates/dirigent_projects/src/detection.rs @@ -0,0 +1,751 @@ +//! Project detection and import support. +//! +//! Provides path normalization, worktree detection, multi-path grouping, +//! and matching logic to link discovered import paths to existing projects. + +use std::collections::HashMap; +use std::path::PathBuf; + +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use dirigent_protocol::project::{Project, ProjectRepository}; + +use crate::error::{ProjectError, Result}; +use crate::params::{AddRepositoryParams, CreateProjectParams}; +use crate::traits::ProjectStore; + +// --------------------------------------------------------------------------- +// DTOs +// --------------------------------------------------------------------------- + +/// A project discovered during import, before resolution against existing projects. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DetectedProject { + /// Filesystem path as discovered (pre-normalization may have been applied). + pub discovered_path: String, + /// Suggested name derived from the path (e.g. last directory component). + pub suggested_name: String, + /// Number of sessions associated with this discovered path. + pub session_count: usize, + /// How this detection was resolved against existing projects. + pub resolution: ProjectResolution, +} + +/// How a detected project path was resolved against the existing project store. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ProjectResolution { + /// Matched an existing project and repository. + Linked { + project_id: Uuid, + project_name: String, + matched_repository_id: Uuid, + }, + /// No match found — suggests creating a new project. + CreateNew { name: String }, + /// The user chose to skip this detection. + Skip, +} + +/// Full result of running project detection over a set of import discoveries. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProjectDetectionResult { + /// One entry per discovered path. + pub detections: Vec<DetectedProject>, + /// Hints about git worktree relationships. + pub worktree_hints: Vec<WorktreeHint>, + /// Hints about paths that share a common parent. + pub multi_path_hints: Vec<MultiPathHint>, +} + +/// Hint that a path is (or may be) a git worktree. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WorktreeHint { + /// The worktree path itself. + pub worktree_path: String, + /// The main repository path (parsed from `.git` file), if resolved. + pub main_repo_path: Option<String>, +} + +/// Hint that multiple discovered paths share a common immediate parent. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MultiPathHint { + /// The shared parent directory. + pub shared_parent: String, + /// The child paths that share this parent. + pub paths: Vec<String>, +} + +/// Request to create a project from an import detection. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportProjectCreationRequest { + /// Project name. + pub name: String, + /// Primary repository path. + pub primary_path: String, + /// Additional repository paths. + #[serde(default)] + pub additional_paths: Vec<String>, + /// Optional icon. + #[serde(skip_serializing_if = "Option::is_none")] + pub icon: Option<String>, + /// Tags for the new project. + #[serde(default)] + pub tags: Vec<String>, + /// Programming languages. + #[serde(default)] + pub languages: Vec<String>, +} + +/// Result of creating a project from an import request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ImportProjectCreationResult { + /// The created project's ID. + pub project_id: Uuid, + /// The created project's name. + pub project_name: String, + /// How many repositories were created (primary + additional). + pub repositories_created: usize, +} + +/// Lightweight input describing a project discovered during import. +/// +/// This mirrors the shape used by import discovery (name + path + session count). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiscoveredImportProject { + /// Project name (typically the directory basename or user-facing label). + pub name: String, + /// Filesystem path associated with this project. + pub path: String, + /// Number of sessions discovered under this path. + pub session_count: usize, +} + +// --------------------------------------------------------------------------- +// Path normalization +// --------------------------------------------------------------------------- + +/// Normalize a filesystem path for consistent cross-platform comparison. +/// +/// Steps (in order): +/// 1. Try `std::fs::canonicalize()` — if it succeeds, use that (resolves symlinks, +/// `..`, etc.) and convert to forward slashes. +/// 2. On failure, apply textual normalization: +/// - Backslash -> forward slash +/// - MinGW `/c/Users/...` -> `C:/Users/...` +/// - WSL `/mnt/c/Users/...` -> `C:/Users/...` +/// - UNC `\\server\share` -> `//server/share` +/// - Tilde `~/foo` -> expanded home + `/foo` +/// - Collapse `//` -> `/` (except leading UNC) +/// - Resolve `.` and `..` segments +/// - Strip trailing `/` +/// 3. On Windows, lowercase the entire result for case-insensitive comparison. +pub fn normalize_project_path(path: &str) -> String { + // Try canonical resolution first. + if let Ok(canonical) = std::fs::canonicalize(path) { + let mut s = canonical.to_string_lossy().replace('\\', "/"); + // Strip trailing slash unless it's a root like "C:/" + if s.len() > 1 && s.ends_with('/') && !s.ends_with(":/") { + s.pop(); + } + return platform_case_normalize(s); + } + + // Textual fallback. + let mut s = path.replace('\\', "/"); + + // Tilde expansion. + if s.starts_with("~/") || s == "~" { + if let Some(home) = home_dir_string() { + if s == "~" { + s = home; + } else { + s = format!("{}/{}", home.trim_end_matches('/'), &s[2..]); + } + } + } + + // MinGW: /c/Users/... -> C:/Users/... + if let Some(rest) = try_strip_mingw(&s) { + s = rest; + } + + // WSL: /mnt/c/Users/... -> C:/Users/... + if let Some(rest) = try_strip_wsl(&s) { + s = rest; + } + + // UNC already converted by backslash replacement: //server/share is fine. + + // Collapse double slashes (preserve leading // for UNC). + s = collapse_slashes(&s); + + // Resolve `.` and `..` segments textually. + s = resolve_dots(&s); + + // Strip trailing slash (unless root). + if s.len() > 1 && s.ends_with('/') && !s.ends_with(":/") { + s.pop(); + } + + platform_case_normalize(s) +} + +fn home_dir_string() -> Option<String> { + dirs::home_dir().map(|p| p.to_string_lossy().replace('\\', "/")) +} + +fn try_strip_mingw(s: &str) -> Option<String> { + let bytes = s.as_bytes(); + // Pattern: /X/... where X is a single ASCII letter + if bytes.len() >= 3 + && bytes[0] == b'/' + && bytes[1].is_ascii_alphabetic() + && bytes[2] == b'/' + { + let drive = (bytes[1] as char).to_ascii_uppercase(); + Some(format!("{}:/{}", drive, &s[3..])) + } else { + None + } +} + +fn try_strip_wsl(s: &str) -> Option<String> { + if let Some(rest) = s.strip_prefix("/mnt/") { + let bytes = rest.as_bytes(); + if !bytes.is_empty() && bytes[0].is_ascii_alphabetic() { + let drive = (bytes[0] as char).to_ascii_uppercase(); + let remainder = if bytes.len() > 1 && bytes[1] == b'/' { + &rest[2..] + } else if bytes.len() == 1 { + "" + } else { + return None; // e.g. /mnt/cdrom — not a drive letter + }; + return Some(format!("{}:/{}", drive, remainder)); + } + } + None +} + +fn collapse_slashes(s: &str) -> String { + let mut result = String::with_capacity(s.len()); + let mut chars = s.chars().peekable(); + + // Preserve leading double slash for UNC. + if s.starts_with("//") { + result.push('/'); + result.push('/'); + chars.next(); + chars.next(); + // Skip any additional leading slashes beyond the two. + while chars.peek() == Some(&'/') { + chars.next(); + } + } + + let mut prev_slash = false; + for c in chars { + if c == '/' { + if !prev_slash { + result.push(c); + } + prev_slash = true; + } else { + result.push(c); + prev_slash = false; + } + } + result +} + +fn resolve_dots(s: &str) -> String { + // Split on '/', resolve `.` and `..` textually. + let mut parts: Vec<&str> = Vec::new(); + let prefix = if s.starts_with("//") { + "//" + } else if s.starts_with('/') { + "/" + } else { + "" + }; + + for segment in s.split('/') { + match segment { + "" | "." => {} + ".." => { + // Don't pop past the root. + if !parts.is_empty() && *parts.last().unwrap() != ".." { + parts.pop(); + } + } + other => parts.push(other), + } + } + + let joined = parts.join("/"); + if prefix.is_empty() { + joined + } else { + format!("{}{}", prefix, joined) + } +} + +#[cfg(target_os = "windows")] +fn platform_case_normalize(s: String) -> String { + s.to_lowercase() +} + +#[cfg(not(target_os = "windows"))] +fn platform_case_normalize(s: String) -> String { + s +} + +// --------------------------------------------------------------------------- +// Worktree detection +// --------------------------------------------------------------------------- + +/// Check whether the given path is a git worktree (`.git` is a file, not a directory). +/// +/// If it is, parses the `gitdir:` pointer to determine the main repository path. +pub fn detect_worktree(path: &str) -> Option<WorktreeHint> { + let dot_git = PathBuf::from(path).join(".git"); + + // Only interested if .git is a *file* (worktree pointer), not a directory. + let meta = std::fs::symlink_metadata(&dot_git).ok()?; + if !meta.is_file() { + return None; + } + + let content = std::fs::read_to_string(&dot_git).ok()?; + let gitdir_line = content + .lines() + .find(|l| l.starts_with("gitdir:"))?; + + let gitdir_raw = gitdir_line["gitdir:".len()..].trim(); + + // The gitdir path typically looks like `/path/to/main-repo/.git/worktrees/<name>`. + // Walk up to find the main repo root. + let gitdir_path = if PathBuf::from(gitdir_raw).is_absolute() { + PathBuf::from(gitdir_raw) + } else { + PathBuf::from(path).join(gitdir_raw) + }; + + // Try to resolve: .../main-repo/.git/worktrees/xxx -> .../main-repo + let main_repo = gitdir_path + .ancestors() + .find(|ancestor| { + // Check if this ancestor has `.git` as a child (actual git dir, not worktree file). + let git_child = ancestor.join(".git"); + git_child.is_dir() + }) + .map(|p| normalize_project_path(&p.to_string_lossy())); + + Some(WorktreeHint { + worktree_path: normalize_project_path(path), + main_repo_path: main_repo, + }) +} + +// --------------------------------------------------------------------------- +// Multi-path grouping +// --------------------------------------------------------------------------- + +/// Group paths that share a common immediate parent directory. +/// +/// Only produces hints for groups of 2+ paths. +pub fn find_multi_path_groups(paths: &[String]) -> Vec<MultiPathHint> { + let mut by_parent: HashMap<String, Vec<String>> = HashMap::new(); + + for path in paths { + let normalized = normalize_project_path(path); + // Find immediate parent by stripping last component. + if let Some(parent) = PathBuf::from(&normalized).parent() { + let parent_str = parent.to_string_lossy().replace('\\', "/"); + by_parent + .entry(parent_str) + .or_default() + .push(normalized); + } + } + + by_parent + .into_iter() + .filter(|(_, children)| children.len() >= 2) + .map(|(parent, mut children)| { + children.sort(); + MultiPathHint { + shared_parent: parent, + paths: children, + } + }) + .collect() +} + +// --------------------------------------------------------------------------- +// Detection logic +// --------------------------------------------------------------------------- + +/// Match discovered import projects against existing projects. +/// +/// For each discovered path, attempts to find a match in the existing project +/// store using (in priority order): +/// 1. Exact normalized path match against any repository +/// 2. Canonical (fs::canonicalize) path match +/// 3. Name-based hint (project name == suggested name) +/// +/// Unmatched paths get `ProjectResolution::CreateNew`. +pub fn detect_projects( + discovered: &[DiscoveredImportProject], + existing_projects: &[(Project, Vec<ProjectRepository>)], +) -> ProjectDetectionResult { + // Pre-build a lookup from normalized repo paths -> (project, repo). + let mut path_index: HashMap<String, (&Project, &ProjectRepository)> = HashMap::new(); + let mut canonical_index: HashMap<String, (&Project, &ProjectRepository)> = HashMap::new(); + let mut name_index: HashMap<String, &Project> = HashMap::new(); + + for (project, repos) in existing_projects { + name_index.insert(project.name.to_lowercase(), project); + for repo in repos { + let repo_path_str = repo.path.to_string_lossy().to_string(); + let normalized = normalize_project_path(&repo_path_str); + path_index.insert(normalized.clone(), (project, repo)); + + // Also try canonical path of the repo. + if let Ok(canonical) = std::fs::canonicalize(&repo.path) { + let canon_norm = normalize_project_path(&canonical.to_string_lossy()); + canonical_index.insert(canon_norm, (project, repo)); + } + } + } + + let mut detections = Vec::with_capacity(discovered.len()); + let discovered_paths: Vec<String> = discovered.iter().map(|d| d.path.clone()).collect(); + let worktree_hints: Vec<WorktreeHint> = discovered_paths + .iter() + .filter_map(|p| detect_worktree(p)) + .collect(); + + for disc in discovered { + let normalized = normalize_project_path(&disc.path); + + // 1. Exact normalized path match. + if let Some((project, repo)) = path_index.get(&normalized) { + detections.push(DetectedProject { + discovered_path: disc.path.clone(), + suggested_name: disc.name.clone(), + session_count: disc.session_count, + resolution: ProjectResolution::Linked { + project_id: project.id, + project_name: project.name.clone(), + matched_repository_id: repo.id, + }, + }); + continue; + } + + // 2. Canonical path match. + let canon_norm = std::fs::canonicalize(&disc.path) + .map(|c| normalize_project_path(&c.to_string_lossy())) + .unwrap_or_default(); + if !canon_norm.is_empty() { + if let Some((project, repo)) = canonical_index.get(&canon_norm) { + detections.push(DetectedProject { + discovered_path: disc.path.clone(), + suggested_name: disc.name.clone(), + session_count: disc.session_count, + resolution: ProjectResolution::Linked { + project_id: project.id, + project_name: project.name.clone(), + matched_repository_id: repo.id, + }, + }); + continue; + } + } + + // 3. Name hint match. + let suggested_lower = derive_suggested_name(&disc.path).to_lowercase(); + if let Some(project) = name_index.get(&suggested_lower) { + // Find the primary repo or any repo to satisfy the linked variant. + let existing_repos = existing_projects + .iter() + .find(|(p, _)| p.id == project.id) + .map(|(_, repos)| repos); + if let Some(repos) = existing_repos { + if let Some(repo) = repos.iter().find(|r| r.is_primary).or(repos.first()) { + detections.push(DetectedProject { + discovered_path: disc.path.clone(), + suggested_name: disc.name.clone(), + session_count: disc.session_count, + resolution: ProjectResolution::Linked { + project_id: project.id, + project_name: project.name.clone(), + matched_repository_id: repo.id, + }, + }); + continue; + } + } + } + + // 4. No match — suggest creating. + let name = derive_suggested_name(&disc.path); + detections.push(DetectedProject { + discovered_path: disc.path.clone(), + suggested_name: disc.name.clone(), + session_count: disc.session_count, + resolution: ProjectResolution::CreateNew { name }, + }); + } + + let multi_path_hints = find_multi_path_groups(&discovered_paths); + + ProjectDetectionResult { + detections, + worktree_hints, + multi_path_hints, + } +} + +/// Derive a suggested project name from a path (last non-empty component). +fn derive_suggested_name(path: &str) -> String { + let normalized = path.replace('\\', "/"); + let trimmed = normalized.trim_end_matches('/'); + trimmed + .rsplit('/') + .next() + .unwrap_or(trimmed) + .to_string() +} + +// --------------------------------------------------------------------------- +// Project creation from import +// --------------------------------------------------------------------------- + +/// Create projects from a batch of import creation requests. +/// +/// For each request: creates the project, adds the primary repository, and +/// adds any additional repositories. Returns one result per request. +pub async fn create_projects_from_import( + store: &dyn ProjectStore, + requests: Vec<ImportProjectCreationRequest>, + owner: Uuid, +) -> Vec<Result<ImportProjectCreationResult>> { + let mut results = Vec::with_capacity(requests.len()); + + for req in requests { + results.push(create_single_project(store, req, owner).await); + } + + results +} + +async fn create_single_project( + store: &dyn ProjectStore, + req: ImportProjectCreationRequest, + owner: Uuid, +) -> Result<ImportProjectCreationResult> { + let project = store + .create_project(CreateProjectParams { + name: req.name.clone(), + description: String::new(), + icon: req.icon, + owner, + tags: req.tags, + languages: req.languages, + metadata: serde_json::Value::Object(serde_json::Map::new()), + }) + .await?; + + let mut repos_created: usize = 0; + + // Primary repository. + store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from(&req.primary_path), + is_primary: true, + label: None, + }) + .await?; + repos_created += 1; + + // Additional repositories. + for additional in &req.additional_paths { + match store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from(additional), + is_primary: false, + label: None, + }) + .await + { + Ok(_) => repos_created += 1, + Err(e) => { + tracing::warn!( + project_id = %project.id, + path = %additional, + error = %e, + "Failed to add additional repository during import" + ); + } + } + } + + Ok(ImportProjectCreationResult { + project_id: project.id, + project_name: project.name, + repositories_created: repos_created, + }) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn normalize_backslashes() { + let result = normalize_project_path("C:\\Users\\alice\\project"); + assert!(result.contains('/')); + assert!(!result.contains('\\')); + } + + #[test] + fn normalize_mingw_path() { + let result = normalize_project_path("/c/Users/alice/project"); + assert!( + result.starts_with("C:/") || result.starts_with("c:/"), + "Expected drive letter prefix, got: {}", + result + ); + } + + #[test] + fn normalize_wsl_path() { + let result = normalize_project_path("/mnt/c/Users/alice/project"); + assert!( + result.starts_with("C:/") || result.starts_with("c:/"), + "Expected drive letter prefix, got: {}", + result + ); + } + + #[test] + fn normalize_strips_trailing_slash() { + let result = normalize_project_path("/home/alice/project/"); + assert!(!result.ends_with('/')); + } + + #[test] + fn normalize_resolves_dots() { + // Textual fallback since this path won't exist on disk. + let result = normalize_project_path("/home/alice/./project/../project/src"); + assert!(result.contains("/home/alice/project/src") || result.ends_with("project/src")); + } + + #[test] + fn normalize_collapses_double_slashes() { + let result = normalize_project_path("/home//alice///project"); + assert!(!result.contains("//") || result.starts_with("//")); + } + + #[test] + fn derive_suggested_name_basic() { + assert_eq!(derive_suggested_name("/home/alice/my-project"), "my-project"); + assert_eq!(derive_suggested_name("C:\\Users\\bob\\work"), "work"); + assert_eq!(derive_suggested_name("/home/alice/my-project/"), "my-project"); + } + + #[test] + fn multi_path_groups_basic() { + let paths = vec![ + "/home/alice/projects/foo".to_string(), + "/home/alice/projects/bar".to_string(), + "/home/alice/work/baz".to_string(), + ]; + let groups = find_multi_path_groups(&paths); + // foo and bar share /home/alice/projects, baz is alone under /home/alice/work + let multi = groups + .iter() + .find(|g| g.paths.len() == 2); + assert!(multi.is_some(), "Expected a group with 2 paths"); + } + + #[test] + fn detect_projects_creates_new_for_unmatched() { + let discovered = vec![DiscoveredImportProject { + name: "my-project".to_string(), + path: "/nonexistent/path/my-project".to_string(), + session_count: 5, + }]; + let existing: Vec<(Project, Vec<ProjectRepository>)> = vec![]; + let result = detect_projects(&discovered, &existing); + assert_eq!(result.detections.len(), 1); + match &result.detections[0].resolution { + ProjectResolution::CreateNew { name } => { + assert_eq!(name, "my-project"); + } + other => panic!("Expected CreateNew, got {:?}", other), + } + } + + #[test] + fn detect_projects_links_by_name() { + use chrono::Utc; + let project_id = Uuid::now_v7(); + let repo_id = Uuid::now_v7(); + let now = Utc::now(); + + let project = Project { + id: project_id, + name: "dirigent".to_string(), + description: String::new(), + icon: None, + owner: Uuid::nil(), + members: vec![], + tags: vec![], + languages: vec![], + linked_projects: vec![], + metadata: serde_json::json!({}), + created_at: now, + updated_at: now, + }; + let repo = ProjectRepository { + id: repo_id, + project_id, + path: PathBuf::from("/other/path/dirigent"), + is_primary: true, + label: None, + access: dirigent_protocol::project::AccessMode::ReadWrite, + created_at: now, + updated_at: now, + }; + + let discovered = vec![DiscoveredImportProject { + name: "dirigent".to_string(), + path: "/somewhere/else/dirigent".to_string(), + session_count: 3, + }]; + let result = detect_projects(&discovered, &[(project, vec![repo])]); + assert_eq!(result.detections.len(), 1); + match &result.detections[0].resolution { + ProjectResolution::Linked { + project_id: pid, + matched_repository_id: rid, + .. + } => { + assert_eq!(*pid, project_id); + assert_eq!(*rid, repo_id); + } + other => panic!("Expected Linked, got {:?}", other), + } + } +} diff --git a/crates/dirigent_projects/src/error.rs b/crates/dirigent_projects/src/error.rs new file mode 100644 index 0000000..1842c5d --- /dev/null +++ b/crates/dirigent_projects/src/error.rs @@ -0,0 +1,43 @@ +//! Error types for the Projects module. + +use thiserror::Error; +use uuid::Uuid; + +/// Errors that can occur in project operations. +#[derive(Debug, Error)] +pub enum ProjectError { + /// Project not found + #[error("project not found: {0}")] + NotFound(Uuid), + + /// Project already exists + #[error("project already exists: {0}")] + AlreadyExists(Uuid), + + /// Repository not found + #[error("repository not found: {0}")] + RepositoryNotFound(Uuid), + + /// Worktree not found + #[error("worktree not found: {0}")] + WorktreeNotFound(Uuid), + + /// Binding not found + #[error("binding not found: {0}")] + BindingNotFound(Uuid), + + /// Validation error + #[error("validation error: {0}")] + Validation(String), + + /// Storage I/O error + #[error("storage error: {0}")] + Storage(#[from] std::io::Error), + + /// Serialization error + #[error("serialization error: {0}")] + Serialization(#[from] serde_json::Error), +} + +/// Result type alias for project operations. +pub type Result<T> = std::result::Result<T, ProjectError>; diff --git a/crates/dirigent_projects/src/file_store.rs b/crates/dirigent_projects/src/file_store.rs new file mode 100644 index 0000000..e6fe4c5 --- /dev/null +++ b/crates/dirigent_projects/src/file_store.rs @@ -0,0 +1,441 @@ +//! File-based ProjectStore implementation. +//! +//! Uses one directory per project under a configurable root. +//! Follows the archivist pattern with atomic JSON writes. + +use crate::error::{ProjectError, Result}; +use crate::params::*; +use crate::storage::io::{read_json, read_json_or_default, write_json}; +use crate::storage::paths::ProjectPaths; +use crate::traits::ProjectStore; +use chrono::Utc; +use dirigent_protocol::project::{ + AccessMode, Project, ProjectBinding, ProjectRepository, Worktree, +}; +use std::path::PathBuf; +use tracing::{debug, info}; +use uuid::Uuid; + +/// File-based project store. +/// +/// Each project gets its own directory under the root: +/// ```text +/// root/ +/// {project_uuid}/ +/// project.json +/// repositories.json (Phase 2) +/// bindings.json (Phase 5) +/// worktrees.json (Phase 4) +/// ``` +pub struct FileBasedProjectStore { + paths: ProjectPaths, +} + +impl FileBasedProjectStore { + /// Create a new file-based store at the given root directory. + /// + /// The root directory will be created if it doesn't exist. + pub async fn new(root: impl Into<PathBuf>) -> std::io::Result<Self> { + let root = root.into(); + tokio::fs::create_dir_all(&root).await?; + Ok(Self { + paths: ProjectPaths::new(root), + }) + } + + /// Find which project owns a given repository ID. + async fn find_project_for_repo(&self, repo_id: &Uuid) -> Result<Uuid> { + let project_ids = self.scan_project_ids().await?; + for project_id in &project_ids { + let repos_path = self.paths.repositories_json(project_id); + let repos: Vec<ProjectRepository> = read_json_or_default(&repos_path).await?; + if repos.iter().any(|r| r.id == *repo_id) { + return Ok(*project_id); + } + } + Err(ProjectError::RepositoryNotFound(*repo_id)) + } + + /// Scan the root directory for project UUIDs. + async fn scan_project_ids(&self) -> Result<Vec<Uuid>> { + let mut ids = Vec::new(); + let mut entries = tokio::fs::read_dir(self.paths.root()).await?; + + while let Some(entry) = entries.next_entry().await? { + if entry.file_type().await?.is_dir() { + if let Some(name) = entry.file_name().to_str() { + if let Ok(uuid) = Uuid::parse_str(name) { + ids.push(uuid); + } + } + } + } + + Ok(ids) + } +} + +#[async_trait::async_trait] +impl ProjectStore for FileBasedProjectStore { + async fn create_project(&self, params: CreateProjectParams) -> Result<Project> { + // Validate + if params.name.trim().is_empty() { + return Err(ProjectError::Validation( + "project name cannot be empty".to_string(), + )); + } + + let now = Utc::now(); + let project = Project { + id: Uuid::now_v7(), + name: params.name, + description: params.description, + icon: params.icon, + owner: params.owner, + members: vec![], + tags: params.tags, + languages: params.languages, + linked_projects: vec![], + metadata: if params.metadata.is_null() { + serde_json::json!({}) + } else { + params.metadata + }, + created_at: now, + updated_at: now, + }; + + // Create project directory + let project_dir = self.paths.project_dir(&project.id); + tokio::fs::create_dir_all(&project_dir).await?; + + // Write project.json + let path = self.paths.project_json(&project.id); + write_json(&path, &project).await?; + + info!(project_id = %project.id, name = %project.name, "Created project"); + Ok(project) + } + + async fn get_project(&self, id: &Uuid) -> Result<Project> { + let path = self.paths.project_json(id); + read_json(&path).await.map_err(|e| match e.kind() { + std::io::ErrorKind::NotFound => ProjectError::NotFound(*id), + _ => ProjectError::Storage(e), + }) + } + + async fn list_projects(&self, filter: ProjectFilter) -> Result<Vec<Project>> { + let ids = self.scan_project_ids().await?; + let mut projects = Vec::new(); + + for id in ids { + match self.get_project(&id).await { + Ok(project) => { + // Apply filters + if let Some(ref owner) = filter.owner { + if project.owner != *owner { + continue; + } + } + if let Some(ref name_contains) = filter.name_contains { + if !project + .name + .to_lowercase() + .contains(&name_contains.to_lowercase()) + { + continue; + } + } + if !filter.tags.is_empty() + && !filter.tags.iter().all(|t| project.tags.contains(t)) + { + continue; + } + projects.push(project); + } + Err(e) => { + debug!(project_id = %id, error = %e, "Skipping unreadable project"); + } + } + } + + // Sort by name for consistent ordering + projects.sort_by(|a, b| a.name.cmp(&b.name)); + Ok(projects) + } + + async fn update_project(&self, id: &Uuid, update: ProjectUpdate) -> Result<Project> { + let mut project = self.get_project(id).await?; + + if let Some(name) = update.name { + if name.trim().is_empty() { + return Err(ProjectError::Validation( + "project name cannot be empty".to_string(), + )); + } + project.name = name; + } + if let Some(description) = update.description { + project.description = description; + } + if let Some(icon) = update.icon { + project.icon = icon; + } + if let Some(tags) = update.tags { + project.tags = tags; + } + if let Some(languages) = update.languages { + project.languages = languages; + } + if let Some(metadata) = update.metadata { + project.metadata = metadata; + } + + project.updated_at = Utc::now(); + + let path = self.paths.project_json(id); + write_json(&path, &project).await?; + + info!(project_id = %id, "Updated project"); + Ok(project) + } + + async fn delete_project(&self, id: &Uuid) -> Result<()> { + let project_dir = self.paths.project_dir(id); + if !project_dir.exists() { + return Err(ProjectError::NotFound(*id)); + } + + tokio::fs::remove_dir_all(&project_dir).await?; + info!(project_id = %id, "Deleted project"); + Ok(()) + } + + // --- Repository management (Phase 2 - scaffolded) --- + + async fn add_repository(&self, params: AddRepositoryParams) -> Result<ProjectRepository> { + // Ensure project exists + let _ = self.get_project(¶ms.project_id).await?; + + let now = Utc::now(); + let repo = ProjectRepository { + id: Uuid::now_v7(), + project_id: params.project_id, + path: params.path, + is_primary: params.is_primary, + label: params.label, + access: AccessMode::ReadWrite, + created_at: now, + updated_at: now, + }; + + // Read existing repos, append, write back + let repos_path = self.paths.repositories_json(¶ms.project_id); + let mut repos: Vec<ProjectRepository> = read_json_or_default(&repos_path).await?; + + // If this is primary, unset others + if repo.is_primary { + for r in repos.iter_mut() { + r.is_primary = false; + } + } + + repos.push(repo.clone()); + write_json(&repos_path, &repos).await?; + + info!(repo_id = %repo.id, project_id = %params.project_id, "Added repository"); + Ok(repo) + } + + async fn remove_repository(&self, id: &Uuid) -> Result<()> { + // Scan all projects to find the repo + let project_ids = self.scan_project_ids().await?; + for project_id in project_ids { + let repos_path = self.paths.repositories_json(&project_id); + let mut repos: Vec<ProjectRepository> = read_json_or_default(&repos_path).await?; + let original_len = repos.len(); + repos.retain(|r| r.id != *id); + if repos.len() < original_len { + write_json(&repos_path, &repos).await?; + info!(repo_id = %id, "Removed repository"); + return Ok(()); + } + } + Err(ProjectError::RepositoryNotFound(*id)) + } + + async fn set_primary_repository(&self, project_id: &Uuid, repo_id: &Uuid) -> Result<()> { + let repos_path = self.paths.repositories_json(project_id); + let mut repos: Vec<ProjectRepository> = read_json_or_default(&repos_path).await?; + + let mut found = false; + for r in repos.iter_mut() { + r.is_primary = r.id == *repo_id; + if r.id == *repo_id { + found = true; + } + } + + if !found { + return Err(ProjectError::RepositoryNotFound(*repo_id)); + } + + write_json(&repos_path, &repos).await?; + Ok(()) + } + + async fn list_repositories(&self, project_id: &Uuid) -> Result<Vec<ProjectRepository>> { + let repos_path = self.paths.repositories_json(project_id); + Ok(read_json_or_default(&repos_path).await?) + } + + // --- Worktrees (Phase 4) --- + + async fn add_worktree(&self, params: AddWorktreeParams) -> Result<Worktree> { + // Find which project owns this repository + let project_id = self.find_project_for_repo(¶ms.repository_id).await?; + + let worktree = Worktree { + id: Uuid::now_v7(), + repository_id: params.repository_id, + path: params.path, + branch: params.branch, + work_branch: params.work_branch, + naming_strategy: params.naming_strategy, + created_at: Utc::now(), + }; + + let wt_path = self.paths.worktrees_json(&project_id); + let mut worktrees: Vec<Worktree> = read_json_or_default(&wt_path).await?; + worktrees.push(worktree.clone()); + write_json(&wt_path, &worktrees).await?; + + info!(worktree_id = %worktree.id, repo_id = %params.repository_id, "Added worktree"); + Ok(worktree) + } + + async fn remove_worktree(&self, worktree_id: &Uuid) -> Result<()> { + let project_ids = self.scan_project_ids().await?; + for project_id in project_ids { + let wt_path = self.paths.worktrees_json(&project_id); + let mut worktrees: Vec<Worktree> = read_json_or_default(&wt_path).await?; + let original_len = worktrees.len(); + worktrees.retain(|w| w.id != *worktree_id); + if worktrees.len() < original_len { + write_json(&wt_path, &worktrees).await?; + info!(worktree_id = %worktree_id, "Removed worktree"); + return Ok(()); + } + } + Err(ProjectError::WorktreeNotFound(*worktree_id)) + } + + async fn list_worktrees(&self, repository_id: &Uuid) -> Result<Vec<Worktree>> { + let project_id = self.find_project_for_repo(repository_id).await?; + let wt_path = self.paths.worktrees_json(&project_id); + let all: Vec<Worktree> = read_json_or_default(&wt_path).await?; + Ok(all + .into_iter() + .filter(|w| w.repository_id == *repository_id) + .collect()) + } + + async fn update_worktree( + &self, + worktree_id: &Uuid, + update: WorktreeUpdate, + ) -> Result<Worktree> { + let project_ids = self.scan_project_ids().await?; + for project_id in &project_ids { + let wt_path = self.paths.worktrees_json(project_id); + let mut worktrees: Vec<Worktree> = read_json_or_default(&wt_path).await?; + if let Some(wt) = worktrees.iter_mut().find(|w| w.id == *worktree_id) { + if let Some(branch) = update.branch { + wt.branch = branch; + } + if let Some(work_branch) = update.work_branch { + wt.work_branch = work_branch; + } + let updated = wt.clone(); + write_json(&wt_path, &worktrees).await?; + return Ok(updated); + } + } + Err(ProjectError::WorktreeNotFound(*worktree_id)) + } + + // --- Bindings (Phase 5 - scaffolded) --- + + async fn bind(&self, params: BindParams) -> Result<ProjectBinding> { + let _ = self.get_project(¶ms.project_id).await?; + + let binding = ProjectBinding { + id: Uuid::now_v7(), + project_id: params.project_id, + connector_id: params.connector_id, + session_id: params.session_id, + working_dir: params.working_dir, + }; + + let bindings_path = self.paths.bindings_json(¶ms.project_id); + let mut bindings: Vec<ProjectBinding> = read_json_or_default(&bindings_path).await?; + bindings.push(binding.clone()); + write_json(&bindings_path, &bindings).await?; + + Ok(binding) + } + + async fn unbind(&self, binding_id: &Uuid) -> Result<()> { + let project_ids = self.scan_project_ids().await?; + for project_id in project_ids { + let bindings_path = self.paths.bindings_json(&project_id); + let mut bindings: Vec<ProjectBinding> = read_json_or_default(&bindings_path).await?; + let original_len = bindings.len(); + bindings.retain(|b| b.id != *binding_id); + if bindings.len() < original_len { + write_json(&bindings_path, &bindings).await?; + return Ok(()); + } + } + Err(ProjectError::BindingNotFound(*binding_id)) + } + + async fn list_bindings(&self, project_id: &Uuid) -> Result<Vec<ProjectBinding>> { + let bindings_path = self.paths.bindings_json(project_id); + Ok(read_json_or_default(&bindings_path).await?) + } + + // --- Resolution (Phase 2 - scaffolded) --- + + async fn resolve_working_dir( + &self, + project_id: &Uuid, + repo_id: Option<&Uuid>, + ) -> Result<PathBuf> { + // Verify the project exists before falling back to default_working_dir. + // Without this, a missing project directory yields an empty repo list + // (via read_json_or_default) and silently falls through to the default. + self.get_project(project_id).await?; + + let repos = self.list_repositories(project_id).await?; + + // If repo_id specified, use that + if let Some(rid) = repo_id { + if let Some(repo) = repos.iter().find(|r| r.id == *rid) { + return Ok(repo.path.clone()); + } + return Err(ProjectError::RepositoryNotFound(*rid)); + } + + // Use primary repo, or first repo + if let Some(repo) = repos.iter().find(|r| r.is_primary).or(repos.first()) { + return Ok(repo.path.clone()); + } + + Err(ProjectError::Validation(format!( + "project {} has no repositories configured", + project_id + ))) + } +} diff --git a/crates/dirigent_projects/src/git/mod.rs b/crates/dirigent_projects/src/git/mod.rs new file mode 100644 index 0000000..6217ef0 --- /dev/null +++ b/crates/dirigent_projects/src/git/mod.rs @@ -0,0 +1,13 @@ +//! Git integration. +//! +//! - `GitRunner` executes git commands against a local repository +//! - `compute_git_state()` aggregates runner output into a `GitState` +//! - Worktree workflows (follow/take) for branch management + +pub mod runner; +pub mod state; +pub mod worktree; + +pub use runner::GitRunner; +pub use state::compute_git_state; +pub use worktree::{follow, take}; diff --git a/crates/dirigent_projects/src/git/runner.rs b/crates/dirigent_projects/src/git/runner.rs new file mode 100644 index 0000000..dc63abb --- /dev/null +++ b/crates/dirigent_projects/src/git/runner.rs @@ -0,0 +1,353 @@ +//! Git command runner. +//! +//! Wraps `tokio::process::Command` to execute git operations on a local +//! repository path. All methods return structured results with proper +//! error handling for git-not-installed, not-a-repo, etc. + +use crate::error::{ProjectError, Result}; +use std::path::{Path, PathBuf}; +use tokio::process::Command; + +/// Executes git commands against a local repository. +#[derive(Clone, Debug)] +pub struct GitRunner { + repo_path: PathBuf, +} + +/// Parsed output of `git status --porcelain=v2 --branch`. +#[derive(Clone, Debug, Default)] +pub struct GitStatus { + /// Current branch (empty if detached HEAD) + pub branch: String, + /// Whether there are uncommitted changes + pub is_dirty: bool, + /// Commits ahead of upstream + pub ahead: u32, + /// Commits behind upstream + pub behind: u32, +} + +/// Parsed output of `git worktree list --porcelain`. +#[derive(Clone, Debug)] +pub struct WorktreeEntry { + /// Worktree filesystem path + pub path: PathBuf, + /// Branch checked out (None if detached) + pub branch: Option<String>, + /// Whether HEAD is detached + pub is_detached: bool, + /// Whether this is a bare repository worktree + pub is_bare: bool, +} + +impl GitRunner { + /// Create a new runner for the given repository path. + pub fn new(repo_path: impl Into<PathBuf>) -> Self { + Self { + repo_path: repo_path.into(), + } + } + + /// Path this runner operates on. + pub fn repo_path(&self) -> &Path { + &self.repo_path + } + + /// Get the current branch name. + /// + /// Returns empty string if HEAD is detached. + pub async fn current_branch(&self) -> Result<String> { + let output = self.git(&["rev-parse", "--abbrev-ref", "HEAD"]).await?; + let branch = output.trim().to_string(); + // rev-parse returns "HEAD" when detached + if branch == "HEAD" { + Ok(String::new()) + } else { + Ok(branch) + } + } + + /// Get the current status (branch, dirty, ahead/behind). + pub async fn status(&self) -> Result<GitStatus> { + let output = self.git(&["status", "--porcelain=v2", "--branch"]).await?; + parse_status(&output) + } + + /// List remote names. + pub async fn remotes(&self) -> Result<Vec<String>> { + let output = self.git(&["remote"]).await?; + Ok(output + .lines() + .map(|l| l.trim().to_string()) + .filter(|l| !l.is_empty()) + .collect()) + } + + /// Fetch from a remote (defaults to "origin"). + pub async fn fetch(&self, remote: Option<&str>) -> Result<()> { + let remote = remote.unwrap_or("origin"); + self.git(&["fetch", remote, "--quiet"]).await?; + Ok(()) + } + + /// List worktrees via `git worktree list --porcelain`. + pub async fn worktree_list(&self) -> Result<Vec<WorktreeEntry>> { + let output = self.git(&["worktree", "list", "--porcelain"]).await?; + Ok(parse_worktree_list(&output)) + } + + /// Add a worktree at the given path for the given branch. + pub async fn worktree_add(&self, path: &Path, branch: &str) -> Result<()> { + self.git(&["worktree", "add", &path.to_string_lossy(), branch]) + .await?; + Ok(()) + } + + /// Remove a worktree at the given path. + pub async fn worktree_remove(&self, path: &Path, force: bool) -> Result<()> { + let path_str = path.to_string_lossy(); + let mut args = vec!["worktree", "remove", &*path_str]; + if force { + args.push("--force"); + } + self.git(&args).await?; + Ok(()) + } + + /// Checkout a branch. + pub async fn checkout(&self, branch: &str) -> Result<()> { + self.git(&["checkout", branch]).await?; + Ok(()) + } + + /// Commit staged changes with the given message. Returns the commit hash. + pub async fn commit(&self, message: &str) -> Result<String> { + self.git(&["commit", "-m", message]).await?; + let hash = self.git(&["rev-parse", "HEAD"]).await?; + Ok(hash.trim().to_string()) + } + + /// Squash-merge from a source branch. + pub async fn merge_squash(&self, source_branch: &str) -> Result<()> { + self.git(&["merge", "--squash", source_branch]).await?; + Ok(()) + } + + /// Hard-reset to a target ref. + pub async fn reset_hard(&self, target: &str) -> Result<()> { + self.git(&["reset", "--hard", target]).await?; + Ok(()) + } + + // ======================================================================== + // Internal helpers + // ======================================================================== + + /// Execute a git command and return stdout on success. + async fn git(&self, args: &[&str]) -> Result<String> { + let output = Command::new("git") + .args(args) + .current_dir(&self.repo_path) + .output() + .await + .map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + ProjectError::Validation("git is not installed or not in PATH".into()) + } else { + ProjectError::Storage(e) + } + })?; + + if output.status.success() { + Ok(String::from_utf8_lossy(&output.stdout).to_string()) + } else { + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + Err(ProjectError::Validation(format!( + "git {} failed: {}", + args.first().unwrap_or(&""), + stderr.trim() + ))) + } + } +} + +// ============================================================================ +// Parsers +// ============================================================================ + +/// Parse `git status --porcelain=v2 --branch` output. +fn parse_status(output: &str) -> Result<GitStatus> { + let mut status = GitStatus::default(); + + for line in output.lines() { + if let Some(rest) = line.strip_prefix("# branch.head ") { + status.branch = rest.trim().to_string(); + if status.branch == "(detached)" { + status.branch = String::new(); + } + } else if let Some(rest) = line.strip_prefix("# branch.ab ") { + // Format: "+N -M" + for part in rest.split_whitespace() { + if let Some(ahead) = part.strip_prefix('+') { + status.ahead = ahead.parse().unwrap_or(0); + } else if let Some(behind) = part.strip_prefix('-') { + status.behind = behind.parse().unwrap_or(0); + } + } + } else if line.starts_with('1') + || line.starts_with('2') + || line.starts_with('u') + || line.starts_with('?') + { + // Any tracked/untracked change means dirty + status.is_dirty = true; + } + } + + Ok(status) +} + +/// Parse `git worktree list --porcelain` output. +/// +/// Porcelain format outputs blocks separated by blank lines: +/// ```text +/// worktree /path/to/main +/// HEAD abc123 +/// branch refs/heads/main +/// +/// worktree /path/to/feature +/// HEAD def456 +/// branch refs/heads/feature +/// ``` +fn parse_worktree_list(output: &str) -> Vec<WorktreeEntry> { + let mut entries = Vec::new(); + let mut current_path: Option<PathBuf> = None; + let mut current_branch: Option<String> = None; + let mut is_detached = false; + let mut is_bare = false; + + for line in output.lines() { + if line.is_empty() { + // End of block — flush + if let Some(path) = current_path.take() { + entries.push(WorktreeEntry { + path, + branch: current_branch.take(), + is_detached, + is_bare, + }); + } + is_detached = false; + is_bare = false; + } else if let Some(rest) = line.strip_prefix("worktree ") { + current_path = Some(PathBuf::from(rest)); + } else if let Some(rest) = line.strip_prefix("branch ") { + // Strip refs/heads/ prefix + current_branch = Some(rest.strip_prefix("refs/heads/").unwrap_or(rest).to_string()); + } else if line == "detached" { + is_detached = true; + } else if line == "bare" { + is_bare = true; + } + } + + // Flush last block (output may not end with blank line) + if let Some(path) = current_path.take() { + entries.push(WorktreeEntry { + path, + branch: current_branch.take(), + is_detached, + is_bare, + }); + } + + entries +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_status_clean() { + let output = "# branch.head main\n# branch.ab +0 -0\n"; + let status = parse_status(output).unwrap(); + assert_eq!(status.branch, "main"); + assert!(!status.is_dirty); + assert_eq!(status.ahead, 0); + assert_eq!(status.behind, 0); + } + + #[test] + fn test_parse_status_dirty_ahead_behind() { + let output = "\ +# branch.head feature +# branch.ab +3 -1 +1 .M N... 100644 100644 100644 abc123 def456 src/main.rs +? new_file.txt +"; + let status = parse_status(output).unwrap(); + assert_eq!(status.branch, "feature"); + assert!(status.is_dirty); + assert_eq!(status.ahead, 3); + assert_eq!(status.behind, 1); + } + + #[test] + fn test_parse_status_detached() { + let output = "# branch.head (detached)\n"; + let status = parse_status(output).unwrap(); + assert_eq!(status.branch, ""); + } + + #[test] + fn test_parse_worktree_list() { + let output = "\ +worktree /home/user/project +HEAD abc123def456 +branch refs/heads/main + +worktree /home/user/project-feature +HEAD 789012345678 +branch refs/heads/feature + +worktree /home/user/project-detached +HEAD aabbccdd +detached + +"; + let entries = parse_worktree_list(output); + assert_eq!(entries.len(), 3); + + assert_eq!(entries[0].path, PathBuf::from("/home/user/project")); + assert_eq!(entries[0].branch, Some("main".to_string())); + assert!(!entries[0].is_detached); + + assert_eq!(entries[1].path, PathBuf::from("/home/user/project-feature")); + assert_eq!(entries[1].branch, Some("feature".to_string())); + assert!(!entries[1].is_detached); + + assert_eq!( + entries[2].path, + PathBuf::from("/home/user/project-detached") + ); + assert!(entries[2].branch.is_none()); + assert!(entries[2].is_detached); + } + + #[test] + fn test_parse_worktree_list_no_trailing_newline() { + let output = "worktree /repo\nHEAD abc\nbranch refs/heads/main"; + let entries = parse_worktree_list(output); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].branch, Some("main".to_string())); + } + + #[test] + fn test_parse_worktree_list_bare() { + let output = "worktree /repo.git\nHEAD abc\nbare\n\n"; + let entries = parse_worktree_list(output); + assert_eq!(entries.len(), 1); + assert!(entries[0].is_bare); + } +} diff --git a/crates/dirigent_projects/src/git/state.rs b/crates/dirigent_projects/src/git/state.rs new file mode 100644 index 0000000..61b8dd9 --- /dev/null +++ b/crates/dirigent_projects/src/git/state.rs @@ -0,0 +1,77 @@ +//! GitState computation from GitRunner output. +//! +//! Aggregates branch, status, remotes, and worktrees into a single +//! `GitState` struct with graceful degradation via `GitWarning`. + +use crate::git::runner::GitRunner; +use dirigent_protocol::project::{GitState, GitWarning, WorktreeInfo}; + +/// Compute the full git state for a repository. +/// +/// Calls branch, status, remotes, and worktree_list. Any individual +/// failure is captured as a `GitWarning` rather than failing the whole +/// computation. +pub async fn compute_git_state(runner: &GitRunner) -> GitState { + let mut state = GitState::default(); + let mut warnings = Vec::new(); + + // Status (includes branch + dirty + ahead/behind) + match runner.status().await { + Ok(status) => { + state.branch = status.branch; + state.is_dirty = status.is_dirty; + state.ahead = status.ahead; + state.behind = status.behind; + } + Err(e) => { + warnings.push(GitWarning { + code: "status_failed".to_string(), + message: format!("Failed to get git status: {e}"), + }); + // Try branch separately as fallback + match runner.current_branch().await { + Ok(branch) => state.branch = branch, + Err(e) => { + warnings.push(GitWarning { + code: "branch_failed".to_string(), + message: format!("Failed to get current branch: {e}"), + }); + } + } + } + } + + // Remotes + match runner.remotes().await { + Ok(remotes) => state.remotes = remotes, + Err(e) => { + warnings.push(GitWarning { + code: "remotes_failed".to_string(), + message: format!("Failed to list remotes: {e}"), + }); + } + } + + // Worktrees + match runner.worktree_list().await { + Ok(entries) => { + state.worktrees = entries + .into_iter() + .map(|e| WorktreeInfo { + path: e.path, + branch: e.branch, + is_detached: e.is_detached, + }) + .collect(); + } + Err(e) => { + warnings.push(GitWarning { + code: "worktrees_failed".to_string(), + message: format!("Failed to list worktrees: {e}"), + }); + } + } + + state.unexpected = warnings; + state +} diff --git a/crates/dirigent_projects/src/git/worktree.rs b/crates/dirigent_projects/src/git/worktree.rs new file mode 100644 index 0000000..9cc1ab2 --- /dev/null +++ b/crates/dirigent_projects/src/git/worktree.rs @@ -0,0 +1,62 @@ +//! Worktree workflow implementations. +//! +//! - **follow**: Hard-reset a work branch to track a target branch (e.g. main) +//! - **take**: Squash-merge changes from a worktree branch into a target branch + +use crate::error::{ProjectError, Result}; +use crate::git::runner::GitRunner; + +/// Follow workflow: hard-reset work_branch to match target_branch. +/// +/// This is used when a worktree's work branch needs to catch up with +/// the main branch. After this operation, work_branch HEAD will be +/// identical to target_branch HEAD. +/// +/// **Destructive**: Discards any uncommitted changes on work_branch. +pub async fn follow(runner: &GitRunner, work_branch: &str, target_branch: &str) -> Result<()> { + // Ensure we're on the work branch + let current = runner.current_branch().await?; + if current != work_branch { + runner.checkout(work_branch).await?; + } + + runner.reset_hard(target_branch).await?; + Ok(()) +} + +/// Take workflow: squash-merge changes from source_branch into target_branch. +/// +/// This brings all the work from source_branch into target_branch as a +/// single commit. If `auto_commit` is true, the squash is committed +/// automatically with a generated message. +/// +/// Returns the commit hash if auto_commit is true, None otherwise +/// (leaving the changes staged for manual commit). +pub async fn take( + runner: &GitRunner, + source_branch: &str, + target_branch: &str, + auto_commit: bool, +) -> Result<Option<String>> { + // Switch to target branch + let current = runner.current_branch().await?; + if current != target_branch { + runner.checkout(target_branch).await?; + } + + // Squash-merge + runner.merge_squash(source_branch).await.map_err(|e| { + ProjectError::Validation(format!( + "squash-merge from '{}' into '{}' failed: {}", + source_branch, target_branch, e + )) + })?; + + if auto_commit { + let message = format!("Squash merge from {}", source_branch); + let hash = runner.commit(&message).await?; + Ok(Some(hash)) + } else { + Ok(None) + } +} diff --git a/crates/dirigent_projects/src/lib.rs b/crates/dirigent_projects/src/lib.rs new file mode 100644 index 0000000..b47d98e --- /dev/null +++ b/crates/dirigent_projects/src/lib.rs @@ -0,0 +1,44 @@ +//! Dirigent Projects +//! +//! Project management crate for the Dirigent system. Trait-based, +//! file-backed, async-first (following the archivist pattern). +//! +//! # Architecture +//! +//! - `ProjectStore` trait defines the storage interface +//! - `FileBasedProjectStore` implements file-backed persistence +//! - Storage uses one directory per project with atomic JSON writes +//! - Protocol types from `dirigent_protocol::project` are shared with WASM +//! +//! # Phases +//! +//! - Phase 1: Project CRUD (implemented) +//! - Phase 2: Repository management, working dir resolution (scaffolded) +//! - Phase 3: Git integration (scaffolded) +//! - Phase 4: Worktree support (scaffolded) +//! - Phase 5: Bindings (scaffolded) + +pub mod detection; +pub mod error; +pub mod file_store; +pub mod git; +pub mod params; +pub mod storage; +pub mod traits; + +// Re-export commonly used types +pub use error::{ProjectError, Result}; +pub use file_store::FileBasedProjectStore; +pub use params::{ + AddRepositoryParams, AddWorktreeParams, BindParams, CreateProjectParams, ProjectFilter, + ProjectUpdate, WorktreeUpdate, +}; +pub use traits::ProjectStore; + +// Re-export detection types +pub use detection::{ + create_projects_from_import, detect_projects, detect_worktree, find_multi_path_groups, + normalize_project_path, DetectedProject, DiscoveredImportProject, + ImportProjectCreationRequest, ImportProjectCreationResult, MultiPathHint, + ProjectDetectionResult, ProjectResolution, WorktreeHint, +}; diff --git a/crates/dirigent_projects/src/params.rs b/crates/dirigent_projects/src/params.rs new file mode 100644 index 0000000..fe7a483 --- /dev/null +++ b/crates/dirigent_projects/src/params.rs @@ -0,0 +1,127 @@ +//! Parameter types for project store operations. + +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use uuid::Uuid; + +/// Parameters for creating a new project. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct CreateProjectParams { + /// Human-readable project name + pub name: String, + /// Project description + #[serde(default)] + pub description: String, + /// Optional icon (emoji or abbreviation) + #[serde(skip_serializing_if = "Option::is_none")] + pub icon: Option<String>, + /// Owner user ID + pub owner: Uuid, + /// Initial tags + #[serde(default)] + pub tags: Vec<String>, + /// Initial languages + #[serde(default)] + pub languages: Vec<String>, + /// Arbitrary metadata + #[serde(default)] + pub metadata: serde_json::Value, +} + +/// Filter for listing projects. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct ProjectFilter { + /// Filter by owner + #[serde(skip_serializing_if = "Option::is_none")] + pub owner: Option<Uuid>, + /// Filter by tag (project must have all specified tags) + #[serde(default)] + pub tags: Vec<String>, + /// Filter by name substring (case-insensitive) + #[serde(skip_serializing_if = "Option::is_none")] + pub name_contains: Option<String>, +} + +/// Fields to update on a project. +/// +/// Only `Some` fields are applied; `None` fields are left unchanged. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct ProjectUpdate { + /// New name + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option<String>, + /// New description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option<String>, + /// New icon + #[serde(skip_serializing_if = "Option::is_none")] + pub icon: Option<Option<String>>, + /// New tags (replaces all) + #[serde(skip_serializing_if = "Option::is_none")] + pub tags: Option<Vec<String>>, + /// New languages (replaces all) + #[serde(skip_serializing_if = "Option::is_none")] + pub languages: Option<Vec<String>>, + /// New metadata (replaces all) + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option<serde_json::Value>, +} + +/// Parameters for adding a repository to a project. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AddRepositoryParams { + /// Project to add the repository to + pub project_id: Uuid, + /// Local filesystem path + pub path: PathBuf, + /// Whether this is the primary repository + #[serde(default)] + pub is_primary: bool, + /// Optional human-readable label + #[serde(skip_serializing_if = "Option::is_none")] + pub label: Option<String>, +} + +/// Parameters for adding a worktree to a repository. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AddWorktreeParams { + /// Repository this worktree belongs to + pub repository_id: Uuid, + /// Local filesystem path for the worktree + pub path: PathBuf, + /// Branch name + pub branch: String, + /// Optional work branch name + #[serde(skip_serializing_if = "Option::is_none")] + pub work_branch: Option<String>, + /// Optional naming strategy + #[serde(skip_serializing_if = "Option::is_none")] + pub naming_strategy: Option<String>, +} + +/// Fields to update on a worktree. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct WorktreeUpdate { + /// New branch + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option<String>, + /// New work branch + #[serde(skip_serializing_if = "Option::is_none")] + pub work_branch: Option<Option<String>>, +} + +/// Parameters for creating a project binding. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct BindParams { + /// Project to bind + pub project_id: Uuid, + /// Optional connector ID + #[serde(skip_serializing_if = "Option::is_none")] + pub connector_id: Option<String>, + /// Optional session ID + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option<Uuid>, + /// Optional working directory override + #[serde(skip_serializing_if = "Option::is_none")] + pub working_dir: Option<PathBuf>, +} diff --git a/crates/dirigent_projects/src/storage/io.rs b/crates/dirigent_projects/src/storage/io.rs new file mode 100644 index 0000000..143a951 --- /dev/null +++ b/crates/dirigent_projects/src/storage/io.rs @@ -0,0 +1,118 @@ +//! JSON read/write helpers with atomic writes. +//! +//! Follows the archivist pattern: write to .tmp, then rename. + +use serde::{Deserialize, Serialize}; +use std::path::Path; +use tokio::io::AsyncWriteExt; + +/// Write a value to a JSON file atomically. +/// +/// 1. Serializes to pretty-printed JSON +/// 2. Writes to `{path}.tmp` +/// 3. Renames temp file to target (atomic on most filesystems) +pub async fn write_json<T: Serialize>(path: &Path, value: &T) -> std::io::Result<()> { + let json = serde_json::to_string_pretty(value) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + let temp_path = path.with_extension("tmp"); + + let mut file = tokio::fs::File::create(&temp_path).await?; + file.write_all(json.as_bytes()).await?; + file.sync_all().await?; + drop(file); + + tokio::fs::rename(&temp_path, path).await?; + + Ok(()) +} + +/// Read a value from a JSON file. +/// +/// Returns `NotFound` if the file doesn't exist. +pub async fn read_json<T: for<'de> Deserialize<'de>>(path: &Path) -> std::io::Result<T> { + let content = tokio::fs::read_to_string(path).await?; + let value: T = serde_json::from_str(&content) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(value) +} + +/// Read a value from a JSON file, returning a default if the file doesn't exist. +pub async fn read_json_or_default<T: for<'de> Deserialize<'de> + Default>( + path: &Path, +) -> std::io::Result<T> { + match read_json(path).await { + Ok(value) => Ok(value), + Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(T::default()), + Err(e) => Err(e), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde::{Deserialize, Serialize}; + + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] + struct TestData { + id: String, + value: i32, + } + + #[tokio::test] + async fn test_write_and_read_roundtrip() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.json"); + + let data = TestData { + id: "test".to_string(), + value: 42, + }; + + write_json(&path, &data).await.unwrap(); + let read: TestData = read_json(&path).await.unwrap(); + assert_eq!(read, data); + } + + #[tokio::test] + async fn test_read_missing_file() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("missing.json"); + + let result: std::io::Result<TestData> = read_json(&path).await; + assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::NotFound); + } + + #[tokio::test] + async fn test_read_json_or_default() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("missing.json"); + + let result: Vec<String> = read_json_or_default(&path).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_atomic_overwrite() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("test.json"); + + let data1 = TestData { + id: "first".to_string(), + value: 1, + }; + let data2 = TestData { + id: "second".to_string(), + value: 2, + }; + + write_json(&path, &data1).await.unwrap(); + write_json(&path, &data2).await.unwrap(); + + let read: TestData = read_json(&path).await.unwrap(); + assert_eq!(read, data2); + + // Temp file should not remain + assert!(!path.with_extension("tmp").exists()); + } +} diff --git a/crates/dirigent_projects/src/storage/mod.rs b/crates/dirigent_projects/src/storage/mod.rs new file mode 100644 index 0000000..726de22 --- /dev/null +++ b/crates/dirigent_projects/src/storage/mod.rs @@ -0,0 +1,6 @@ +//! Storage layer for file-based project persistence. +//! +//! Follows the archivist pattern: atomic writes, JSON files, directory-per-project. + +pub mod io; +pub mod paths; diff --git a/crates/dirigent_projects/src/storage/paths.rs b/crates/dirigent_projects/src/storage/paths.rs new file mode 100644 index 0000000..1c91786 --- /dev/null +++ b/crates/dirigent_projects/src/storage/paths.rs @@ -0,0 +1,66 @@ +//! Path conventions for project storage. + +use std::path::{Path, PathBuf}; +use uuid::Uuid; + +/// Path helper for the projects storage root. +pub struct ProjectPaths { + root: PathBuf, +} + +impl ProjectPaths { + /// Create a new path helper. + pub fn new(root: impl Into<PathBuf>) -> Self { + Self { root: root.into() } + } + + /// Root directory for all projects. + pub fn root(&self) -> &Path { + &self.root + } + + /// Directory for a specific project. + pub fn project_dir(&self, project_id: &Uuid) -> PathBuf { + self.root.join(project_id.to_string()) + } + + /// Path to the project metadata JSON file. + pub fn project_json(&self, project_id: &Uuid) -> PathBuf { + self.project_dir(project_id).join("project.json") + } + + /// Path to the repositories JSON file (Phase 2). + pub fn repositories_json(&self, project_id: &Uuid) -> PathBuf { + self.project_dir(project_id).join("repositories.json") + } + + /// Path to the bindings JSON file (Phase 5). + pub fn bindings_json(&self, project_id: &Uuid) -> PathBuf { + self.project_dir(project_id).join("bindings.json") + } + + /// Path to the worktrees JSON file (Phase 4). + pub fn worktrees_json(&self, project_id: &Uuid) -> PathBuf { + self.project_dir(project_id).join("worktrees.json") + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_project_paths() { + let paths = ProjectPaths::new("/data/projects"); + let id = Uuid::nil(); + + assert_eq!( + paths.project_dir(&id), + PathBuf::from("/data/projects/00000000-0000-0000-0000-000000000000") + ); + assert_eq!( + paths.project_json(&id), + PathBuf::from("/data/projects/00000000-0000-0000-0000-000000000000/project.json") + ); + } +} diff --git a/crates/dirigent_projects/src/traits.rs b/crates/dirigent_projects/src/traits.rs new file mode 100644 index 0000000..089cfab --- /dev/null +++ b/crates/dirigent_projects/src/traits.rs @@ -0,0 +1,81 @@ +//! ProjectStore trait definition. + +use crate::error::Result; +use crate::params::*; +use dirigent_protocol::project::{Project, ProjectBinding, ProjectRepository, Worktree}; +use std::path::PathBuf; +use uuid::Uuid; + +/// Trait for project storage backends. +/// +/// Async-first, trait-object safe. Implementations must be Send + Sync. +/// Phase 1 implements project CRUD. Later phases add repository management, +/// bindings, and resolution. +#[async_trait::async_trait] +pub trait ProjectStore: Send + Sync { + // --- Project CRUD (Phase 1) --- + + /// Create a new project. + async fn create_project(&self, params: CreateProjectParams) -> Result<Project>; + + /// Get a project by ID. + async fn get_project(&self, id: &Uuid) -> Result<Project>; + + /// List projects matching a filter. + async fn list_projects(&self, filter: ProjectFilter) -> Result<Vec<Project>>; + + /// Update a project's fields. + async fn update_project(&self, id: &Uuid, update: ProjectUpdate) -> Result<Project>; + + /// Delete a project and all associated data. + async fn delete_project(&self, id: &Uuid) -> Result<()>; + + // --- Repository management (Phase 2) --- + + /// Add a repository to a project. + async fn add_repository(&self, params: AddRepositoryParams) -> Result<ProjectRepository>; + + /// Remove a repository. + async fn remove_repository(&self, id: &Uuid) -> Result<()>; + + /// Set a repository as the primary for its project. + async fn set_primary_repository(&self, project_id: &Uuid, repo_id: &Uuid) -> Result<()>; + + /// List repositories for a project. + async fn list_repositories(&self, project_id: &Uuid) -> Result<Vec<ProjectRepository>>; + + // --- Worktrees (Phase 4) --- + + /// Add a worktree record to a repository. + async fn add_worktree(&self, params: AddWorktreeParams) -> Result<Worktree>; + + /// Remove a worktree record. + async fn remove_worktree(&self, worktree_id: &Uuid) -> Result<()>; + + /// List worktree records for a repository. + async fn list_worktrees(&self, repository_id: &Uuid) -> Result<Vec<Worktree>>; + + /// Update a worktree record. + async fn update_worktree(&self, worktree_id: &Uuid, update: WorktreeUpdate) + -> Result<Worktree>; + + // --- Bindings (Phase 5) --- + + /// Bind a project to a connector/session. + async fn bind(&self, params: BindParams) -> Result<ProjectBinding>; + + /// Remove a binding. + async fn unbind(&self, binding_id: &Uuid) -> Result<()>; + + /// List bindings for a project. + async fn list_bindings(&self, project_id: &Uuid) -> Result<Vec<ProjectBinding>>; + + // --- Resolution (Phase 2) --- + + /// Resolve the working directory for a project. + async fn resolve_working_dir( + &self, + project_id: &Uuid, + repo_id: Option<&Uuid>, + ) -> Result<PathBuf>; +} diff --git a/crates/dirigent_projects/tests/git_tests.rs b/crates/dirigent_projects/tests/git_tests.rs new file mode 100644 index 0000000..65942ff --- /dev/null +++ b/crates/dirigent_projects/tests/git_tests.rs @@ -0,0 +1,184 @@ +//! Integration tests for the git module. +//! +//! These tests create real temporary git repos and exercise GitRunner +//! and compute_git_state against them. Marked `#[ignore]` by default +//! since they require `git` to be installed. + +use dirigent_projects::git::{compute_git_state, GitRunner}; +use std::path::Path; +use tokio::process::Command; + +/// Helper: initialize a git repo in the given directory with an initial commit. +async fn init_repo(dir: &Path) { + run(dir, &["git", "init"]).await; + run(dir, &["git", "config", "user.email", "test@test.com"]).await; + run(dir, &["git", "config", "user.name", "Test"]).await; + // Create an initial commit so HEAD exists + let file = dir.join("README.md"); + tokio::fs::write(&file, "# Test\n").await.unwrap(); + run(dir, &["git", "add", "."]).await; + run(dir, &["git", "commit", "-m", "Initial commit"]).await; +} + +async fn run(dir: &Path, args: &[&str]) { + let status = Command::new(args[0]) + .args(&args[1..]) + .current_dir(dir) + .output() + .await + .unwrap_or_else(|e| panic!("Failed to run {:?}: {e}", args)); + assert!( + status.status.success(), + "{:?} failed: {}", + args, + String::from_utf8_lossy(&status.stderr) + ); +} + +#[tokio::test] +#[ignore = "requires git"] +async fn test_current_branch() { + let dir = tempfile::tempdir().unwrap(); + init_repo(dir.path()).await; + + let runner = GitRunner::new(dir.path()); + let branch = runner.current_branch().await.unwrap(); + // Default branch may be "main" or "master" depending on git config + assert!( + branch == "main" || branch == "master", + "unexpected branch: {branch}" + ); +} + +#[tokio::test] +#[ignore = "requires git"] +async fn test_status_clean() { + let dir = tempfile::tempdir().unwrap(); + init_repo(dir.path()).await; + + let runner = GitRunner::new(dir.path()); + let status = runner.status().await.unwrap(); + assert!(!status.is_dirty); + assert_eq!(status.ahead, 0); + assert_eq!(status.behind, 0); +} + +#[tokio::test] +#[ignore = "requires git"] +async fn test_status_dirty() { + let dir = tempfile::tempdir().unwrap(); + init_repo(dir.path()).await; + + // Create an untracked file + tokio::fs::write(dir.path().join("dirty.txt"), "dirty") + .await + .unwrap(); + + let runner = GitRunner::new(dir.path()); + let status = runner.status().await.unwrap(); + assert!(status.is_dirty); +} + +#[tokio::test] +#[ignore = "requires git"] +async fn test_remotes_empty() { + let dir = tempfile::tempdir().unwrap(); + init_repo(dir.path()).await; + + let runner = GitRunner::new(dir.path()); + let remotes = runner.remotes().await.unwrap(); + assert!(remotes.is_empty()); +} + +#[tokio::test] +#[ignore = "requires git"] +async fn test_worktree_list_single() { + let dir = tempfile::tempdir().unwrap(); + init_repo(dir.path()).await; + + let runner = GitRunner::new(dir.path()); + let worktrees = runner.worktree_list().await.unwrap(); + // A non-bare repo always has at least the main worktree + assert_eq!(worktrees.len(), 1); + assert!(worktrees[0].branch.is_some()); +} + +#[tokio::test] +#[ignore = "requires git"] +async fn test_worktree_add_and_list() { + let dir = tempfile::tempdir().unwrap(); + init_repo(dir.path()).await; + + let wt_path = dir.path().join("wt-feature"); + // Create a branch first + run(dir.path(), &["git", "branch", "feature"]).await; + + let runner = GitRunner::new(dir.path()); + runner.worktree_add(&wt_path, "feature").await.unwrap(); + + let worktrees = runner.worktree_list().await.unwrap(); + assert_eq!(worktrees.len(), 2); + + // Find the feature worktree by branch name (paths may differ due to symlink canonicalization) + let feature_wt = worktrees + .iter() + .find(|w| w.branch.as_deref() == Some("feature")) + .expect("should find worktree with branch 'feature'"); + assert!(!feature_wt.is_detached); +} + +#[tokio::test] +#[ignore = "requires git"] +async fn test_compute_git_state() { + let dir = tempfile::tempdir().unwrap(); + init_repo(dir.path()).await; + + // Make it dirty + tokio::fs::write(dir.path().join("new.txt"), "content") + .await + .unwrap(); + + let runner = GitRunner::new(dir.path()); + let state = compute_git_state(&runner).await; + + assert!(!state.branch.is_empty()); + assert!(state.is_dirty); + assert!( + state.unexpected.is_empty(), + "unexpected warnings: {:?}", + state.unexpected + ); + // Should have at least the main worktree + assert!(!state.worktrees.is_empty()); +} + +#[tokio::test] +#[ignore = "requires git"] +async fn test_graceful_degradation_not_a_repo() { + let dir = tempfile::tempdir().unwrap(); + // Don't init — not a git repo + + let runner = GitRunner::new(dir.path()); + let state = compute_git_state(&runner).await; + + // Should have warnings, not panic + assert!(!state.unexpected.is_empty()); +} + +#[tokio::test] +#[ignore = "requires git"] +async fn test_commit_returns_hash() { + let dir = tempfile::tempdir().unwrap(); + init_repo(dir.path()).await; + + let file = dir.path().join("commit_test.txt"); + tokio::fs::write(&file, "data").await.unwrap(); + run(dir.path(), &["git", "add", "."]).await; + + let runner = GitRunner::new(dir.path()); + let hash = runner.commit("test commit").await.unwrap(); + + // SHA-1 hash is 40 hex chars + assert_eq!(hash.len(), 40, "unexpected hash: {hash}"); + assert!(hash.chars().all(|c| c.is_ascii_hexdigit())); +} diff --git a/crates/dirigent_projects/tests/project_lifecycle.rs b/crates/dirigent_projects/tests/project_lifecycle.rs new file mode 100644 index 0000000..c1c6f0b --- /dev/null +++ b/crates/dirigent_projects/tests/project_lifecycle.rs @@ -0,0 +1,226 @@ +//! Integration tests for project CRUD lifecycle. + +use dirigent_projects::{ + CreateProjectParams, FileBasedProjectStore, ProjectFilter, ProjectStore, ProjectUpdate, +}; +use uuid::Uuid; + +async fn make_store() -> FileBasedProjectStore { + let dir = tempfile::tempdir().unwrap(); + FileBasedProjectStore::new(dir.into_path()).await.unwrap() +} + +#[tokio::test] +async fn test_create_and_get_project() { + let store = make_store().await; + let owner = Uuid::now_v7(); + + let project = store + .create_project(CreateProjectParams { + name: "Test Project".to_string(), + description: "A test".to_string(), + icon: Some("🚀".to_string()), + owner, + tags: vec!["rust".to_string()], + languages: vec!["Rust".to_string()], + metadata: serde_json::json!({}), + }) + .await + .unwrap(); + + assert_eq!(project.name, "Test Project"); + assert_eq!(project.owner, owner); + + let fetched = store.get_project(&project.id).await.unwrap(); + assert_eq!(fetched.id, project.id); + assert_eq!(fetched.name, "Test Project"); + assert_eq!(fetched.icon, Some("🚀".to_string())); +} + +#[tokio::test] +async fn test_list_projects_empty() { + let store = make_store().await; + let projects = store.list_projects(ProjectFilter::default()).await.unwrap(); + assert!(projects.is_empty()); +} + +#[tokio::test] +async fn test_list_projects_with_filter() { + let store = make_store().await; + let owner1 = Uuid::now_v7(); + let owner2 = Uuid::now_v7(); + + store + .create_project(CreateProjectParams { + name: "Alpha".to_string(), + owner: owner1, + tags: vec!["web".to_string()], + ..default_params() + }) + .await + .unwrap(); + + store + .create_project(CreateProjectParams { + name: "Beta".to_string(), + owner: owner2, + tags: vec!["cli".to_string()], + ..default_params() + }) + .await + .unwrap(); + + // Filter by owner + let filtered = store + .list_projects(ProjectFilter { + owner: Some(owner1), + ..Default::default() + }) + .await + .unwrap(); + assert_eq!(filtered.len(), 1); + assert_eq!(filtered[0].name, "Alpha"); + + // Filter by name + let filtered = store + .list_projects(ProjectFilter { + name_contains: Some("bet".to_string()), + ..Default::default() + }) + .await + .unwrap(); + assert_eq!(filtered.len(), 1); + assert_eq!(filtered[0].name, "Beta"); + + // Filter by tag + let filtered = store + .list_projects(ProjectFilter { + tags: vec!["web".to_string()], + ..Default::default() + }) + .await + .unwrap(); + assert_eq!(filtered.len(), 1); + assert_eq!(filtered[0].name, "Alpha"); + + // No filter returns all, sorted by name + let all = store.list_projects(ProjectFilter::default()).await.unwrap(); + assert_eq!(all.len(), 2); + assert_eq!(all[0].name, "Alpha"); + assert_eq!(all[1].name, "Beta"); +} + +#[tokio::test] +async fn test_update_project() { + let store = make_store().await; + + let project = store + .create_project(CreateProjectParams { + name: "Original".to_string(), + ..default_params() + }) + .await + .unwrap(); + + let updated = store + .update_project( + &project.id, + ProjectUpdate { + name: Some("Renamed".to_string()), + description: Some("New description".to_string()), + tags: Some(vec!["new-tag".to_string()]), + ..Default::default() + }, + ) + .await + .unwrap(); + + assert_eq!(updated.name, "Renamed"); + assert_eq!(updated.description, "New description"); + assert_eq!(updated.tags, vec!["new-tag"]); + assert!(updated.updated_at > project.created_at); + + // Verify persistence + let fetched = store.get_project(&project.id).await.unwrap(); + assert_eq!(fetched.name, "Renamed"); +} + +#[tokio::test] +async fn test_delete_project() { + let store = make_store().await; + + let project = store + .create_project(CreateProjectParams { + name: "ToDelete".to_string(), + ..default_params() + }) + .await + .unwrap(); + + store.delete_project(&project.id).await.unwrap(); + + let err = store.get_project(&project.id).await.unwrap_err(); + assert!(matches!(err, dirigent_projects::ProjectError::NotFound(_))); +} + +#[tokio::test] +async fn test_get_nonexistent_project() { + let store = make_store().await; + let err = store.get_project(&Uuid::now_v7()).await.unwrap_err(); + assert!(matches!(err, dirigent_projects::ProjectError::NotFound(_))); +} + +#[tokio::test] +async fn test_create_empty_name_fails() { + let store = make_store().await; + let err = store + .create_project(CreateProjectParams { + name: " ".to_string(), + ..default_params() + }) + .await + .unwrap_err(); + assert!(matches!( + err, + dirigent_projects::ProjectError::Validation(_) + )); +} + +#[tokio::test] +async fn test_update_empty_name_fails() { + let store = make_store().await; + let project = store + .create_project(CreateProjectParams { + name: "Valid".to_string(), + ..default_params() + }) + .await + .unwrap(); + + let err = store + .update_project( + &project.id, + ProjectUpdate { + name: Some("".to_string()), + ..Default::default() + }, + ) + .await + .unwrap_err(); + assert!(matches!( + err, + dirigent_projects::ProjectError::Validation(_) + )); +} + +fn default_params() -> CreateProjectParams { + CreateProjectParams { + name: String::new(), + description: String::new(), + icon: None, + owner: Uuid::now_v7(), + tags: vec![], + languages: vec![], + metadata: serde_json::json!({}), + } +} diff --git a/crates/dirigent_projects/tests/repository_tests.rs b/crates/dirigent_projects/tests/repository_tests.rs new file mode 100644 index 0000000..dbb0f5f --- /dev/null +++ b/crates/dirigent_projects/tests/repository_tests.rs @@ -0,0 +1,424 @@ +//! Integration tests for repository and binding CRUD, plus working directory resolution. + +use dirigent_projects::{ + AddRepositoryParams, BindParams, CreateProjectParams, FileBasedProjectStore, ProjectStore, +}; +use std::path::PathBuf; +use uuid::Uuid; + +async fn make_store() -> FileBasedProjectStore { + let dir = tempfile::tempdir().unwrap(); + FileBasedProjectStore::new(dir.into_path()).await.unwrap() +} + +fn default_params() -> CreateProjectParams { + CreateProjectParams { + name: "Test Project".to_string(), + description: String::new(), + icon: None, + owner: Uuid::now_v7(), + tags: vec![], + languages: vec![], + metadata: serde_json::json!({}), + } +} + +// ============================================================================ +// Repository Tests +// ============================================================================ + +#[tokio::test] +async fn test_add_and_list_repositories() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + let repo = store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/home/user/project"), + is_primary: false, + label: Some("main".to_string()), + }) + .await + .unwrap(); + + assert_eq!(repo.project_id, project.id); + assert_eq!(repo.path, PathBuf::from("/home/user/project")); + assert!(!repo.is_primary); + assert_eq!(repo.label, Some("main".to_string())); + + let repos = store.list_repositories(&project.id).await.unwrap(); + assert_eq!(repos.len(), 1); + assert_eq!(repos[0].id, repo.id); +} + +#[tokio::test] +async fn test_add_primary_repository_unsets_others() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + let repo1 = store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/repo1"), + is_primary: true, + label: None, + }) + .await + .unwrap(); + assert!(repo1.is_primary); + + // Adding a second primary should unset the first + let repo2 = store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/repo2"), + is_primary: true, + label: None, + }) + .await + .unwrap(); + assert!(repo2.is_primary); + + let repos = store.list_repositories(&project.id).await.unwrap(); + assert_eq!(repos.len(), 2); + + let first = repos.iter().find(|r| r.id == repo1.id).unwrap(); + let second = repos.iter().find(|r| r.id == repo2.id).unwrap(); + assert!(!first.is_primary); + assert!(second.is_primary); +} + +#[tokio::test] +async fn test_remove_repository() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + let repo = store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/repo"), + is_primary: false, + label: None, + }) + .await + .unwrap(); + + store.remove_repository(&repo.id).await.unwrap(); + let repos = store.list_repositories(&project.id).await.unwrap(); + assert!(repos.is_empty()); +} + +#[tokio::test] +async fn test_remove_nonexistent_repository() { + let store = make_store().await; + let err = store.remove_repository(&Uuid::now_v7()).await.unwrap_err(); + assert!(matches!( + err, + dirigent_projects::ProjectError::RepositoryNotFound(_) + )); +} + +#[tokio::test] +async fn test_set_primary_repository() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + let repo1 = store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/repo1"), + is_primary: true, + label: None, + }) + .await + .unwrap(); + + let repo2 = store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/repo2"), + is_primary: false, + label: None, + }) + .await + .unwrap(); + + // Switch primary to repo2 + store + .set_primary_repository(&project.id, &repo2.id) + .await + .unwrap(); + + let repos = store.list_repositories(&project.id).await.unwrap(); + let first = repos.iter().find(|r| r.id == repo1.id).unwrap(); + let second = repos.iter().find(|r| r.id == repo2.id).unwrap(); + assert!(!first.is_primary); + assert!(second.is_primary); +} + +#[tokio::test] +async fn test_set_primary_nonexistent_repo() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + let err = store + .set_primary_repository(&project.id, &Uuid::now_v7()) + .await + .unwrap_err(); + assert!(matches!( + err, + dirigent_projects::ProjectError::RepositoryNotFound(_) + )); +} + +#[tokio::test] +async fn test_add_repo_to_nonexistent_project() { + let store = make_store().await; + let err = store + .add_repository(AddRepositoryParams { + project_id: Uuid::now_v7(), + path: PathBuf::from("/repo"), + is_primary: false, + label: None, + }) + .await + .unwrap_err(); + assert!(matches!(err, dirigent_projects::ProjectError::NotFound(_))); +} + +// ============================================================================ +// Working Directory Resolution Tests +// ============================================================================ + +#[tokio::test] +async fn test_resolve_working_dir_specific_repo() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + let repo = store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/specific/repo"), + is_primary: false, + label: None, + }) + .await + .unwrap(); + + let resolved = store + .resolve_working_dir(&project.id, Some(&repo.id)) + .await + .unwrap(); + assert_eq!(resolved, PathBuf::from("/specific/repo")); +} + +#[tokio::test] +async fn test_resolve_working_dir_primary_repo() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/secondary"), + is_primary: false, + label: None, + }) + .await + .unwrap(); + + store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/primary"), + is_primary: true, + label: None, + }) + .await + .unwrap(); + + let resolved = store.resolve_working_dir(&project.id, None).await.unwrap(); + assert_eq!(resolved, PathBuf::from("/primary")); +} + +#[tokio::test] +async fn test_resolve_working_dir_first_repo_fallback() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/only-repo"), + is_primary: false, + label: None, + }) + .await + .unwrap(); + + let resolved = store.resolve_working_dir(&project.id, None).await.unwrap(); + assert_eq!(resolved, PathBuf::from("/only-repo")); +} + +#[tokio::test] +async fn test_resolve_working_dir_no_repos_errors() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + let err = store + .resolve_working_dir(&project.id, None) + .await + .unwrap_err(); + assert!(matches!( + err, + dirigent_projects::ProjectError::Validation(_) + )); +} + +#[tokio::test] +async fn test_resolve_working_dir_nonexistent_repo_id() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + store + .add_repository(AddRepositoryParams { + project_id: project.id, + path: PathBuf::from("/repo"), + is_primary: true, + label: None, + }) + .await + .unwrap(); + + let err = store + .resolve_working_dir(&project.id, Some(&Uuid::now_v7())) + .await + .unwrap_err(); + assert!(matches!( + err, + dirigent_projects::ProjectError::RepositoryNotFound(_) + )); +} + +// ============================================================================ +// Binding Tests +// ============================================================================ + +#[tokio::test] +async fn test_bind_and_list_bindings() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + let binding = store + .bind(BindParams { + project_id: project.id, + connector_id: Some("opencode-1".to_string()), + session_id: None, + working_dir: Some(PathBuf::from("/custom/dir")), + }) + .await + .unwrap(); + + assert_eq!(binding.project_id, project.id); + assert_eq!(binding.connector_id, Some("opencode-1".to_string())); + assert!(binding.session_id.is_none()); + assert_eq!(binding.working_dir, Some(PathBuf::from("/custom/dir"))); + + let bindings = store.list_bindings(&project.id).await.unwrap(); + assert_eq!(bindings.len(), 1); + assert_eq!(bindings[0].id, binding.id); +} + +#[tokio::test] +async fn test_unbind() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + let binding = store + .bind(BindParams { + project_id: project.id, + connector_id: Some("conn-1".to_string()), + session_id: None, + working_dir: None, + }) + .await + .unwrap(); + + store.unbind(&binding.id).await.unwrap(); + let bindings = store.list_bindings(&project.id).await.unwrap(); + assert!(bindings.is_empty()); +} + +#[tokio::test] +async fn test_unbind_nonexistent() { + let store = make_store().await; + let err = store.unbind(&Uuid::now_v7()).await.unwrap_err(); + assert!(matches!( + err, + dirigent_projects::ProjectError::BindingNotFound(_) + )); +} + +#[tokio::test] +async fn test_bind_to_nonexistent_project() { + let store = make_store().await; + let err = store + .bind(BindParams { + project_id: Uuid::now_v7(), + connector_id: Some("conn".to_string()), + session_id: None, + working_dir: None, + }) + .await + .unwrap_err(); + assert!(matches!(err, dirigent_projects::ProjectError::NotFound(_))); +} + +#[tokio::test] +async fn test_bind_with_session_id() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + let session_id = Uuid::now_v7(); + + let binding = store + .bind(BindParams { + project_id: project.id, + connector_id: Some("conn-1".to_string()), + session_id: Some(session_id), + working_dir: None, + }) + .await + .unwrap(); + + assert_eq!(binding.session_id, Some(session_id)); +} + +#[tokio::test] +async fn test_multiple_bindings_per_project() { + let store = make_store().await; + let project = store.create_project(default_params()).await.unwrap(); + + store + .bind(BindParams { + project_id: project.id, + connector_id: Some("conn-1".to_string()), + session_id: None, + working_dir: None, + }) + .await + .unwrap(); + + store + .bind(BindParams { + project_id: project.id, + connector_id: Some("conn-2".to_string()), + session_id: None, + working_dir: None, + }) + .await + .unwrap(); + + let bindings = store.list_bindings(&project.id).await.unwrap(); + assert_eq!(bindings.len(), 2); +} diff --git a/crates/dirigent_protocol/CHANGELOG.md b/crates/dirigent_protocol/CHANGELOG.md new file mode 100644 index 0000000..0a27ccc --- /dev/null +++ b/crates/dirigent_protocol/CHANGELOG.md @@ -0,0 +1,98 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [0.2.0] - 2025-11-10 + +### BREAKING CHANGES + +#### Removed Deprecated Event Variants + +The `MessagePartAdded` event variant has been removed from the `Event` enum. This variant was part of the old streaming system and has been replaced by the new ACP-style `SessionUpdate` event. + +**Migration Guide:** + +If you were using `MessagePartAdded`, you should migrate to using `SessionUpdate` instead: + +**Old code:** +```rust +match event { + Event::MessagePartAdded { session_id, message_id, part } => { + // Handle message part + } + // ... +} +``` + +**New code:** +```rust +match event { + Event::SessionUpdate { session_id, update } => { + match update { + SessionUpdate::UserMessageChunk { message_id, content, .. } => { + // Handle user message content + } + SessionUpdate::AgentMessageChunk { message_id, content, .. } => { + // Handle agent message content + } + SessionUpdate::AgentThoughtChunk { message_id, content, .. } => { + // Handle agent thinking/reasoning + } + SessionUpdate::ToolCall { message_id, tool_call, .. } => { + // Handle tool call initiated + } + SessionUpdate::ToolCallUpdate { message_id, tool_call_id, tool_call, .. } => { + // Handle tool call updates + } + } + } + // ... +} +``` + +**Key Differences:** + +1. `SessionUpdate` uses typed `ContentBlock` instead of generic `MessagePart` +2. Updates are categorized by type (user/agent/thought/tool) +3. Better separation of concerns for streaming content +4. Aligns with Agent-Client Protocol (ACP) standards + +**What This Means:** + +- The protocol now uses a unified streaming model via `SessionUpdate` +- Better alignment with ACP specification +- Clearer separation between message lifecycle events and streaming updates +- More structured content representation with `ContentBlock` + +### Added + +- `SessionUpdate` event variant with ACP-style streaming updates +- `SessionUpdate` enum with variants: + - `UserMessageChunk`: User message content streaming + - `AgentMessageChunk`: Agent message content streaming + - `AgentThoughtChunk`: Agent reasoning/thinking streaming + - `ToolCall`: Tool call initiated + - `ToolCallUpdate`: Tool call status/content updates +- `ContentBlock` enum for structured content representation +- `ToolCall` type with status tracking and metadata + +### Changed + +- Streaming events now use `SessionUpdate` instead of `MessagePartAdded` +- Version bumped from 0.1.0 to 0.2.0 (breaking change) + +## [0.1.0] - 2025-11-09 + +### Added + +- Initial protocol definition +- `Event` enum with session, message, connector, and system events +- `Session` and `SessionMetadata` types +- `Message`, `MessageMetadata`, `MessageRole`, `MessageStatus` types +- `MessagePart` enum for message content +- OpenCode adapter for translating OpenCode events to Dirigent protocol +- REST adapter for converting REST API responses +- Comprehensive test suite diff --git a/crates/dirigent_protocol/CLAUDE.md b/crates/dirigent_protocol/CLAUDE.md new file mode 100644 index 0000000..28ae11a --- /dev/null +++ b/crates/dirigent_protocol/CLAUDE.md @@ -0,0 +1,268 @@ +# dirigent_protocol + +**Version:** 0.2.0 +**Status:** Active Development + +ACP/MCP-aligned protocol library for agent-client interactions. + +## Quick Links + +- **Package README**: [README.md](README.md) - Main documentation +- **Streaming Model**: [docs/streaming_model.md](docs/streaming_model.md) - Detailed SessionUpdate guide +- **Migration Guide**: [docs/migration_from_0.1.md](docs/migration_from_0.1.md) - Upgrading from 0.1.x +- **Architecture Doc**: [../../docs/architecture/protocol.md](../../docs/architecture/protocol.md) - System design + +## Purpose + +This package provides the core event protocol for Dirigent, enabling: +- Real-time streaming of agent interactions +- Provider-agnostic event representation +- Tool lifecycle management +- Structured content representation + +## Key Types + +```rust +use dirigent_protocol::{ + Event, // Top-level event enum + SessionUpdate, // Streaming content updates + ContentBlock, // Structured content (Text, ResourceLink) + ToolCall, // Tool execution state + ToolCallStatus, // Pending → Running → Completed/Error + TurnCompleteTrigger, // How turn completion was detected +}; +``` + +## Event Semantics: MessageCompleted vs TurnComplete vs SessionIdle + +Understanding the distinction between these three events is critical for correct system behavior: + +### MessageCompleted - "Metadata is ready" +- **Purpose**: Informational - signals that message metadata exists +- **Timing**: Emitted when message record is created, content may still be streaming +- **Consumer action**: Update UI status indicators ("Assistant is typing" → "Complete") +- **Example**: Show message timestamp, update message count + +### TurnComplete - "All content received" (ACTIONABLE) +- **Purpose**: **Primary finalization signal** - all content for this turn is complete +- **Timing**: Emitted AFTER all content chunks, tool calls, and metadata updates +- **Consumer action**: **Finalize storage, lock state, trigger post-processing** +- **Example**: Write message to disk, mark as immutable, generate summaries + +### SessionIdle - "No recent activity" +- **Purpose**: Informational - indicates session is quiet +- **Timing**: Emitted AFTER TurnComplete as final safety signal +- **Consumer action**: Hide spinners, update activity indicators +- **Example**: Remove "typing" animation, update last activity timestamp + +### Event Ordering Guarantee + +```text +1. MessageStarted (message created) +2. SessionUpdate::*Chunk (content streaming) +3. SessionUpdate::ToolCall* (tool execution) +4. MessageCompleted (metadata ready) ← UI: "Complete" +5. TurnComplete ← FINALIZE HERE! +6. SessionIdle ← UI: hide spinner +``` + +### Consumer Behavior Table + +| Consumer | MessageCompleted | TurnComplete | SessionIdle | +|----------|------------------|--------------|-------------| +| **Archivist** | Ignore | **Finalize and write** | Safety net | +| **UI Cache** | Update status | **Lock state** | Hide spinner | +| **Conductor Bridge** | - | **Flush response** | Fallback flush | + +### TurnCompleteTrigger Variants + +The `TurnCompleteTrigger` enum indicates **how** the system determined completion: + +- **`ExplicitSignal`**: Upstream provider sent explicit completion (e.g., OpenCode session.idle) +- **`ResponseReceived`**: JSON-RPC response received (ACP stdio - response is last message) +- **`OperationsComplete`**: All tracked operations finished (e.g., pending tool calls resolved) +- **`IdleTimeout { duration_ms }`**: Timeout-based detection (fallback mechanism) + +**For most consumers**, treat all triggers the same - the turn is complete. The trigger type is primarily for debugging and observability. + +## Usage Pattern + +```rust +use dirigent_protocol::{Event, SessionUpdate, ContentBlock}; + +fn handle_event(event: Event) { + match event { + Event::SessionUpdate { session_id, update } => { + match update { + SessionUpdate::AgentMessageChunk { message_id, content, .. } => { + if let ContentBlock::Text { text } = content { + println!("Agent: {}", text); + } + } + SessionUpdate::ToolCall { tool_call, .. } => { + println!("Tool: {}", tool_call.tool_name); + } + _ => {} + } + } + _ => {} + } +} +``` + +## Architecture + +``` +dirigent_protocol/ +├── src/ +│ ├── types/ # Core types +│ │ ├── content.rs # ContentBlock definitions +│ │ ├── updates.rs # SessionUpdate variants +│ │ ├── tool.rs # ToolCall, ToolCallStatus +│ │ └── meta.rs # Provider metadata +│ ├── session.rs # Session types +│ ├── conversation.rs # Message types +│ ├── events.rs # Event enum +│ └── adapters/ # Provider adapters +│ ├── opencode.rs # OpenCode translation +│ └── rest.rs # REST translation +├── docs/ # Detailed documentation +├── examples/ # Usage examples +└── tests/ # Integration tests +``` + +## Version 0.2.0 Changes + +**Breaking:** +- Removed `Event::MessagePartAdded` + +**New:** +- `SessionUpdate` event system (ACP-style) +- `ContentBlock` types (MCP-compatible) +- `ToolCall` with lifecycle tracking +- Provider metadata via `_meta` + +See [docs/migration_from_0.1.md](docs/migration_from_0.1.md) for migration guide. + +## Development + +### Running Tests +```bash +cargo test --package dirigent_protocol +``` + +### Checking Code +```bash +cargo check --package dirigent_protocol +``` + +### Running Examples +```bash +cargo run --package dirigent_protocol --example session_metadata_demo +``` + +## Integration + +This package is used by: +- **api** package: Server functions consume protocol events +- **web** package: UI renders protocol events +- **dirigent_core** (future): Runtime emits protocol events + +## Adapters + +The adapter system translates provider-specific events to Dirigent Protocol: + +- **OpenCodeAdapter**: Translates OpenCode.ai events +- **RESTAdapter**: Converts REST API responses + +Adapters preserve provider metadata in the `_meta` field for debugging and traceability. + +## Current Scope + +**Phase 1 (Implemented):** +- User/Agent/Thought message streaming +- Tool lifecycle (Pending → Running → Completed/Error) +- Text and ResourceLink content types +- Provider metadata support + +**Deferred to Future Phases:** +- Plans and mode switching +- Permission system +- Embedded resources (full content) +- Rich media (images, audio) +- Multi-agent communication + +See [../../docs/building/03_acp_prep/04_first_order_refactor.md](../../docs/building/03_acp_prep/04_first_order_refactor.md) for the full plan. + +## Standards Alignment + +**Agent-Client Protocol (ACP):** +- Session-centric streaming +- Separate content types (user/agent/thought) +- Tool status tracking + +**Model Context Protocol (MCP):** +- ContentBlock structure +- Resource links +- Extensible content types + +Differences from standards are documented in [../../docs/architecture/protocol.md](../../docs/architecture/protocol.md). + +## Anti-Patterns + +### Timeout-Based Event Waiting (FORBIDDEN) + +**Never use timeout-based waiting to receive events that should be available immediately.** + +```rust +// ❌ BAD - Race condition waiting for event +async fn wait_for_metadata_event( + events: &mut broadcast::Receiver<Event>, + timeout: Duration, +) -> Option<SessionMetadataReceived> { + let start = Instant::now(); + while start.elapsed() < timeout { + match tokio::time::timeout(Duration::from_millis(100), events.recv()).await { + Ok(Ok(Event::SessionMetadataReceived { .. })) => return Some(...), + _ => continue, + } + } + None +} +``` + +**Why this is wrong:** +1. **Race condition**: The event may have been emitted before the receiver subscribed +2. **Arbitrary delays**: 500ms waits add latency for no good reason +3. **Silent failures**: Timeout expiring doesn't indicate the real problem +4. **Fragile**: Works "most of the time" but fails under load or timing variations + +**Instead, pass data directly:** +```rust +// ✅ GOOD - Extract data from existing events +async fn create_session_in_connector(...) -> Result<(String, Option<Models>, Option<Modes>), String> { + // The SessionCreated event already contains models/modes + match event { + Event::SessionCreated { session, .. } => { + Ok((session.id, session.models, session.modes)) + } + } +} +``` + +**Rule**: If you find yourself writing `timeout(Duration::from_millis(N), events.recv())` to wait for an event that "should" arrive, the architecture is wrong. Refactor to pass data directly through return values or existing event payloads. + +## Contributing + +When adding features: +1. Update type definitions in `src/types/` +2. Add comprehensive tests +3. Update documentation in `docs/` +4. Add examples if applicable +5. Update CHANGELOG.md + +## See Also + +- Main project: [../../CLAUDE.md](../../CLAUDE.md) +- Architecture docs: [../../docs/architecture/](../../docs/architecture/) +- Building docs: [../../docs/building/](../../docs/building/) diff --git a/crates/dirigent_protocol/Cargo.toml b/crates/dirigent_protocol/Cargo.toml new file mode 100644 index 0000000..4fb1f44 --- /dev/null +++ b/crates/dirigent_protocol/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "dirigent_protocol" +version = "0.2.0" +edition = "2021" + +[dependencies] +async-trait = "0.1" +chrono = { version = "0.4", features = ["serde"] } +opencode_client = { workspace = true, optional = true } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "2.0" +tokio = { version = "1", features = ["sync"] } +tracing = "0.1" +uuid = { version = "1.18", features = ["js", "serde", "v4", "v7"] } + +[dev-dependencies] +tokio = { version = "1", features = ["macros", "rt", "sync"] } + +[features] +default = [] +adapters = ["dep:opencode_client"] + +[[test]] +name = "opencode_session_update_tests" +required-features = ["adapters"] + +[[test]] +name = "protocol_tests" +required-features = ["adapters"] + +[[test]] +name = "session_list_tests" +required-features = ["adapters"] + +[[test]] +name = "deduplication_tests" +required-features = ["adapters"] diff --git a/crates/dirigent_protocol/README.md b/crates/dirigent_protocol/README.md new file mode 100644 index 0000000..39012ff --- /dev/null +++ b/crates/dirigent_protocol/README.md @@ -0,0 +1,352 @@ +# Dirigent Protocol + +**Version:** 0.2.0 + +A Rust protocol library for agent-client interactions, aligned with **Agent-Client Protocol (ACP)** and **Model Context Protocol (MCP)** standards. + +## Overview + +The Dirigent Protocol provides a structured, streaming-first event model for real-time agent interactions. It's designed to support multi-agent orchestration, tool execution, and rich content streaming while maintaining compatibility with standard protocols. + +## Features + +- **ACP-Style Streaming**: Real-time content updates via `SessionUpdate` events +- **MCP-Compatible Content**: Structured `ContentBlock` representation +- **Tool Lifecycle Management**: Complete tool call tracking from initiation to completion +- **Provider Agnostic**: Adapter system for integrating different AI providers +- **Type-Safe**: Strongly-typed Rust API with comprehensive serde support +- **Extensible**: Provider metadata and extensibility hooks + +## Quick Start + +Add to your `Cargo.toml`: + +```toml +[dependencies] +dirigent_protocol = "0.2" +``` + +### Basic Usage + +```rust +use dirigent_protocol::{Event, SessionUpdate, ContentBlock}; + +fn handle_event(event: Event) { + match event { + Event::SessionUpdate { session_id, update } => { + match update { + SessionUpdate::AgentMessageChunk { message_id, content, .. } => { + if let ContentBlock::Text { text } = content { + println!("Agent says: {}", text); + } + } + SessionUpdate::ToolCall { tool_call, .. } => { + println!("Tool called: {}", tool_call.tool_name); + } + _ => {} + } + } + Event::SessionCreated { session } => { + println!("New session: {}", session.id); + } + _ => {} + } +} +``` + +## Core Concepts + +### SessionUpdate Event Model + +The protocol uses `SessionUpdate` for all streaming content: + +```rust +pub enum SessionUpdate { + UserMessageChunk { message_id: String, content: ContentBlock, _meta: Option<Meta> }, + AgentMessageChunk { message_id: String, content: ContentBlock, _meta: Option<Meta> }, + AgentThoughtChunk { message_id: String, content: ContentBlock, _meta: Option<Meta> }, + ToolCall { message_id: String, tool_call: ToolCall, _meta: Option<Meta> }, + ToolCallUpdate { message_id: String, tool_call_id: String, tool_call: ToolCall, _meta: Option<Meta> }, +} +``` + +### ContentBlock Types + +Structured content representation: + +```rust +pub enum ContentBlock { + Text { text: String }, + ResourceLink { + uri: String, + name: Option<String>, + mime_type: Option<String>, + }, +} +``` + +### Tool Call Lifecycle + +Complete tool execution tracking: + +```rust +pub struct ToolCall { + pub id: ToolCallId, + pub tool_name: String, + pub status: ToolCallStatus, // Pending → Running → Completed/Error + pub content: Vec<ContentBlock>, + pub raw_input: Option<Value>, + pub raw_output: Option<Value>, + pub title: Option<String>, + pub error: Option<String>, + pub metadata: Option<Value>, +} +``` + +## Documentation + +- **[Streaming Model](docs/streaming_model.md)** - Detailed guide to the SessionUpdate event system +- **[Migration Guide](docs/migration_from_0.1.md)** - Upgrading from 0.1.x to 0.2.0 +- **[CHANGELOG.md](CHANGELOG.md)** - Version history and breaking changes +- **[Examples](examples/)** - Working code examples + +## Examples + +### Streaming Text + +```rust +use dirigent_protocol::{Event, SessionUpdate, ContentBlock}; + +// Agent streaming response +Event::SessionUpdate { + session_id: "session_123".to_string(), + update: SessionUpdate::AgentMessageChunk { + message_id: "msg_1".to_string(), + content: ContentBlock::Text { + text: "Hello, world!".to_string(), + }, + _meta: None, + }, +} +``` + +### Tool Execution + +```rust +use dirigent_protocol::{Event, SessionUpdate, ToolCall, ToolCallStatus}; + +// Tool call initiated +Event::SessionUpdate { + session_id: "session_123".to_string(), + update: SessionUpdate::ToolCall { + message_id: "msg_1".to_string(), + tool_call: ToolCall { + id: "call_1".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: Some(json!({"command": "ls"})), + raw_output: None, + title: Some("List files".to_string()), + error: None, + metadata: None, + }, + _meta: None, + }, +} + +// Tool call completed +Event::SessionUpdate { + session_id: "session_123".to_string(), + update: SessionUpdate::ToolCallUpdate { + message_id: "msg_1".to_string(), + tool_call_id: "call_1".to_string(), + tool_call: ToolCall { + id: "call_1".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Completed, + content: vec![ + ContentBlock::Text { + text: "file1.txt\nfile2.txt".to_string(), + }, + ], + raw_input: Some(json!({"command": "ls"})), + raw_output: Some(json!({"exit_code": 0})), + title: Some("List files".to_string()), + error: None, + metadata: None, + }, + _meta: None, + }, +} +``` + +### Agent Thinking + +```rust +use dirigent_protocol::{Event, SessionUpdate, ContentBlock}; + +// Agent internal reasoning +Event::SessionUpdate { + session_id: "session_123".to_string(), + update: SessionUpdate::AgentThoughtChunk { + message_id: "msg_1".to_string(), + content: ContentBlock::Text { + text: "Analyzing the user's request...".to_string(), + }, + _meta: None, + }, +} +``` + +## Adapters + +The protocol includes adapter modules for translating provider-specific events: + +- **OpenCode Adapter**: Converts OpenCode.ai events to Dirigent Protocol +- **REST Adapter**: Translates REST API responses + +### Using an Adapter + +```rust +use dirigent_protocol::adapters::opencode::OpenCodeAdapter; + +let adapter = OpenCodeAdapter::new(); +let dirigent_events = adapter.translate_event(opencode_event); +``` + +## Version History + +### 0.2.0 (Current) + +**Breaking Changes:** +- Removed `Event::MessagePartAdded` (replaced with `SessionUpdate`) + +**New Features:** +- ACP-style `SessionUpdate` event system +- MCP-compatible `ContentBlock` types +- Structured `ToolCall` with lifecycle tracking +- Provider metadata via `_meta` field + +See [CHANGELOG.md](CHANGELOG.md) for details. + +### 0.1.0 + +Initial release with basic event types and adapters. + +## Migration from 0.1.x + +If you're upgrading from version 0.1.x, see the [Migration Guide](docs/migration_from_0.1.md) for detailed instructions and examples. + +**Quick Summary:** +- Replace `Event::MessagePartAdded` with `Event::SessionUpdate` +- Use `SessionUpdate` variants instead of `MessagePart` +- Handle tool lifecycle with `ToolCall` and `ToolCallUpdate` +- Access `ContentBlock::Text { text }` instead of `MessagePart::Text { content }` + +## Architecture + +``` +dirigent_protocol/ +├── src/ +│ ├── types/ # Core types +│ │ ├── content.rs # ContentBlock definitions +│ │ ├── updates.rs # SessionUpdate variants +│ │ ├── tool.rs # Tool call types +│ │ └── meta.rs # Provider metadata +│ ├── session.rs # Session types +│ ├── conversation.rs # Message types +│ ├── events.rs # Event enum +│ └── adapters/ # Provider adapters +│ ├── opencode.rs # OpenCode adapter +│ └── rest.rs # REST adapter +├── docs/ # Documentation +├── examples/ # Code examples +└── tests/ # Integration tests +``` + +## Event Types + +The protocol defines several event categories: + +### Session Events +- `SessionCreated` - New session started +- `SessionUpdated` - Session metadata changed +- `SessionDeleted` - Session removed +- `SessionsListed` - Available sessions returned + +### Streaming Events +- `SessionUpdate` - Real-time content/tool updates (see SessionUpdate variants above) + +### Message Events +- `MessageStarted` - Message creation initiated +- `MessageCompleted` - Message finalized +- `MessageDeleted` - Message removed + +### System Events +- `ConnectorStateChanged` - Connector status changed +- `Error` - Error occurred + +## Best Practices + +### For Consumers + +1. **Use SessionUpdate for streaming**: The new event model provides granular updates +2. **Track tools by ID**: Maintain a HashMap of tool calls keyed by `tool_call_id` +3. **Replace tool state on update**: `ToolCallUpdate` sends complete state, not deltas +4. **Distinguish thoughts from messages**: Render `AgentThoughtChunk` differently (e.g., collapsible sections) +5. **Handle optional fields**: Provider metadata (`_meta`) may not always be present + +### For Adapters + +1. **Preserve provider info**: Store original IDs in `_meta.provider` for debugging +2. **Send complete tool state**: Include all tool call fields in updates +3. **Use appropriate chunk types**: Choose User/Agent/Thought for correct semantics +4. **Keep metadata minimal**: Avoid large payloads in `_meta` (use excerpts) + +## Testing + +Run the test suite: + +```bash +cargo test +``` + +Run with output: + +```bash +cargo test -- --nocapture +``` + +## Contributing + +Contributions welcome! Please ensure: +- All tests pass +- New features include tests +- Documentation is updated +- Code follows Rust conventions + +## License + +[License information here] + +## Related Projects + +- **dirigent_core** - Multi-agent orchestration runtime +- **opencode_client** - OpenCode.ai HTTP client library +- **dirigent_archive** - Session persistence + +## Support + +For questions or issues: +- Check the [documentation](docs/) +- Review [examples](examples/) +- Open an issue on the repository + +## Standards Alignment + +This protocol is designed to align with: +- **Agent-Client Protocol (ACP)** - Streaming model and event types +- **Model Context Protocol (MCP)** - Content block structure + +Differences from standards are documented for compatibility and future convergence. diff --git a/crates/dirigent_protocol/docs/migration_from_0.1.md b/crates/dirigent_protocol/docs/migration_from_0.1.md new file mode 100644 index 0000000..c71b11c --- /dev/null +++ b/crates/dirigent_protocol/docs/migration_from_0.1.md @@ -0,0 +1,620 @@ +# Migration Guide: Dirigent Protocol 0.1.x → 0.2.0 + +## Overview + +Version 0.2.0 introduces a **new ACP-style streaming model** while maintaining backward compatibility with most existing code. This guide will help you migrate to the new `SessionUpdate` event system. + +## Breaking Changes Summary + +### Removed: Event::MessagePartAdded + +The `MessagePartAdded` event variant has been **removed** from the `Event` enum. This was the primary breaking change in 0.2.0. + +**What was removed:** +```rust +// This no longer exists in 0.2.0 +Event::MessagePartAdded { + session_id: String, + message_id: String, + part: MessagePart, +} +``` + +**Replaced with:** +```rust +// New in 0.2.0 +Event::SessionUpdate { + session_id: String, + update: SessionUpdate, +} +``` + +## Migration Patterns + +### Pattern 1: Basic Text Streaming + +#### Before (0.1.x) +```rust +use dirigent_protocol::{Event, MessagePart}; + +match event { + Event::MessagePartAdded { session_id, message_id, part } => { + match part { + MessagePart::Text { content, .. } => { + println!("Text from {}: {}", message_id, content); + } + _ => {} + } + } + _ => {} +} +``` + +#### After (0.2.0) +```rust +use dirigent_protocol::{Event, SessionUpdate, ContentBlock}; + +match event { + Event::SessionUpdate { session_id, update } => { + match update { + SessionUpdate::UserMessageChunk { message_id, content, .. } | + SessionUpdate::AgentMessageChunk { message_id, content, .. } => { + match content { + ContentBlock::Text { text } => { + println!("Text from {}: {}", message_id, text); + } + _ => {} + } + } + _ => {} + } + } + _ => {} +} +``` + +**Key Changes:** +- Use `SessionUpdate` instead of `MessagePartAdded` +- Separate `UserMessageChunk` and `AgentMessageChunk` variants +- `ContentBlock::Text { text }` instead of `MessagePart::Text { content }` +- Field renamed: `content` → `text` + +### Pattern 2: Thinking/Reasoning Content + +#### Before (0.1.x) +```rust +match event { + Event::MessagePartAdded { session_id, message_id, part } => { + match part { + MessagePart::Thinking { content, .. } => { + println!("Agent thinking: {}", content); + } + _ => {} + } + } + _ => {} +} +``` + +#### After (0.2.0) +```rust +match event { + Event::SessionUpdate { session_id, update } => { + match update { + SessionUpdate::AgentThoughtChunk { message_id, content, .. } => { + match content { + ContentBlock::Text { text } => { + println!("Agent thinking: {}", text); + } + _ => {} + } + } + _ => {} + } + } + _ => {} +} +``` + +**Key Changes:** +- `MessagePart::Thinking` → `SessionUpdate::AgentThoughtChunk` +- Content wrapped in `ContentBlock::Text` + +### Pattern 3: Tool Calls + +#### Before (0.1.x) +```rust +match event { + Event::MessagePartAdded { session_id, message_id, part } => { + match part { + MessagePart::Tool { + tool_name, + tool_call_id, + status, + input, + output, + .. + } => { + println!("Tool {}: {:?}", tool_name, status); + if let Some(out) = output { + println!("Output: {}", out); + } + } + _ => {} + } + } + _ => {} +} +``` + +#### After (0.2.0) +```rust +use dirigent_protocol::ToolCallStatus; + +match event { + Event::SessionUpdate { session_id, update } => { + match update { + // Initial tool call + SessionUpdate::ToolCall { message_id, tool_call, .. } => { + println!("Tool {}: {:?}", tool_call.tool_name, tool_call.status); + } + + // Tool call updates (progress, completion, errors) + SessionUpdate::ToolCallUpdate { message_id, tool_call_id, tool_call, .. } => { + println!("Tool {} updated: {:?}", tool_call.tool_name, tool_call.status); + + // Output is now in content blocks + for content in &tool_call.content { + match content { + ContentBlock::Text { text } => { + println!("Output: {}", text); + } + _ => {} + } + } + + // Check for errors + if tool_call.status == ToolCallStatus::Error { + if let Some(error) = &tool_call.error { + println!("Error: {}", error); + } + } + } + _ => {} + } + } + _ => {} +} +``` + +**Key Changes:** +- Tool lifecycle split into `ToolCall` (initial) and `ToolCallUpdate` (updates) +- Tool output now in `tool_call.content: Vec<ContentBlock>` +- Structured `ToolCall` type with multiple fields +- Explicit `ToolCallStatus` enum (Pending/Running/Completed/Error) +- Error information in dedicated `error` field + +### Pattern 4: File References + +#### Before (0.1.x) +```rust +match event { + Event::MessagePartAdded { session_id, message_id, part } => { + match part { + MessagePart::File { path, name, mime_type, .. } => { + println!("File reference: {} ({})", name.unwrap_or_default(), path); + } + _ => {} + } + } + _ => {} +} +``` + +#### After (0.2.0) +```rust +match event { + Event::SessionUpdate { session_id, update } => { + match update { + SessionUpdate::AgentMessageChunk { message_id, content, .. } => { + match content { + ContentBlock::ResourceLink { uri, name, mime_type } => { + println!("Resource: {} ({})", + name.as_deref().unwrap_or("unnamed"), + uri + ); + } + _ => {} + } + } + _ => {} + } + } + _ => {} +} +``` + +**Key Changes:** +- `MessagePart::File` → `ContentBlock::ResourceLink` +- Field renamed: `path` → `uri` (more generic) +- Can appear in any message chunk type + +## Complete Migration Example + +Here's a complete before/after example showing a typical event handler: + +### Before (0.1.x) +```rust +use dirigent_protocol::{Event, MessagePart}; + +fn handle_event(event: Event) { + match event { + Event::MessagePartAdded { session_id, message_id, part } => { + match part { + MessagePart::Text { content, .. } => { + append_text(&message_id, &content); + } + MessagePart::Thinking { content, .. } => { + show_thinking(&message_id, &content); + } + MessagePart::Tool { tool_name, tool_call_id, status, output, .. } => { + update_tool_display(&tool_call_id, &tool_name, status, output.as_deref()); + } + MessagePart::File { path, name, .. } => { + add_file_reference(&message_id, &path, name.as_deref()); + } + _ => {} + } + } + Event::MessageStarted { session_id, message_id, .. } => { + create_message_container(&message_id); + } + Event::MessageCompleted { session_id, message_id, .. } => { + finalize_message(&message_id); + } + _ => {} + } +} +``` + +### After (0.2.0) +```rust +use dirigent_protocol::{Event, SessionUpdate, ContentBlock, ToolCallStatus}; + +fn handle_event(event: Event) { + match event { + Event::SessionUpdate { session_id, update } => { + match update { + SessionUpdate::UserMessageChunk { message_id, content, .. } | + SessionUpdate::AgentMessageChunk { message_id, content, .. } => { + match content { + ContentBlock::Text { text } => { + append_text(&message_id, &text); + } + ContentBlock::ResourceLink { uri, name, .. } => { + add_file_reference(&message_id, &uri, name.as_deref()); + } + } + } + + SessionUpdate::AgentThoughtChunk { message_id, content, .. } => { + if let ContentBlock::Text { text } = content { + show_thinking(&message_id, &text); + } + } + + SessionUpdate::ToolCall { message_id, tool_call, .. } => { + create_tool_display( + &tool_call.id, + &tool_call.tool_name, + tool_call.status, + ); + } + + SessionUpdate::ToolCallUpdate { tool_call_id, tool_call, .. } => { + // Extract output text from content blocks + let output_text = tool_call.content.iter() + .filter_map(|c| match c { + ContentBlock::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::<Vec<_>>() + .join(""); + + update_tool_display( + &tool_call_id, + &tool_call.tool_name, + tool_call.status, + if output_text.is_empty() { None } else { Some(&output_text) }, + ); + + // Handle errors + if tool_call.status == ToolCallStatus::Error { + if let Some(error) = &tool_call.error { + show_tool_error(&tool_call_id, error); + } + } + } + } + } + + // These events remain unchanged + Event::MessageStarted { session_id, message_id, .. } => { + create_message_container(&message_id); + } + Event::MessageCompleted { session_id, message_id, .. } => { + finalize_message(&message_id); + } + _ => {} + } +} +``` + +## UI State Management Changes + +### Before: Simple Append Model + +```rust +struct MessageState { + id: String, + text: String, + thinking: String, + tools: Vec<ToolDisplay>, +} + +// On MessagePartAdded with Text: +message.text.push_str(&content); +``` + +### After: ContentBlock Streaming + +```rust +use std::collections::HashMap; + +struct MessageState { + id: String, + content_blocks: Vec<ContentBlock>, + thoughts: Vec<ContentBlock>, + tools: HashMap<String, ToolCall>, // Keyed by tool_call_id +} + +// On AgentMessageChunk: +message.content_blocks.push(content); + +// On AgentThoughtChunk: +message.thoughts.push(content); + +// On ToolCall: +message.tools.insert(tool_call.id.clone(), tool_call); + +// On ToolCallUpdate: +message.tools.insert(tool_call_id.clone(), tool_call); // Replace, not merge +``` + +**Key Insight:** ToolCallUpdate sends the **complete tool state**, not a delta. Always replace the existing tool call entry. + +## What Stays the Same + +The following events and types are **unchanged** and require no migration: + +### Session Events +- `Event::SessionCreated` +- `Event::SessionUpdated` +- `Event::SessionDeleted` +- `Event::SessionsListed` + +### Message Lifecycle Events +- `Event::MessageStarted` +- `Event::MessageCompleted` +- `Event::MessageDeleted` + +### Connector Events +- `Event::ConnectorStateChanged` + +### Types +- `Session` +- `SessionMetadata` (extended with optional fields, but backward compatible) +- `Message` +- `MessageMetadata` +- `MessageRole` +- `MessageStatus` +- `MessagePart` (still used for completed messages) + +## Common Migration Issues + +### Issue 1: Pattern Matching Exhaustiveness + +**Problem:** Compiler errors about non-exhaustive patterns after removing `MessagePartAdded`. + +**Solution:** Remove any match arms for `MessagePartAdded` and add `SessionUpdate` handling. + +### Issue 2: Field Name Mismatches + +**Problem:** `MessagePart::Text { content }` vs `ContentBlock::Text { text }` + +**Solution:** Update field access from `content` to `text`. + +### Issue 3: Tool Call State Management + +**Problem:** Treating `ToolCallUpdate` as a delta instead of complete state. + +**Solution:** Replace the entire tool call entry when receiving `ToolCallUpdate`, don't try to merge. + +**Incorrect:** +```rust +// DON'T do this +if let Some(existing) = tools.get_mut(&tool_call_id) { + existing.content.extend(tool_call.content); // Wrong! + existing.status = tool_call.status; +} +``` + +**Correct:** +```rust +// DO this +tools.insert(tool_call_id.clone(), tool_call); // Replace completely +``` + +### Issue 4: Missing _meta Fields + +**Problem:** Trying to access `_meta` that might be `None`. + +**Solution:** Always use `Option` handling or provide defaults. + +```rust +match update { + SessionUpdate::AgentMessageChunk { _meta, .. } => { + if let Some(meta) = _meta { + if let Some(provider) = &meta.provider { + println!("Provider: {}", provider.name); + } + } + } + _ => {} +} +``` + +## Testing Your Migration + +### Test Checklist + +- [ ] Text streaming displays correctly +- [ ] Agent thoughts appear in designated section +- [ ] User messages are distinguished from agent messages +- [ ] Tool calls show initial pending state +- [ ] Tool calls update with running status +- [ ] Tool calls complete successfully +- [ ] Tool errors display with error messages +- [ ] File references render as links +- [ ] Multiple content chunks accumulate properly +- [ ] Tool call state replaces (not merges) on update + +### Migration Test Example + +```rust +#[test] +fn test_migration_text_streaming() { + let event = Event::SessionUpdate { + session_id: "test_session".to_string(), + update: SessionUpdate::AgentMessageChunk { + message_id: "msg_1".to_string(), + content: ContentBlock::Text { + text: "Hello".to_string(), + }, + _meta: None, + }, + }; + + // Your handler should process this correctly + handle_event(event); + + // Assert expected state changes + assert_eq!(get_message_text("msg_1"), "Hello"); +} + +#[test] +fn test_migration_tool_lifecycle() { + use dirigent_protocol::ToolCallStatus; + + // 1. Initial tool call + handle_event(Event::SessionUpdate { + session_id: "test".to_string(), + update: SessionUpdate::ToolCall { + message_id: "msg_1".to_string(), + tool_call: ToolCall { + id: "call_1".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + }, + _meta: None, + }, + }); + assert_eq!(get_tool_status("call_1"), ToolCallStatus::Pending); + + // 2. Tool starts running + handle_event(Event::SessionUpdate { + session_id: "test".to_string(), + update: SessionUpdate::ToolCallUpdate { + message_id: "msg_1".to_string(), + tool_call_id: "call_1".to_string(), + tool_call: ToolCall { + id: "call_1".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Running, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + }, + _meta: None, + }, + }); + assert_eq!(get_tool_status("call_1"), ToolCallStatus::Running); + + // 3. Tool completes with output + handle_event(Event::SessionUpdate { + session_id: "test".to_string(), + update: SessionUpdate::ToolCallUpdate { + message_id: "msg_1".to_string(), + tool_call_id: "call_1".to_string(), + tool_call: ToolCall { + id: "call_1".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Completed, + content: vec![ + ContentBlock::Text { + text: "Done!".to_string(), + }, + ], + raw_input: None, + raw_output: Some(serde_json::json!({"exit_code": 0})), + title: None, + error: None, + metadata: None, + }, + _meta: None, + }, + }); + assert_eq!(get_tool_status("call_1"), ToolCallStatus::Completed); + assert_eq!(get_tool_output("call_1"), "Done!"); +} +``` + +## Getting Help + +If you encounter issues during migration: + +1. Check the [streaming_model.md](streaming_model.md) documentation for detailed API information +2. Review the [examples/](../examples/) directory for working code samples +3. Examine the [CHANGELOG.md](../CHANGELOG.md) for version-specific details +4. Look at the protocol source code tests for reference implementations + +## Quick Reference: Event Type Mapping + +| 0.1.x Event | 0.2.0 Event | Notes | +|-------------|-------------|-------| +| `MessagePartAdded` with `Text` | `SessionUpdate::AgentMessageChunk` or `UserMessageChunk` | Split by role | +| `MessagePartAdded` with `Thinking` | `SessionUpdate::AgentThoughtChunk` | Dedicated type | +| `MessagePartAdded` with `Tool` | `SessionUpdate::ToolCall` + `ToolCallUpdate` | Split lifecycle | +| `MessagePartAdded` with `File` | `ContentBlock::ResourceLink` in chunks | Different structure | +| All other events | Unchanged | No migration needed | + +## Summary + +The 0.2.0 migration primarily involves: + +1. Replacing `Event::MessagePartAdded` pattern matching with `Event::SessionUpdate` +2. Using `SessionUpdate` variants instead of `MessagePart` variants +3. Accessing `ContentBlock::Text { text }` instead of `MessagePart::Text { content }` +4. Handling tool lifecycle with separate `ToolCall` and `ToolCallUpdate` events +5. Managing tool state as complete snapshots, not deltas + +The new model provides better structure, clearer semantics, and improved alignment with ACP standards while maintaining most of your existing code. diff --git a/crates/dirigent_protocol/docs/streaming_model.md b/crates/dirigent_protocol/docs/streaming_model.md new file mode 100644 index 0000000..c7fb0e2 --- /dev/null +++ b/crates/dirigent_protocol/docs/streaming_model.md @@ -0,0 +1,476 @@ +# Dirigent Protocol Streaming Model + +## Overview + +The Dirigent Protocol uses an **ACP-style streaming model** built around `SessionUpdate` events. This model provides granular, real-time updates during agent interactions, enabling responsive UIs and structured content representation. + +Version: 0.2.0 + +## Core Concepts + +### SessionUpdate Events + +All streaming content is delivered through `SessionUpdate` variants wrapped in `Event::SessionUpdate`: + +```rust +pub enum Event { + // ... other events + SessionUpdate { + session_id: String, + update: SessionUpdate, + }, +} +``` + +The `SessionUpdate` enum contains five variants for different types of streaming updates: + +```rust +pub enum SessionUpdate { + UserMessageChunk { message_id: String, content: ContentBlock, _meta: Option<Meta> }, + AgentMessageChunk { message_id: String, content: ContentBlock, _meta: Option<Meta> }, + AgentThoughtChunk { message_id: String, content: ContentBlock, _meta: Option<Meta> }, + ToolCall { message_id: String, tool_call: ToolCall, _meta: Option<Meta> }, + ToolCallUpdate { message_id: String, tool_call_id: String, tool_call: ToolCall, _meta: Option<Meta> }, +} +``` + +### ContentBlock Types + +Content is represented using structured `ContentBlock` variants: + +```rust +pub enum ContentBlock { + Text { text: String }, + ResourceLink { + uri: String, + name: Option<String>, + mime_type: Option<String>, + }, + // Future: Resource, Image, Audio (marked as out-of-scope for phase 1) +} +``` + +**Key Points:** +- `Text` is the primary content type for all textual output +- `ResourceLink` represents file references without embedding full content +- Future expansions will include embedded resources, images, and audio + +### Provider Metadata (_meta) + +All `SessionUpdate` variants support optional `_meta` fields for provider-specific information: + +```rust +pub struct Meta { + pub provider: Option<ProviderMeta>, + pub extra: HashMap<String, Value>, // Arbitrary additional fields +} + +pub struct ProviderMeta { + pub name: String, // e.g., "opencode", "anthropic" + pub original_ids: Option<HashMap<String, String>>, // Original provider IDs + pub raw_excerpt: Option<Value>, // Minimal raw payload for debugging +} +``` + +**Usage:** +- Adapters populate `_meta` to preserve provider-specific information +- Consumers can use this for debugging, telemetry, or provider-specific features +- The `extra` map allows arbitrary fields for forward compatibility + +## SessionUpdate Variants + +### 1. UserMessageChunk + +Represents streaming chunks of user message content. + +```rust +SessionUpdate::UserMessageChunk { + message_id: "msg_abc123".to_string(), + content: ContentBlock::Text { + text: "What's the capital of France?".to_string(), + }, + _meta: None, +} +``` + +**When to use:** +- Streaming user input being typed +- Echo of user input from server +- Multi-part user messages being assembled + +### 2. AgentMessageChunk + +Represents streaming chunks of agent response content. + +```rust +SessionUpdate::AgentMessageChunk { + message_id: "msg_def456".to_string(), + content: ContentBlock::Text { + text: "The capital of France is ".to_string(), + }, + _meta: Some(Meta { + provider: Some(ProviderMeta { + name: "opencode".to_string(), + original_ids: Some(HashMap::from([ + ("message_id".to_string(), "original_123".to_string()), + ])), + raw_excerpt: None, + }), + extra: HashMap::new(), + }), +} +``` + +**When to use:** +- Agent's response text being generated +- Final answer content +- Any visible agent output + +**Key distinction from AgentThoughtChunk:** +- `AgentMessageChunk`: Visible output intended for the user +- `AgentThoughtChunk`: Internal reasoning, typically hidden or collapsible + +### 3. AgentThoughtChunk + +Represents streaming chunks of agent internal reasoning (thinking, planning). + +```rust +SessionUpdate::AgentThoughtChunk { + message_id: "msg_ghi789".to_string(), + content: ContentBlock::Text { + text: "I need to look up Paris in my knowledge base...".to_string(), + }, + _meta: None, +} +``` + +**When to use:** +- Agent's internal reasoning process +- "Chain of thought" content +- Planning or decision-making process +- Content typically displayed in collapsible sections + +**UI Conventions:** +- Often hidden by default or shown in a separate "Thinking" section +- May be styled differently (e.g., italics, muted colors) +- Can be collapsed to save screen space + +### 4. ToolCall + +Represents the initiation or current state of a tool call. + +```rust +SessionUpdate::ToolCall { + message_id: "msg_jkl012".to_string(), + tool_call: ToolCall { + id: "call_xyz789".to_string(), + tool_name: "read_file".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: Some(json!({ + "file_path": "/path/to/file.txt" + })), + raw_output: None, + title: Some("Read file.txt".to_string()), + error: None, + metadata: None, + }, + _meta: None, +} +``` + +**When to use:** +- Tool call is first initiated +- Sending a snapshot of current tool state +- Re-sending full tool state after reconnection + +**ToolCallStatus Lifecycle:** +- `Pending` → Tool call created but not yet executing +- `Running` → Tool call actively executing +- `Completed` → Tool call finished successfully +- `Error` → Tool call failed + +### 5. ToolCallUpdate + +Represents an update to an existing tool call (status change, new content). + +```rust +SessionUpdate::ToolCallUpdate { + message_id: "msg_jkl012".to_string(), + tool_call_id: "call_xyz789".to_string(), + tool_call: ToolCall { + id: "call_xyz789".to_string(), + tool_name: "read_file".to_string(), + status: ToolCallStatus::Running, + content: vec![ + ContentBlock::Text { + text: "Reading file...".to_string(), + }, + ], + raw_input: Some(json!({ + "file_path": "/path/to/file.txt" + })), + raw_output: None, + title: Some("Read file.txt".to_string()), + error: None, + metadata: Some(json!({ + "bytes_read": 1024 + })), + }, + _meta: None, +} +``` + +**When to use:** +- Tool status changes (Pending → Running → Completed/Error) +- New output content available +- Progress updates +- Error state reached + +**Note:** The full `ToolCall` is sent each time, not a delta. Consumers should replace the previous tool call state with the new one. + +## Tool Call Lifecycle + +Understanding the tool call lifecycle is essential for proper UI implementation: + +```rust +// 1. Tool call initiated +SessionUpdate::ToolCall { + tool_call: ToolCall { + id: "call_123", + status: ToolCallStatus::Pending, + content: vec![], + // ... + } +} + +// 2. Tool starts executing +SessionUpdate::ToolCallUpdate { + tool_call_id: "call_123", + tool_call: ToolCall { + id: "call_123", + status: ToolCallStatus::Running, + content: vec![], + // ... + } +} + +// 3. Tool produces output +SessionUpdate::ToolCallUpdate { + tool_call_id: "call_123", + tool_call: ToolCall { + id: "call_123", + status: ToolCallStatus::Running, + content: vec![ + ContentBlock::Text { text: "Output line 1" }, + ], + // ... + } +} + +// 4a. Tool completes successfully +SessionUpdate::ToolCallUpdate { + tool_call_id: "call_123", + tool_call: ToolCall { + id: "call_123", + status: ToolCallStatus::Completed, + content: vec![ + ContentBlock::Text { text: "Output line 1" }, + ContentBlock::Text { text: "Done!" }, + ], + raw_output: Some(json!({"success": true})), + // ... + } +} + +// 4b. Or tool fails with error +SessionUpdate::ToolCallUpdate { + tool_call_id: "call_123", + tool_call: ToolCall { + id: "call_123", + status: ToolCallStatus::Error, + content: vec![ + ContentBlock::Text { text: "Error output" }, + ], + error: Some("File not found".to_string()), + // ... + } +} +``` + +**UI Implementation Guidelines:** +- Track tool calls by `id` in a HashMap +- On `ToolCall`: create new entry +- On `ToolCallUpdate`: replace existing entry (not delta) +- Display status with appropriate visual indicators +- Show `content` blocks as streaming output +- Display `error` when status is `Error` +- Use `title` for tool call heading if available + +## Typical Message Flow + +Here's a complete example showing a typical agent interaction: + +```rust +// 1. User sends message +Event::SessionUpdate { + session_id: "session_123", + update: SessionUpdate::UserMessageChunk { + message_id: "msg_user_1", + content: ContentBlock::Text { + text: "Read and summarize config.toml", + }, + _meta: None, + }, +} + +// 2. Agent starts thinking +Event::SessionUpdate { + session_id: "session_123", + update: SessionUpdate::AgentThoughtChunk { + message_id: "msg_agent_1", + content: ContentBlock::Text { + text: "I need to read the file first...", + }, + _meta: None, + }, +} + +// 3. Agent initiates tool call +Event::SessionUpdate { + session_id: "session_123", + update: SessionUpdate::ToolCall { + message_id: "msg_agent_1", + tool_call: ToolCall { + id: "call_read_1", + tool_name: "read_file", + status: ToolCallStatus::Pending, + content: vec![], + raw_input: Some(json!({"path": "config.toml"})), + title: Some("Read config.toml"), + // ... + }, + _meta: None, + }, +} + +// 4. Tool starts executing +Event::SessionUpdate { + session_id: "session_123", + update: SessionUpdate::ToolCallUpdate { + message_id: "msg_agent_1", + tool_call_id: "call_read_1", + tool_call: ToolCall { + id: "call_read_1", + status: ToolCallStatus::Running, + // ... + }, + _meta: None, + }, +} + +// 5. Tool completes +Event::SessionUpdate { + session_id: "session_123", + update: SessionUpdate::ToolCallUpdate { + message_id: "msg_agent_1", + tool_call_id: "call_read_1", + tool_call: ToolCall { + id: "call_read_1", + status: ToolCallStatus::Completed, + content: vec![ + ContentBlock::Text { + text: "[port = 3000\n...]", + }, + ], + raw_output: Some(json!({"bytes_read": 1024})), + // ... + }, + _meta: None, + }, +} + +// 6. Agent responds with summary +Event::SessionUpdate { + session_id: "session_123", + update: SessionUpdate::AgentMessageChunk { + message_id: "msg_agent_1", + content: ContentBlock::Text { + text: "The config file sets the server port to 3000", + }, + _meta: None, + }, +} + +// 7. More response chunks... +Event::SessionUpdate { + session_id: "session_123", + update: SessionUpdate::AgentMessageChunk { + message_id: "msg_agent_1", + content: ContentBlock::Text { + text: " and enables debug mode.", + }, + _meta: None, + }, +} +``` + +## Content vs MessagePart + +**Important Distinction:** + +- **ContentBlock**: Streaming content representation (used in `SessionUpdate`) + - Designed for real-time rendering + - Granular updates + - MCP-compatible structure + +- **MessagePart**: Completed message content (legacy, still supported) + - Used in stored/completed messages + - May include additional fields for history + - Compatibility with existing code + +**Migration Path:** The protocol supports both models. New code should prefer `SessionUpdate` with `ContentBlock` for streaming, while `MessagePart` remains available for compatibility and completed message storage. + +## Best Practices + +### For Consumers + +1. **Track by message_id**: Group all chunks/updates for the same message +2. **Handle tool calls separately**: Maintain a HashMap of tool calls by `tool_call_id` +3. **Replace, don't merge**: `ToolCallUpdate` sends complete state, not deltas +4. **Use _meta for debugging**: Provider metadata helps with troubleshooting +5. **Distinguish thoughts from messages**: Render `AgentThoughtChunk` differently + +### For Adapters + +1. **Always include message_id**: Every update must reference its message +2. **Preserve provider info in _meta**: Store original IDs for debugging +3. **Send complete tool state**: Include all tool call fields in updates +4. **Use appropriate chunk types**: User/Agent/Thought for correct semantics +5. **Keep _meta minimal**: Avoid large raw payloads in production + +### For UI Developers + +1. **Stream incrementally**: Append chunks as they arrive +2. **Show tool status visually**: Use icons/colors for Pending/Running/Completed/Error +3. **Make thoughts collapsible**: Don't clutter the main conversation +4. **Handle reconnection**: Be prepared to receive full state snapshots +5. **Display errors prominently**: Show tool errors clearly to users + +## Future Extensions + +The following features are planned but not yet implemented: + +- **ResourceBlock**: Embedded resource content (text/blob) +- **Image/Audio blocks**: Rich media content +- **Plan updates**: Agent planning and mode switching +- **Permissions**: Request/reply for user permissions +- **Stop reasons**: Detailed completion reasons + +See the protocol roadmap for timeline and details. + +## See Also + +- [Migration from 0.1.x](migration_from_0.1.md) - Upgrading from older versions +- [CHANGELOG.md](../CHANGELOG.md) - Version history and breaking changes +- [examples/](../examples/) - Code examples demonstrating usage diff --git a/crates/dirigent_protocol/examples/session_metadata_demo.rs b/crates/dirigent_protocol/examples/session_metadata_demo.rs new file mode 100644 index 0000000..cf69084 --- /dev/null +++ b/crates/dirigent_protocol/examples/session_metadata_demo.rs @@ -0,0 +1,100 @@ +use chrono::Utc; +use dirigent_protocol::types::meta::{Meta, ProviderMeta}; +use dirigent_protocol::{Session, SessionMetadata}; +use std::collections::HashMap; + +fn main() { + println!("=== SessionMetadata JSON Examples ===\n"); + + // Example 1: Without new fields (backward compatible) + println!("1. Basic SessionMetadata (without new fields):"); + let basic = SessionMetadata { + project_path: "/workspace/project".to_string(), + model: Some("gpt-4".to_string()), + total_messages: 10, + system_message: None, + current_mode_id: None, + _meta: None, + project_id: None, + }; + let json = serde_json::to_string_pretty(&basic).unwrap(); + println!("{}\n", json); + + // Example 2: With current_mode_id + println!("2. SessionMetadata with current_mode_id:"); + let with_mode = SessionMetadata { + project_path: "/workspace/project".to_string(), + model: Some("claude-3-sonnet".to_string()), + total_messages: 5, + system_message: Some("You are a helpful coding assistant".to_string()), + current_mode_id: Some("code_mode".to_string()), + _meta: None, + project_id: None, + }; + let json = serde_json::to_string_pretty(&with_mode).unwrap(); + println!("{}\n", json); + + // Example 3: With provider metadata + println!("3. SessionMetadata with provider metadata:"); + let meta = Meta { + provider: Some(ProviderMeta { + name: "opencode".to_string(), + original_ids: Some(HashMap::from([ + ("session_id".to_string(), "ses_abc123xyz".to_string()), + ("project_id".to_string(), "proj_456".to_string()), + ])), + raw_excerpt: Some(serde_json::json!({ + "version": "0.15.31", + "share": null + })), + }), + extra: HashMap::new(), + }; + + let with_meta = SessionMetadata { + project_path: "/workspace/project".to_string(), + model: Some("gpt-4-turbo".to_string()), + total_messages: 15, + system_message: Some("System prompt here".to_string()), + current_mode_id: Some("architect".to_string()), + _meta: Some(meta), + project_id: None, + }; + let json = serde_json::to_string_pretty(&with_meta).unwrap(); + println!("{}\n", json); + + // Example 4: Full Session object + println!("4. Complete Session with all metadata fields:"); + let now = Utc::now(); + let session = Session { + id: "ses_demo_123".to_string(), + title: "My Coding Session".to_string(), + created_at: now, + updated_at: now, + metadata: with_meta, + models: None, + modes: None, + config_options: None, + acp_client_id: None, + cwd: None, + }; + let json = serde_json::to_string_pretty(&session).unwrap(); + println!("{}\n", json); + + // Example 5: Verify backward compatibility + println!("5. Backward compatibility test (deserializing old format):"); + let old_json = r#"{ + "project_path": "/old/project", + "model": "gpt-3.5", + "total_messages": 3 + }"#; + + let parsed: SessionMetadata = serde_json::from_str(old_json).unwrap(); + println!("Successfully parsed old JSON format:"); + println!(" project_path: {}", parsed.project_path); + println!(" model: {:?}", parsed.model); + println!(" total_messages: {}", parsed.total_messages); + println!(" system_message: {:?}", parsed.system_message); + println!(" current_mode_id: {:?}", parsed.current_mode_id); + println!(" _meta: {:?}", parsed._meta); +} diff --git a/crates/dirigent_protocol/src/accumulator.rs b/crates/dirigent_protocol/src/accumulator.rs new file mode 100644 index 0000000..758f98a --- /dev/null +++ b/crates/dirigent_protocol/src/accumulator.rs @@ -0,0 +1,717 @@ +//! Protocol-level message accumulator for incremental message assembly. +//! +//! Handles streaming message deltas and assembles them into complete +//! [`AccumulatedMessage`] values using protocol types. Unlike the archivist's +//! accumulator, this module produces protocol-native output without UUID +//! parsing, markdown generation, or storage metadata. +//! +//! The accumulator preserves the order of content parts (text, thinking, tool +//! calls) as they arrive in the event stream, enabling inline tool rendering +//! and faithful forwarding to downstream consumers. + +use chrono::{DateTime, Utc}; +use serde_json::Value; +use std::collections::HashMap; + +use crate::conversation::MessagePart; +use crate::types::ContentBlock; + +/// Tool call data accumulated during streaming. +#[derive(Debug, Clone)] +pub struct ToolCallData { + pub id: String, + pub tool_name: String, + pub input: Value, + pub output: Option<Value>, +} + +/// A single accumulated content part, preserving event-stream order. +#[derive(Debug, Clone)] +pub enum AccumulatedPart { + Text { text: String }, + Thinking { text: String }, + Tool { data: ToolCallData }, +} + +/// A fully accumulated message assembled from streaming chunks. +#[derive(Debug, Clone)] +pub struct AccumulatedMessage { + pub message_id: String, + pub session_id: String, + pub connector_id: String, + pub role: String, + pub parts: Vec<AccumulatedPart>, + pub created_at: Option<DateTime<Utc>>, + pub last_activity: DateTime<Utc>, +} + +impl AccumulatedMessage { + /// Build an `AccumulatedMessage` directly from a slice of [`MessagePart`]s. + /// + /// This is the path for non-streaming clients that deliver content in a + /// single `MessageCompleted` event rather than via incremental chunks. + pub fn from_message_parts( + message_id: String, + session_id: String, + connector_id: String, + role: String, + parts: &[MessagePart], + ) -> Self { + let now = Utc::now(); + let mut accumulated_parts = Vec::new(); + + for part in parts { + match part { + MessagePart::Text { text } => { + if !text.is_empty() { + accumulated_parts.push(AccumulatedPart::Text { text: text.clone() }); + } + } + MessagePart::Thinking { text } => { + if !text.is_empty() { + accumulated_parts.push(AccumulatedPart::Thinking { text: text.clone() }); + } + } + MessagePart::Code { language, code } => { + if !code.is_empty() { + let fenced = format!("```{}\n{}\n```", language, code); + accumulated_parts.push(AccumulatedPart::Text { text: fenced }); + } + } + MessagePart::Tool { + tool, + tool_call_id, + input, + output, + } => { + accumulated_parts.push(AccumulatedPart::Tool { + data: ToolCallData { + id: tool_call_id + .clone() + .unwrap_or_else(|| String::new()), + tool_name: tool.clone(), + input: input.clone(), + output: output.clone(), + }, + }); + } + MessagePart::File { path, content: _ } => { + let label = format!("\u{1f4c4} File: {}", path); + accumulated_parts.push(AccumulatedPart::Text { text: label }); + } + } + } + + Self { + message_id, + session_id, + connector_id, + role, + parts: accumulated_parts, + created_at: Some(now), + last_activity: now, + } + } + + /// Convert accumulated parts back to protocol [`MessagePart`]s. + pub fn to_message_parts(&self) -> Vec<MessagePart> { + self.parts + .iter() + .map(|part| match part { + AccumulatedPart::Text { text } => MessagePart::Text { text: text.clone() }, + AccumulatedPart::Thinking { text } => { + MessagePart::Thinking { text: text.clone() } + } + AccumulatedPart::Tool { data } => MessagePart::Tool { + tool: data.tool_name.clone(), + tool_call_id: if data.id.is_empty() { + None + } else { + Some(data.id.clone()) + }, + input: data.input.clone(), + output: data.output.clone(), + }, + }) + .collect() + } + + /// Returns `true` when the message has no accumulated content. + pub fn is_empty(&self) -> bool { + self.parts.is_empty() + } +} + +// --------------------------------------------------------------------------- +// Internal buffer +// --------------------------------------------------------------------------- + +/// Buffer for accumulating streaming chunks into a complete message. +#[derive(Debug)] +struct MessageBuffer { + message_id: String, + session_id: String, + connector_id: String, + role: String, + parts: Vec<AccumulatedPart>, + created_at: Option<DateTime<Utc>>, + last_activity: DateTime<Utc>, +} + +impl MessageBuffer { + fn new(message_id: String, session_id: String, connector_id: String, role: String) -> Self { + let now = Utc::now(); + Self { + message_id, + session_id, + connector_id, + role, + parts: Vec::new(), + created_at: None, + last_activity: now, + } + } + + fn touch(&mut self) { + let now = Utc::now(); + if self.created_at.is_none() { + self.created_at = Some(now); + } + self.last_activity = now; + } +} + +// --------------------------------------------------------------------------- +// MessageAccumulator +// --------------------------------------------------------------------------- + +/// Accumulator for assembling streaming message deltas into complete messages. +/// +/// Each in-flight message is identified by its `message_id` and tracked in an +/// internal buffer. Text and thinking chunks are coalesced when consecutive; +/// tool calls are deduplicated by `tool_call_id`. +#[derive(Debug, Default)] +pub struct MessageAccumulator { + buffers: HashMap<String, MessageBuffer>, +} + +impl MessageAccumulator { + /// Create a new, empty accumulator. + pub fn new() -> Self { + Self { + buffers: HashMap::new(), + } + } + + /// Add a content chunk to the message buffer. + /// + /// Consecutive text chunks are coalesced into a single `AccumulatedPart::Text`. + pub fn add_chunk( + &mut self, + message_id: &str, + session_id: &str, + connector_id: &str, + role: &str, + content: ContentBlock, + ) { + let buffer = self + .buffers + .entry(message_id.to_string()) + .or_insert_with(|| { + MessageBuffer::new( + message_id.to_string(), + session_id.to_string(), + connector_id.to_string(), + role.to_string(), + ) + }); + + buffer.touch(); + + match content { + ContentBlock::Text { text } => { + if let Some(AccumulatedPart::Text { text: existing }) = buffer.parts.last_mut() { + existing.push_str(&text); + } else { + buffer.parts.push(AccumulatedPart::Text { text }); + } + } + ContentBlock::ResourceLink { .. } => { + // ResourceLink is not accumulated as text content for now. + } + } + } + + /// Add thinking content to the message buffer. + /// + /// Consecutive thinking chunks are coalesced into a single + /// `AccumulatedPart::Thinking`. + pub fn add_thinking( + &mut self, + message_id: &str, + session_id: &str, + connector_id: &str, + content: &str, + ) { + let buffer = self + .buffers + .entry(message_id.to_string()) + .or_insert_with(|| { + MessageBuffer::new( + message_id.to_string(), + session_id.to_string(), + connector_id.to_string(), + "assistant".to_string(), + ) + }); + + buffer.touch(); + + if let Some(AccumulatedPart::Thinking { text: existing }) = buffer.parts.last_mut() { + existing.push_str(content); + } else { + buffer + .parts + .push(AccumulatedPart::Thinking { text: content.to_string() }); + } + } + + /// Add or update a tool call in the message buffer. + /// + /// If a tool call with the same `id` already exists in the buffer, the + /// existing entry is updated (input is overwritten only when non-empty; + /// output is overwritten when `Some`). Otherwise a new entry is appended, + /// preserving event-stream ordering. + pub fn add_or_update_tool_call(&mut self, message_id: &str, tool_call: ToolCallData) { + if let Some(buffer) = self.buffers.get_mut(message_id) { + buffer.last_activity = Utc::now(); + + // Try to find and update an existing tool call with the same id. + for part in buffer.parts.iter_mut() { + if let AccumulatedPart::Tool { data } = part { + if data.id == tool_call.id { + data.tool_name = tool_call.tool_name; + + if tool_call.input != Value::Null + && tool_call.input != serde_json::json!({}) + { + data.input = tool_call.input; + } + + if tool_call.output.is_some() { + data.output = tool_call.output; + } + + return; + } + } + } + + // First time seeing this tool_call_id -- append. + buffer + .parts + .push(AccumulatedPart::Tool { data: tool_call }); + } + } + + /// Finalize a message and return its accumulated content. + /// + /// The internal buffer for `message_id` is removed. Returns `None` if no + /// buffer exists for the given id. + pub fn finalize(&mut self, message_id: &str) -> Option<AccumulatedMessage> { + let buffer = self.buffers.remove(message_id)?; + + Some(AccumulatedMessage { + message_id: buffer.message_id, + session_id: buffer.session_id, + connector_id: buffer.connector_id, + role: buffer.role, + parts: buffer.parts, + created_at: buffer.created_at, + last_activity: buffer.last_activity, + }) + } + + /// Returns `true` if a buffer exists for the given message id. + pub fn has_buffer(&self, message_id: &str) -> bool { + self.buffers.contains_key(message_id) + } + + /// Return all buffered message ids that belong to `session_id`. + pub fn message_ids_for_session(&self, session_id: &str) -> Vec<String> { + self.buffers + .iter() + .filter(|(_, buf)| buf.session_id == session_id) + .map(|(id, _)| id.clone()) + .collect() + } + + /// Return all currently buffered message ids. + pub fn active_message_ids(&self) -> Vec<String> { + self.buffers.keys().cloned().collect() + } + + /// Return message ids whose buffers have not been touched for longer than + /// `threshold`. + pub fn stale_message_ids(&self, threshold: std::time::Duration) -> Vec<String> { + let now = Utc::now(); + self.buffers + .iter() + .filter(|(_, buf)| { + let inactive = now.signed_duration_since(buf.last_activity); + inactive + .to_std() + .unwrap_or(std::time::Duration::ZERO) + > threshold + }) + .map(|(id, _)| id.clone()) + .collect() + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_accumulator() { + let mut acc = MessageAccumulator::new(); + assert!(acc.finalize("nonexistent").is_none()); + assert!(acc.active_message_ids().is_empty()); + } + + #[test] + fn test_text_chunk_coalescing() { + let mut acc = MessageAccumulator::new(); + + acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text { + text: "Hello, ".to_string(), + }); + acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text { + text: "world!".to_string(), + }); + + let msg = acc.finalize("msg1").unwrap(); + assert_eq!(msg.parts.len(), 1); + match &msg.parts[0] { + AccumulatedPart::Text { text } => assert_eq!(text, "Hello, world!"), + other => panic!("Expected Text, got {:?}", other), + } + } + + #[test] + fn test_thinking_coalescing() { + let mut acc = MessageAccumulator::new(); + + acc.add_thinking("msg1", "s1", "c1", "First. "); + acc.add_thinking("msg1", "s1", "c1", "Second."); + + let msg = acc.finalize("msg1").unwrap(); + assert_eq!(msg.parts.len(), 1); + match &msg.parts[0] { + AccumulatedPart::Thinking { text } => assert_eq!(text, "First. Second."), + other => panic!("Expected Thinking, got {:?}", other), + } + } + + #[test] + fn test_interleaved_parts_preserve_order() { + let mut acc = MessageAccumulator::new(); + + // text, tool, text -- should produce 3 distinct parts + acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text { + text: "Before tool.".to_string(), + }); + + acc.add_or_update_tool_call("msg1", ToolCallData { + id: "tc1".to_string(), + tool_name: "grep".to_string(), + input: serde_json::json!({"q": "x"}), + output: None, + }); + + acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text { + text: "After tool.".to_string(), + }); + + let msg = acc.finalize("msg1").unwrap(); + assert_eq!(msg.parts.len(), 3); + assert!(matches!(&msg.parts[0], AccumulatedPart::Text { .. })); + assert!(matches!(&msg.parts[1], AccumulatedPart::Tool { .. })); + assert!(matches!(&msg.parts[2], AccumulatedPart::Text { .. })); + } + + #[test] + fn test_tool_call_deduplication() { + let mut acc = MessageAccumulator::new(); + + // Create buffer first + acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text { + text: "hi".to_string(), + }); + + // Initial tool call + acc.add_or_update_tool_call("msg1", ToolCallData { + id: "tc1".to_string(), + tool_name: "read".to_string(), + input: serde_json::json!({"path": "foo.rs"}), + output: None, + }); + + // Update same tool call with output (empty input should NOT overwrite) + acc.add_or_update_tool_call("msg1", ToolCallData { + id: "tc1".to_string(), + tool_name: "read".to_string(), + input: serde_json::json!({}), + output: Some(serde_json::json!({"content": "fn main() {}"})), + }); + + let msg = acc.finalize("msg1").unwrap(); + // text + 1 tool (not 2) + assert_eq!(msg.parts.len(), 2); + + match &msg.parts[1] { + AccumulatedPart::Tool { data } => { + assert_eq!(data.id, "tc1"); + // Input preserved from first call (non-empty), not overwritten by empty update + assert_eq!(data.input, serde_json::json!({"path": "foo.rs"})); + // Output set from update + assert_eq!( + data.output, + Some(serde_json::json!({"content": "fn main() {}"})) + ); + } + other => panic!("Expected Tool, got {:?}", other), + } + } + + #[test] + fn test_from_message_parts_non_streaming() { + let parts = vec![ + MessagePart::Text { + text: "Hello".to_string(), + }, + MessagePart::Thinking { + text: "hmm".to_string(), + }, + MessagePart::Code { + language: "rs".to_string(), + code: "fn main() {}".to_string(), + }, + MessagePart::Tool { + tool: "grep".to_string(), + tool_call_id: Some("tc1".to_string()), + input: serde_json::json!({"q": "x"}), + output: Some(serde_json::json!("found")), + }, + MessagePart::File { + path: "README.md".to_string(), + content: "# Title".to_string(), + }, + // Empty text and code should be skipped + MessagePart::Text { + text: String::new(), + }, + MessagePart::Code { + language: "py".to_string(), + code: String::new(), + }, + ]; + + let msg = AccumulatedMessage::from_message_parts( + "msg1".into(), + "s1".into(), + "c1".into(), + "assistant".into(), + &parts, + ); + + // 5 non-empty parts: text, thinking, code-as-text, tool, file-as-text + assert_eq!(msg.parts.len(), 5); + + match &msg.parts[0] { + AccumulatedPart::Text { text } => assert_eq!(text, "Hello"), + other => panic!("Expected Text, got {:?}", other), + } + match &msg.parts[1] { + AccumulatedPart::Thinking { text } => assert_eq!(text, "hmm"), + other => panic!("Expected Thinking, got {:?}", other), + } + match &msg.parts[2] { + AccumulatedPart::Text { text } => { + assert!(text.contains("```rs")); + assert!(text.contains("fn main() {}")); + } + other => panic!("Expected Text (code), got {:?}", other), + } + match &msg.parts[3] { + AccumulatedPart::Tool { data } => { + assert_eq!(data.tool_name, "grep"); + assert_eq!(data.id, "tc1"); + assert_eq!(data.output, Some(serde_json::json!("found"))); + } + other => panic!("Expected Tool, got {:?}", other), + } + match &msg.parts[4] { + AccumulatedPart::Text { text } => { + assert!(text.contains("File: README.md")); + } + other => panic!("Expected Text (file), got {:?}", other), + } + } + + #[test] + fn test_session_and_stale_queries() { + let mut acc = MessageAccumulator::new(); + + acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text { + text: "a".to_string(), + }); + acc.add_chunk("msg2", "s1", "c1", "assistant", ContentBlock::Text { + text: "b".to_string(), + }); + acc.add_chunk("msg3", "s2", "c1", "user", ContentBlock::Text { + text: "c".to_string(), + }); + + // message_ids_for_session + let mut s1_ids = acc.message_ids_for_session("s1"); + s1_ids.sort(); + assert_eq!(s1_ids, vec!["msg1", "msg2"]); + + let s2_ids = acc.message_ids_for_session("s2"); + assert_eq!(s2_ids, vec!["msg3"]); + + assert!(acc.message_ids_for_session("s3").is_empty()); + + // active_message_ids + let mut all = acc.active_message_ids(); + all.sort(); + assert_eq!(all, vec!["msg1", "msg2", "msg3"]); + + // has_buffer + assert!(acc.has_buffer("msg1")); + assert!(!acc.has_buffer("msg99")); + + // stale_message_ids with zero threshold -- everything is stale + // (last_activity <= now) + let _stale_zero = acc.stale_message_ids(std::time::Duration::ZERO); + // All three should be considered stale since last_activity <= now + // (Due to timing, they might not all be strictly < now, so we check + // with a small threshold instead.) + let stale_lenient = acc.stale_message_ids(std::time::Duration::from_secs(0)); + // At minimum, none should be stale with a huge threshold + let not_stale = acc.stale_message_ids(std::time::Duration::from_secs(3600)); + assert!(not_stale.is_empty()); + + // Verify stale detection works by checking the lenient case doesn't + // return more ids than we have buffers + assert!(stale_lenient.len() <= 3); + } + + #[test] + fn test_to_message_parts() { + let mut acc = MessageAccumulator::new(); + + acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text { + text: "Hello ".to_string(), + }); + acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text { + text: "world.".to_string(), + }); + acc.add_thinking("msg1", "s1", "c1", "thinking..."); + acc.add_or_update_tool_call("msg1", ToolCallData { + id: "tc1".to_string(), + tool_name: "search".to_string(), + input: serde_json::json!({"q": "test"}), + output: Some(serde_json::json!("result")), + }); + + let msg = acc.finalize("msg1").unwrap(); + let parts = msg.to_message_parts(); + + assert_eq!(parts.len(), 3); + + // Coalesced text + match &parts[0] { + MessagePart::Text { text } => assert_eq!(text, "Hello world."), + other => panic!("Expected Text, got {:?}", other), + } + + // Thinking + match &parts[1] { + MessagePart::Thinking { text } => assert_eq!(text, "thinking..."), + other => panic!("Expected Thinking, got {:?}", other), + } + + // Tool roundtrip + match &parts[2] { + MessagePart::Tool { + tool, + tool_call_id, + input, + output, + } => { + assert_eq!(tool, "search"); + assert_eq!(tool_call_id, &Some("tc1".to_string())); + assert_eq!(input, &serde_json::json!({"q": "test"})); + assert_eq!(output, &Some(serde_json::json!("result"))); + } + other => panic!("Expected Tool, got {:?}", other), + } + } + + #[test] + fn test_is_empty() { + let msg = AccumulatedMessage::from_message_parts( + "msg1".into(), + "s1".into(), + "c1".into(), + "user".into(), + &[], + ); + assert!(msg.is_empty()); + + let msg2 = AccumulatedMessage::from_message_parts( + "msg2".into(), + "s1".into(), + "c1".into(), + "user".into(), + &[MessagePart::Text { + text: "hi".to_string(), + }], + ); + assert!(!msg2.is_empty()); + } + + #[test] + fn test_to_message_parts_empty_tool_id() { + // Tool with no tool_call_id should roundtrip as None + let parts = vec![MessagePart::Tool { + tool: "bash".to_string(), + tool_call_id: None, + input: serde_json::json!({"cmd": "ls"}), + output: None, + }]; + + let msg = AccumulatedMessage::from_message_parts( + "msg1".into(), + "s1".into(), + "c1".into(), + "assistant".into(), + &parts, + ); + + let roundtripped = msg.to_message_parts(); + match &roundtripped[0] { + MessagePart::Tool { tool_call_id, .. } => { + assert_eq!(tool_call_id, &None); + } + other => panic!("Expected Tool, got {:?}", other), + } + } +} diff --git a/crates/dirigent_protocol/src/adapters/acp.rs b/crates/dirigent_protocol/src/adapters/acp.rs new file mode 100644 index 0000000..a26ea6d --- /dev/null +++ b/crates/dirigent_protocol/src/adapters/acp.rs @@ -0,0 +1,2447 @@ +//! ACP (Agent-Client Protocol) Adapter +//! +//! This module provides translation from ACP JSON-RPC notifications to Dirigent protocol events. +//! It's designed to work with ACP-compliant agents that follow the JSON-RPC 2.0 notification pattern. +//! +//! # Architecture +//! +//! The adapter translates ACP notifications like: +//! - `session/update` - Session updates with various content types +//! +//! Into Dirigent protocol events that can be consumed by the UI and other components. +//! +//! # Supported Update Types +//! +//! - `agent_message_chunk` - Agent response content +//! - `user_message_chunk` - User message echo/history +//! - `thought_chunk` - Agent internal reasoning +//! - `tool_call` - Tool execution started +//! - `tool_result` - Tool execution completed + +use crate::log_utils::format_for_log; +use crate::{ContentBlock, Event, SessionUpdate, ToolCall, ToolCallStatus, ToolCallContent}; +use serde_json::{Map, Value}; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use uuid::Uuid; + +/// ACP notification translation error +#[derive(Debug, thiserror::Error)] +pub enum AcpTranslationError { + /// Missing required field in notification + #[error("Missing required field: {0}")] + MissingField(String), + + /// Invalid value in notification + #[error("Invalid value for {field}: {reason}")] + InvalidValue { field: String, reason: String }, + + /// Unknown notification method + #[error("Unknown notification method: {0}")] + UnknownMethod(String), + + /// Duplicate notification (already processed) + #[error("Duplicate notification")] + Duplicate, +} + +/// Result type for ACP translation operations +pub type AcpResult<T> = Result<T, AcpTranslationError>; + +/// ACP notification adapter +/// +/// Translates ACP JSON-RPC notifications to Dirigent protocol events. +/// +/// # State Management +/// +/// The adapter maintains internal state for: +/// - Message tracking (to detect duplicates) +/// - Session metadata +/// - Tool execution state +/// +/// # Thread Safety +/// +/// AcpAdapter is designed to be used from a single task (the connector's event loop). +/// It uses Arc<RwLock> for state to support potential future multi-threaded scenarios. +#[derive(Debug)] +pub struct AcpAdapter { + /// Processed notification IDs (for deduplication) + processed_notifications: Arc<RwLock<HashMap<String, ()>>>, + + /// Active sessions (for tracking metadata - future use) + sessions: Arc<RwLock<HashMap<String, SessionMetadata>>>, +} + +/// Session metadata tracked by the adapter (for future use) +#[derive(Debug, Clone)] +#[allow(dead_code)] +struct SessionMetadata { + /// Session ID + id: String, + /// Session title (if available) + title: Option<String>, + /// Current working directory + cwd: String, + /// Model being used + model: Option<String>, +} + +impl AcpAdapter { + /// Create a new ACP adapter + pub fn new() -> Self { + Self { + processed_notifications: Arc::new(RwLock::new(HashMap::new())), + sessions: Arc::new(RwLock::new(HashMap::new())), + } + } + + /// Translate an ACP notification to a Dirigent event + /// + /// # Arguments + /// + /// * `notification` - JSON-RPC notification object + /// + /// # Returns + /// + /// Dirigent Event or translation error. Returns `Err(Duplicate)` if the notification + /// has already been processed (can be ignored). + pub async fn translate_notification(&self, notification: Value) -> AcpResult<Event> { + // Extract method + let method = notification + .get("method") + .and_then(|m| m.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("method".to_string()))?; + + // Extract params + let params = notification + .get("params") + .ok_or_else(|| AcpTranslationError::MissingField("params".to_string()))?; + + tracing::debug!( + method = method, + params = %format_for_log(params), + "ACP notification received" + ); + + // Route based on method + match method { + "session/update" => self.translate_session_update(params).await, + _ => { + // Unknown method - log and skip + tracing::warn!(method = method, "Unknown ACP notification method"); + Err(AcpTranslationError::UnknownMethod(method.to_string())) + } + } + } + + /// Translate a session/update notification + async fn translate_session_update(&self, params: &Value) -> AcpResult<Event> { + // Extract session ID + let session_id = params + .get("sessionId") + .and_then(|s| s.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("sessionId".to_string()))? + .to_string(); + + tracing::info!(session_id = %session_id, "🔍 Translating session/update notification"); + + // Extract update object + let update_obj = params + .get("update") + .and_then(|u| u.as_object()) + .ok_or_else(|| AcpTranslationError::MissingField("update".to_string()))?; + + // Determine update type - check both "type" and "sessionUpdate" fields for compatibility + let update_type = update_obj + .get("type") + .or_else(|| update_obj.get("sessionUpdate")) + .and_then(|t| t.as_str()); + + // If no type field found, log the notification structure and skip it + let update_type = match update_type { + Some(t) => t, + None => { + tracing::warn!( + update_obj = ?update_obj, + "⚠️ session/update notification missing type field, skipping" + ); + // Return a special error that can be handled gracefully + return Err(AcpTranslationError::InvalidValue { + field: "update.type".to_string(), + reason: "Missing type field (might be a different notification variant)".to_string(), + }); + } + }; + + tracing::info!(update_type = %update_type, "📝 Processing update type"); + + // Route to variant-specific handler + let session_update = match update_type { + "agent_message_chunk" => self.translate_agent_message_chunk(update_obj)?, + "user_message_chunk" => self.translate_user_message_chunk(update_obj)?, + "thought_chunk" => self.translate_thought_chunk(update_obj)?, + "tool_call" => self.translate_tool_call(update_obj)?, + "tool_result" => self.translate_tool_result(update_obj)?, + "tool_call_update" => self.translate_tool_call_update(update_obj)?, + unknown => { + tracing::warn!( + update_type = unknown, + "⚠️ Unknown session update type, forwarding as Unknown variant for pass-through" + ); + // Forward unknown types as raw JSON for forward compatibility + SessionUpdate::Unknown { + data: serde_json::Value::Object(update_obj.clone()), + } + } + }; + + tracing::info!(session_id = %session_id, "✅ Successfully translated session/update"); + + // Note: connector_id will be set by the connector when emitting this event + Ok(Event::SessionUpdate { + connector_id: String::new(), // Placeholder - connector will set this + session_id, + update: session_update, + }) + } + + /// Translate content block from ACP format to Dirigent format + fn translate_content_block(&self, content: &Value) -> AcpResult<ContentBlock> { + let content_type = content + .get("type") + .and_then(|t| t.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("content.type".to_string()))?; + + match content_type { + "text" => { + let text = content + .get("text") + .and_then(|t| t.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("content.text".to_string()))? + .to_string(); + Ok(ContentBlock::Text { text }) + } + unknown => { + tracing::warn!( + content_type = unknown, + "Unknown content type, falling back to text representation" + ); + // Fallback: represent unknown content as text + Ok(ContentBlock::Text { + text: format!("[Unknown content type: {}]", serde_json::to_string(content).unwrap_or_default()), + }) + } + } + } + + /// Translate content array from ACP format to ToolCallContent + /// + /// Handles both wrapped ACP format and backward compatibility with raw ContentBlock format. + fn translate_tool_call_content(&self, content: &[Value]) -> Vec<ToolCallContent> { + content + .iter() + .filter_map(|item| { + // Try to parse as ToolCallContent (ACP format with wrapper) + serde_json::from_value::<ToolCallContent>(item.clone()) + .ok() + .or_else(|| { + // Fallback: try as ContentBlock (backward compatibility) + serde_json::from_value::<ContentBlock>(item.clone()) + .ok() + .map(ToolCallContent::from_content_block) + }) + }) + .collect() + } + + /// Translate agent_message_chunk update + fn translate_agent_message_chunk(&self, update: &Map<String, Value>) -> AcpResult<SessionUpdate> { + // messageId is optional - some agents (like Claude) don't send it in chunks + // The connector will replace it with a Dirigent message_id anyway + let message_id = update + .get("messageId") + .and_then(|m| m.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("tmp-{}", uuid::Uuid::new_v4())); + + let content = update + .get("content") + .ok_or_else(|| AcpTranslationError::MissingField("content".to_string()))?; + + let content = self.translate_content_block(content)?; + + Ok(SessionUpdate::AgentMessageChunk { + message_id, + content, + _meta: None, + }) + } + + /// Translate user_message_chunk update + fn translate_user_message_chunk(&self, update: &Map<String, Value>) -> AcpResult<SessionUpdate> { + // messageId is optional - some agents don't send it in chunks + let message_id = update + .get("messageId") + .and_then(|m| m.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("tmp-{}", Uuid::new_v4())); + + let content = update + .get("content") + .ok_or_else(|| AcpTranslationError::MissingField("content".to_string()))?; + + let content = self.translate_content_block(content)?; + + Ok(SessionUpdate::UserMessageChunk { + message_id, + content, + _meta: None, + }) + } + + /// Translate thought_chunk update + fn translate_thought_chunk(&self, update: &Map<String, Value>) -> AcpResult<SessionUpdate> { + // messageId is optional - some agents don't send it in chunks + let message_id = update + .get("messageId") + .and_then(|m| m.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("tmp-{}", Uuid::new_v4())); + + let content = update + .get("content") + .ok_or_else(|| AcpTranslationError::MissingField("content".to_string()))?; + + let content = self.translate_content_block(content)?; + + Ok(SessionUpdate::AgentThoughtChunk { + message_id, + content, + _meta: None, + }) + } + + /// Translate tool_call update + /// + /// Supports two formats: + /// 1. Nested format (backward compatibility): `{ messageId, toolCall: {...} }` + /// 2. Flat format (Claude): `{ toolCallId, status, title, rawInput, content, _meta }` + fn translate_tool_call(&self, update: &Map<String, Value>) -> AcpResult<SessionUpdate> { + // Make messageId optional - generate if missing (Claude doesn't send it) + let message_id = update + .get("messageId") + .and_then(|m| m.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("tool-{}", Uuid::new_v4())); + + // Try nested toolCall format first (backward compatibility), then flat structure (Claude) + let tool_call = if let Some(tool_call_obj) = update.get("toolCall").and_then(|t| t.as_object()) { + // Original nested format + self.translate_tool_call_object(tool_call_obj)? + } else { + // Claude's flat format - extract from top level + self.translate_tool_call_flat(update)? + }; + + Ok(SessionUpdate::ToolCall { + message_id, + tool_call, + _meta: None, + }) + } + + /// Translate tool call from Claude's flat structure + /// + /// Claude sends tool calls with fields at the top level: + /// ```json + /// { + /// "sessionUpdate": "tool_call", + /// "toolCallId": "toolu_012...", + /// "title": "grep \"undefined\"", + /// "kind": "search", + /// "rawInput": {}, + /// "status": "pending", + /// "content": [], + /// "_meta": {"claudeCode": {"toolName": "Grep"}} + /// } + /// ``` + fn translate_tool_call_flat(&self, update: &Map<String, Value>) -> AcpResult<ToolCall> { + let id = update + .get("toolCallId") + .and_then(|i| i.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("toolCallId".to_string()))? + .to_string(); + + // Get tool name from _meta.claudeCode.toolName, then kind, then title prefix, then "unknown" + let tool_name = update + .get("_meta") + .and_then(|m| m.get("claudeCode")) + .and_then(|c| c.get("toolName")) + .and_then(|n| n.as_str()) + .or_else(|| update.get("kind").and_then(|k| k.as_str())) + .or_else(|| { + update + .get("title") + .and_then(|t| t.as_str()) + .and_then(|t| t.split_whitespace().next()) + }) + .unwrap_or("unknown") + .to_string(); + + // Get status, defaulting to Pending if missing + let status = update + .get("status") + .and_then(|s| s.as_str()) + .map(|s| self.translate_tool_status(s)) + .transpose()? + .unwrap_or(ToolCallStatus::Pending); + + let title = update + .get("title") + .and_then(|t| t.as_str()) + .map(String::from); + + // Claude uses "rawInput" instead of "input" + let raw_input = update.get("rawInput").cloned(); + + // Translate content array if present + let content = update + .get("content") + .and_then(|c| c.as_array()) + .map(|arr| self.translate_tool_call_content(arr)) + .unwrap_or_default(); + + // Extract and store kind field in metadata for later reconstruction + let metadata = if let Some(kind) = update.get("kind") { + let mut meta_map = serde_json::Map::new(); + meta_map.insert("acp_kind".to_string(), kind.clone()); + Some(Value::Object(meta_map)) + } else { + None + }; + + Ok(ToolCall { + id, + tool_name, + status, + content, + raw_input, + raw_output: None, + title, + error: None, + metadata, + origin: None, // Will be set by connector based on source + }) + } + + /// Translate tool_result update + /// + /// **Note:** Claude Code does NOT use this notification type. Instead, it sends + /// `tool_call_update` with `status: "completed"` to indicate tool completion. + /// This handler is maintained for compatibility with other ACP-compliant agents + /// that may follow the traditional tool_call → tool_result pattern. + /// + /// Expected format (traditional ACP): + /// ```json + /// { + /// "type": "tool_result", + /// "messageId": "msg-123", + /// "toolCallId": "call-456", + /// "result": { + /// "output": "...", + /// "error": "..." + /// } + /// } + /// ``` + /// + /// See `translate_tool_call_update()` for Claude Code's actual completion mechanism. + fn translate_tool_result(&self, update: &Map<String, Value>) -> AcpResult<SessionUpdate> { + let message_id = update + .get("messageId") + .and_then(|m| m.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("messageId".to_string()))? + .to_string(); + + let tool_call_id = update + .get("toolCallId") + .and_then(|t| t.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("toolCallId".to_string()))? + .to_string(); + + let result_obj = update + .get("result") + .and_then(|r| r.as_object()) + .ok_or_else(|| AcpTranslationError::MissingField("result".to_string()))?; + + // Determine status based on error presence + let status = if result_obj.contains_key("error") { + ToolCallStatus::Error + } else { + ToolCallStatus::Completed + }; + + // Extract content from output and wrap in ToolCallContent + let content = if let Some(output) = result_obj.get("output") { + if let Some(text) = output.as_str() { + vec![ToolCallContent::from_content_block(ContentBlock::Text { + text: text.to_string() + })] + } else { + // Output is structured data - represent as JSON text + vec![ToolCallContent::from_content_block(ContentBlock::Text { + text: serde_json::to_string_pretty(output).unwrap_or_default(), + })] + } + } else { + vec![] + }; + + let error = result_obj + .get("error") + .and_then(|e| e.as_str()) + .map(String::from); + + // Note: tool_name is not provided in result, UI will need to merge with original tool_call + let tool_call = ToolCall { + id: tool_call_id.clone(), + tool_name: String::new(), // Will be populated by UI merging with original + status, + content, + raw_input: None, + raw_output: result_obj.get("output").cloned(), + title: None, + error, + metadata: None, + origin: None, // Will be set by connector based on source + }; + + Ok(SessionUpdate::ToolCallUpdate { + message_id, + tool_call_id, + tool_call, + _meta: None, + }) + } + + /// Translate tool_call_update notification + fn translate_tool_call_update(&self, update: &Map<String, Value>) -> AcpResult<SessionUpdate> { + // Make messageId optional - generate if missing (Claude doesn't send it) + let message_id = update + .get("messageId") + .and_then(|m| m.as_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| format!("tool-{}", Uuid::new_v4())); + + let tool_call_id = update + .get("toolCallId") + .and_then(|t| t.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("toolCallId".to_string()))? + .to_string(); + + // Try nested toolCall format first (backward compatibility), then flat structure (Claude) + let tool_call = if let Some(tool_call_obj) = update.get("toolCall").and_then(|t| t.as_object()) { + // Original nested format + self.translate_tool_call_object(tool_call_obj)? + } else { + // Claude's flat format - extract from top level + self.translate_tool_call_update_flat(update)? + }; + + // Preserve metadata from _meta field (includes toolResponse for progress tracking) + let _meta = update.get("_meta") + .and_then(|m| serde_json::from_value(m.clone()).ok()); + + Ok(SessionUpdate::ToolCallUpdate { + message_id, + tool_call_id, + tool_call, + _meta, + }) + } + + /// Translate tool call update from Claude's flat structure + /// + /// Claude sends tool call updates with fields at the top level: + /// ```json + /// { + /// "sessionUpdate": "tool_call_update", + /// "toolCallId": "toolu_012...", + /// "status": "completed", + /// "content": [{"type": "content", "content": {"type": "text", "text": "..."}}], + /// "_meta": {"claudeCode": {"toolName": "Grep", "toolResponse": {...}}} + /// } + /// ``` + fn translate_tool_call_update_flat(&self, update: &Map<String, Value>) -> AcpResult<ToolCall> { + let id = update + .get("toolCallId") + .and_then(|i| i.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("toolCallId".to_string()))? + .to_string(); + + // Get tool name from _meta.claudeCode.toolName, then kind, then title prefix, then "unknown" + let tool_name = update + .get("_meta") + .and_then(|m| m.get("claudeCode")) + .and_then(|c| c.get("toolName")) + .and_then(|n| n.as_str()) + .or_else(|| update.get("kind").and_then(|k| k.as_str())) + .or_else(|| { + update + .get("title") + .and_then(|t| t.as_str()) + .and_then(|t| t.split_whitespace().next()) + }) + .unwrap_or("unknown") + .to_string(); + + // Get status, defaulting to Running if missing (updates typically have status) + let status = update + .get("status") + .and_then(|s| s.as_str()) + .map(|s| self.translate_tool_status(s)) + .transpose()? + .unwrap_or(ToolCallStatus::Running); + + let title = update + .get("title") + .and_then(|t| t.as_str()) + .map(String::from); + + // Claude uses "rawInput" instead of "input" + let raw_input = update.get("rawInput").cloned(); + + // Extract raw_output from toolResponse if present (for completed updates) + let raw_output = update + .get("_meta") + .and_then(|m| m.get("claudeCode")) + .and_then(|c| c.get("toolResponse")) + .cloned(); + + // Translate content array if present + let content = update + .get("content") + .and_then(|c| c.as_array()) + .map(|arr| self.translate_tool_call_content(arr)) + .unwrap_or_default(); + + Ok(ToolCall { + id, + tool_name, + status, + content, + raw_input, + raw_output, + title, + error: None, + metadata: None, + origin: None, // Will be set by connector based on source + }) + } + + /// Translate a tool call object to ToolCall struct + fn translate_tool_call_object(&self, obj: &Map<String, Value>) -> AcpResult<ToolCall> { + let id = obj + .get("id") + .and_then(|i| i.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("toolCall.id".to_string()))? + .to_string(); + + let tool_name = obj + .get("name") + .and_then(|n| n.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("toolCall.name".to_string()))? + .to_string(); + + let status_str = obj + .get("status") + .and_then(|s| s.as_str()) + .ok_or_else(|| AcpTranslationError::MissingField("toolCall.status".to_string()))?; + + let status = self.translate_tool_status(status_str)?; + + let title = obj + .get("title") + .and_then(|t| t.as_str()) + .map(String::from); + + let raw_input = obj.get("input").cloned(); + + Ok(ToolCall { + id, + tool_name, + status, + content: vec![], // Content comes in tool_result updates + raw_input, + raw_output: None, + title, + error: None, + metadata: None, + origin: None, // Will be set by connector based on source + }) + } + + /// Translate tool status string to ToolCallStatus enum + fn translate_tool_status(&self, status: &str) -> AcpResult<ToolCallStatus> { + match status { + "pending" => Ok(ToolCallStatus::Pending), + "running" => Ok(ToolCallStatus::Running), + "completed" => Ok(ToolCallStatus::Completed), + "error" => Ok(ToolCallStatus::Error), + unknown => Err(AcpTranslationError::InvalidValue { + field: "status".to_string(), + reason: format!("Unknown tool status: {}", unknown), + }), + } + } + + /// Clear all adapter state + /// + /// Called when reconnecting to ensure clean state. + pub async fn clear(&self) { + self.processed_notifications.write().await.clear(); + self.sessions.write().await.clear(); + } +} + +impl Default for AcpAdapter { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + // ADAPT-001: Adapter creation test + #[tokio::test] + async fn test_adapter_creation() { + let adapter = AcpAdapter::new(); + assert_eq!(adapter.processed_notifications.read().await.len(), 0); + assert_eq!(adapter.sessions.read().await.len(), 0); + } + + // ADAPT-002: Dispatcher tests + #[tokio::test] + async fn test_missing_method() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "params": {} + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AcpTranslationError::MissingField(_) + )); + } + + #[tokio::test] + async fn test_missing_params() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update" + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AcpTranslationError::MissingField(_) + )); + } + + // ADAPT-015: Unknown method test + #[tokio::test] + async fn test_unknown_method() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "unknown/method", + "params": {} + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AcpTranslationError::UnknownMethod(_) + )); + } + + // ADAPT-003, ADAPT-016: Session update tests with malformed params + #[tokio::test] + async fn test_session_update_missing_session_id() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "update": { + "type": "agent_message_chunk", + "messageId": "msg-1", + "content": { + "type": "text", + "text": "Hello" + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AcpTranslationError::MissingField(_) + )); + } + + #[tokio::test] + async fn test_session_update_missing_update_object() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-123" + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AcpTranslationError::MissingField(_) + )); + } + + #[tokio::test] + async fn test_session_update_unknown_type() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-123", + "update": { + "type": "unknown_type", + "messageId": "msg-1" + } + } + }); + + // Unknown types are now forwarded as SessionUpdate::Unknown for forward compatibility + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { session_id, update, .. } => { + assert_eq!(session_id, "session-123"); + match update { + SessionUpdate::Unknown { data } => { + // Verify the original data is preserved + assert_eq!(data.get("type").and_then(|v| v.as_str()), Some("unknown_type")); + assert_eq!(data.get("messageId").and_then(|v| v.as_str()), Some("msg-1")); + } + _ => panic!("Expected Unknown variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + // ADAPT-004: Content block translation tests + #[tokio::test] + async fn test_translate_text_content() { + let adapter = AcpAdapter::new(); + + let content = json!({ + "type": "text", + "text": "Hello, world!" + }); + + let result = adapter.translate_content_block(&content); + assert!(result.is_ok()); + + match result.unwrap() { + ContentBlock::Text { text } => { + assert_eq!(text, "Hello, world!"); + } + _ => panic!("Expected Text content block"), + } + } + + #[tokio::test] + async fn test_translate_unknown_content_type() { + let adapter = AcpAdapter::new(); + + let content = json!({ + "type": "image", + "url": "https://example.com/image.png" + }); + + let result = adapter.translate_content_block(&content); + assert!(result.is_ok()); + + // Unknown types should fall back to text representation + match result.unwrap() { + ContentBlock::Text { text } => { + assert!(text.contains("Unknown content type")); + } + _ => panic!("Expected Text content block fallback"), + } + } + + #[tokio::test] + async fn test_translate_malformed_content() { + let adapter = AcpAdapter::new(); + + let content = json!({ + "text": "Missing type field" + }); + + let result = adapter.translate_content_block(&content); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AcpTranslationError::MissingField(_) + )); + } + + // ADAPT-010: Agent message chunk tests + #[tokio::test] + async fn test_agent_message_chunk_translation() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-123", + "update": { + "type": "agent_message_chunk", + "messageId": "msg-1", + "content": { + "type": "text", + "text": "Hello from agent" + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-123"); + match update { + SessionUpdate::AgentMessageChunk { message_id, content, _meta } => { + assert_eq!(message_id, "msg-1"); + match content { + ContentBlock::Text { text } => { + assert_eq!(text, "Hello from agent"); + } + _ => panic!("Expected Text content"), + } + assert!(_meta.is_none()); + } + _ => panic!("Expected AgentMessageChunk"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_agent_message_chunk_missing_message_id() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-123", + "update": { + "type": "agent_message_chunk", + "content": { + "type": "text", + "text": "Missing message ID" + } + } + } + }); + + // messageId is now optional - should succeed with generated ID + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { session_id, update, .. } => { + assert_eq!(session_id, "session-123"); + match update { + SessionUpdate::AgentMessageChunk { message_id, .. } => { + // Generated message ID should have "tmp-" prefix + assert!(message_id.starts_with("tmp-")); + } + _ => panic!("Expected AgentMessageChunk"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_agent_message_chunk_missing_content() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-123", + "update": { + "type": "agent_message_chunk", + "messageId": "msg-1" + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err()); + } + + // ADAPT-011: User message chunk tests + #[tokio::test] + async fn test_user_message_chunk_translation() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-456", + "update": { + "type": "user_message_chunk", + "messageId": "msg-2", + "content": { + "type": "text", + "text": "Hello from user" + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-456"); + match update { + SessionUpdate::UserMessageChunk { message_id, content, _meta } => { + assert_eq!(message_id, "msg-2"); + match content { + ContentBlock::Text { text } => { + assert_eq!(text, "Hello from user"); + } + _ => panic!("Expected Text content"), + } + assert!(_meta.is_none()); + } + _ => panic!("Expected UserMessageChunk"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + // ADAPT-012: Thought chunk tests + #[tokio::test] + async fn test_thought_chunk_translation() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-789", + "update": { + "type": "thought_chunk", + "messageId": "msg-3", + "content": { + "type": "text", + "text": "Analyzing the problem..." + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-789"); + match update { + SessionUpdate::AgentThoughtChunk { message_id, content, _meta } => { + assert_eq!(message_id, "msg-3"); + match content { + ContentBlock::Text { text } => { + assert_eq!(text, "Analyzing the problem..."); + } + _ => panic!("Expected Text content"), + } + assert!(_meta.is_none()); + } + _ => panic!("Expected AgentThoughtChunk"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + // ADAPT-013: Tool call translation tests + #[tokio::test] + async fn test_tool_call_translation_pending() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-abc", + "update": { + "type": "tool_call", + "messageId": "msg-4", + "toolCall": { + "id": "call-1", + "name": "bash", + "status": "pending", + "title": "Run bash command", + "input": { + "command": "ls -la" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-abc"); + match update { + SessionUpdate::ToolCall { message_id, tool_call, _meta } => { + assert_eq!(message_id, "msg-4"); + assert_eq!(tool_call.id, "call-1"); + assert_eq!(tool_call.tool_name, "bash"); + assert_eq!(tool_call.status, ToolCallStatus::Pending); + assert_eq!(tool_call.title, Some("Run bash command".to_string())); + assert!(tool_call.raw_input.is_some()); + assert!(_meta.is_none()); + } + _ => panic!("Expected ToolCall"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_tool_call_translation_running() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-def", + "update": { + "type": "tool_call", + "messageId": "msg-5", + "toolCall": { + "id": "call-2", + "name": "read_file", + "status": "running" + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-def"); + match update { + SessionUpdate::ToolCall { message_id, tool_call, .. } => { + assert_eq!(message_id, "msg-5"); + assert_eq!(tool_call.id, "call-2"); + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert!(tool_call.title.is_none()); + assert!(tool_call.raw_input.is_none()); + } + _ => panic!("Expected ToolCall"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_tool_call_missing_fields() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-ghi", + "update": { + "type": "tool_call", + "messageId": "msg-6", + "toolCall": { + "id": "call-3" + // Missing name and status + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_tool_call_invalid_status() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-jkl", + "update": { + "type": "tool_call", + "messageId": "msg-7", + "toolCall": { + "id": "call-4", + "name": "test", + "status": "invalid_status" + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AcpTranslationError::InvalidValue { .. } + )); + } + + // ADAPT-019: Claude flat format tool_call tests + #[tokio::test] + async fn test_tool_call_claude_flat_format() { + let adapter = AcpAdapter::new(); + + // Claude's actual notification format: flat structure, no messageId + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-claude", + "update": { + "sessionUpdate": "tool_call", + "toolCallId": "toolu_012TKJvVXSYFFARfUWj9okv8", + "title": "grep \"undefined\"", + "kind": "search", + "rawInput": { + "pattern": "undefined", + "path": "." + }, + "status": "pending", + "content": [], + "_meta": { + "claudeCode": { + "toolName": "Grep" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-claude"); + match update { + SessionUpdate::ToolCall { message_id, tool_call, _meta } => { + // Generated message ID should have "tool-" prefix + assert!(message_id.starts_with("tool-")); + assert_eq!(tool_call.id, "toolu_012TKJvVXSYFFARfUWj9okv8"); + assert_eq!(tool_call.tool_name, "Grep"); + assert_eq!(tool_call.status, ToolCallStatus::Pending); + assert_eq!(tool_call.title, Some("grep \"undefined\"".to_string())); + assert!(tool_call.raw_input.is_some()); + assert!(_meta.is_none()); + + // T001: Verify kind field is stored in metadata + assert!(tool_call.metadata.is_some(), "Metadata should be present when kind field exists"); + if let Some(metadata) = tool_call.metadata { + assert_eq!( + metadata.get("acp_kind"), + Some(&json!("search")), + "kind field should be stored as 'acp_kind' in metadata" + ); + } + } + _ => panic!("Expected ToolCall"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_tool_call_claude_flat_format_without_tool_name() { + let adapter = AcpAdapter::new(); + + // Claude format without _meta.claudeCode.toolName - should fall back to "unknown" + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-claude-2", + "update": { + "sessionUpdate": "tool_call", + "toolCallId": "toolu_123456", + "status": "running", + "content": [] + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-claude-2"); + match update { + SessionUpdate::ToolCall { message_id, tool_call, .. } => { + assert!(message_id.starts_with("tool-")); + assert_eq!(tool_call.id, "toolu_123456"); + assert_eq!(tool_call.tool_name, "unknown"); + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert!(tool_call.title.is_none()); + + // T001: Verify that missing kind field doesn't cause errors + assert!(tool_call.metadata.is_none(), "Metadata should be None when kind field is missing"); + } + _ => panic!("Expected ToolCall"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_tool_call_claude_flat_format_with_content() { + let adapter = AcpAdapter::new(); + + // Claude format with content array (wrapped format) + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-claude-3", + "update": { + "sessionUpdate": "tool_call", + "toolCallId": "toolu_789", + "status": "completed", + "content": [ + { + "type": "content", + "content": { + "type": "text", + "text": "Search results here" + } + } + ], + "_meta": { + "claudeCode": { + "toolName": "Grep" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-claude-3"); + match update { + SessionUpdate::ToolCall { tool_call, .. } => { + assert_eq!(tool_call.id, "toolu_789"); + assert_eq!(tool_call.tool_name, "Grep"); + assert_eq!(tool_call.status, ToolCallStatus::Completed); + assert_eq!(tool_call.content.len(), 1); + match &tool_call.content[0] { + ToolCallContent::Content { content } => match content { + ContentBlock::Text { text } => { + assert_eq!(text, "Search results here"); + } + _ => panic!("Expected Text content"), + } + _ => panic!("Expected Content wrapper"), + } + } + _ => panic!("Expected ToolCall"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + // ADAPT-014: Tool result translation tests + #[tokio::test] + async fn test_tool_result_translation_completed() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-mno", + "update": { + "type": "tool_result", + "messageId": "msg-8", + "toolCallId": "call-5", + "result": { + "output": "Command completed successfully" + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-mno"); + match update { + SessionUpdate::ToolCallUpdate { message_id, tool_call_id, tool_call, _meta } => { + assert_eq!(message_id, "msg-8"); + assert_eq!(tool_call_id, "call-5"); + assert_eq!(tool_call.id, "call-5"); + assert_eq!(tool_call.status, ToolCallStatus::Completed); + assert_eq!(tool_call.content.len(), 1); + match &tool_call.content[0] { + ToolCallContent::Content { content } => match content { + ContentBlock::Text { text } => { + assert_eq!(text, "Command completed successfully"); + } + _ => panic!("Expected Text content"), + } + _ => panic!("Expected Content wrapper"), + } + assert!(tool_call.error.is_none()); + assert!(_meta.is_none()); + } + _ => panic!("Expected ToolCallUpdate"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_tool_result_translation_error() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-pqr", + "update": { + "type": "tool_result", + "messageId": "msg-9", + "toolCallId": "call-6", + "result": { + "error": "Permission denied" + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-pqr"); + match update { + SessionUpdate::ToolCallUpdate { message_id, tool_call_id, tool_call, .. } => { + assert_eq!(message_id, "msg-9"); + assert_eq!(tool_call_id, "call-6"); + assert_eq!(tool_call.status, ToolCallStatus::Error); + assert_eq!(tool_call.error, Some("Permission denied".to_string())); + } + _ => panic!("Expected ToolCallUpdate"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_tool_result_with_structured_output() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-stu", + "update": { + "type": "tool_result", + "messageId": "msg-10", + "toolCallId": "call-7", + "result": { + "output": { + "files": ["file1.txt", "file2.txt"], + "count": 2 + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-stu"); + match update { + SessionUpdate::ToolCallUpdate { tool_call, .. } => { + assert_eq!(tool_call.status, ToolCallStatus::Completed); + assert_eq!(tool_call.content.len(), 1); + // Structured output should be converted to pretty JSON text + match &tool_call.content[0] { + ToolCallContent::Content { content } => match content { + ContentBlock::Text { text } => { + assert!(text.contains("files")); + assert!(text.contains("file1.txt")); + } + _ => panic!("Expected Text content"), + } + _ => panic!("Expected Content wrapper"), + } + } + _ => panic!("Expected ToolCallUpdate"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_tool_result_missing_fields() { + let adapter = AcpAdapter::new(); + + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-vwx", + "update": { + "type": "tool_result", + "messageId": "msg-11" + // Missing toolCallId and result + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err()); + } + + // Adapter state management tests + #[tokio::test] + async fn test_adapter_clear() { + let adapter = AcpAdapter::new(); + + // Add some state + adapter + .processed_notifications + .write() + .await + .insert("notif-1".to_string(), ()); + + assert_eq!(adapter.processed_notifications.read().await.len(), 1); + + // Clear + adapter.clear().await; + + assert_eq!(adapter.processed_notifications.read().await.len(), 0); + } + + // Tool status translation tests + #[tokio::test] + async fn test_translate_tool_status() { + let adapter = AcpAdapter::new(); + + assert_eq!(adapter.translate_tool_status("pending").unwrap(), ToolCallStatus::Pending); + assert_eq!(adapter.translate_tool_status("running").unwrap(), ToolCallStatus::Running); + assert_eq!(adapter.translate_tool_status("completed").unwrap(), ToolCallStatus::Completed); + assert_eq!(adapter.translate_tool_status("error").unwrap(), ToolCallStatus::Error); + + let result = adapter.translate_tool_status("invalid"); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AcpTranslationError::InvalidValue { .. } + )); + } + + // ADAPT-017: Missing type field handling + #[tokio::test] + async fn test_session_update_missing_type_field() { + let adapter = AcpAdapter::new(); + + // Notification without type field (like some Claude Code notifications) + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-xyz", + "update": { + "someOtherField": "someValue" + // Missing "type" field + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + AcpTranslationError::InvalidValue { .. } + )); + } + + // ADAPT-018: Alternative type field name (sessionUpdate) + #[tokio::test] + async fn test_session_update_with_session_update_field() { + let adapter = AcpAdapter::new(); + + // Notification with "sessionUpdate" instead of "type" + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "session-abc", + "update": { + "sessionUpdate": "agent_message_chunk", + "messageId": "msg-1", + "content": { + "type": "text", + "text": "Hello" + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "session-abc"); + match update { + SessionUpdate::AgentMessageChunk { message_id, .. } => { + assert_eq!(message_id, "msg-1"); + } + _ => panic!("Expected AgentMessageChunk"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + // T007: Test translate_tool_call() without messageId (Claude format) + // Tests Claude's flat format with no messageId and _meta.claudeCode.toolName + #[tokio::test] + async fn test_t007_translate_tool_call_without_message_id_claude_format() { + let adapter = AcpAdapter::new(); + + // Claude's format: flat structure, no messageId, tool name in _meta + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t007", + "update": { + "sessionUpdate": "tool_call", + "toolCallId": "toolu_test_007", + "status": "pending", + "title": "grep \"test pattern\"", + "kind": "search", + "rawInput": { + "pattern": "test pattern", + "path": "/test/path" + }, + "content": [], + "_meta": { + "claudeCode": { + "toolName": "Grep" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok(), "Translation should succeed for Claude format"); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "test-session-t007"); + match update { + SessionUpdate::ToolCall { message_id, tool_call, _meta } => { + // T007: Verify generated message ID format matches tool-{UUID} pattern + assert!( + message_id.starts_with("tool-"), + "Generated message ID should start with 'tool-', got: {}", + message_id + ); + assert_eq!( + message_id.len(), + 41, // "tool-" (5) + UUID (36) + "Message ID should be 'tool-' + UUID format" + ); + + // Verify tool call details + assert_eq!(tool_call.id, "toolu_test_007"); + assert_eq!(tool_call.tool_name, "Grep"); + assert_eq!(tool_call.status, ToolCallStatus::Pending); + assert_eq!(tool_call.title, Some("grep \"test pattern\"".to_string())); + assert!(tool_call.raw_input.is_some()); + assert_eq!(tool_call.content.len(), 0); + assert!(_meta.is_none()); + } + _ => panic!("Expected ToolCall variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_t007_tool_call_different_statuses_no_message_id() { + let adapter = AcpAdapter::new(); + + // Test with "running" status + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t007-running", + "update": { + "sessionUpdate": "tool_call", + "toolCallId": "toolu_running", + "status": "running", + "content": [], + "_meta": { + "claudeCode": { + "toolName": "Read" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { update, .. } => { + match update { + SessionUpdate::ToolCall { tool_call, .. } => { + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert_eq!(tool_call.tool_name, "Read"); + } + _ => panic!("Expected ToolCall variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + // T008: Test translate_tool_call() with messageId (backward compatibility) + // Tests original nested format with explicit messageId + #[tokio::test] + async fn test_t008_translate_tool_call_with_message_id_nested_format() { + let adapter = AcpAdapter::new(); + + // Original nested format with explicit messageId + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t008", + "update": { + "type": "tool_call", + "messageId": "explicit-msg-id-123", + "toolCall": { + "id": "tool-nested-001", + "name": "bash", + "status": "pending", + "title": "Run command", + "input": { + "command": "ls -la" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok(), "Translation should succeed for nested format"); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "test-session-t008"); + match update { + SessionUpdate::ToolCall { message_id, tool_call, _meta } => { + // T008: Verify explicit messageId is preserved + assert_eq!(message_id, "explicit-msg-id-123"); + + // Verify tool call details from nested format + assert_eq!(tool_call.id, "tool-nested-001"); + assert_eq!(tool_call.tool_name, "bash"); + assert_eq!(tool_call.status, ToolCallStatus::Pending); + assert_eq!(tool_call.title, Some("Run command".to_string())); + assert!(tool_call.raw_input.is_some()); + assert!(_meta.is_none()); + } + _ => panic!("Expected ToolCall variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_t008_nested_format_all_statuses() { + let adapter = AcpAdapter::new(); + + // Test all status values with nested format + let statuses = vec![ + ("pending", ToolCallStatus::Pending), + ("running", ToolCallStatus::Running), + ("completed", ToolCallStatus::Completed), + ("error", ToolCallStatus::Error), + ]; + + for (status_str, expected_status) in statuses { + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session", + "update": { + "type": "tool_call", + "messageId": format!("msg-{}", status_str), + "toolCall": { + "id": format!("tool-{}", status_str), + "name": "test_tool", + "status": status_str + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok(), "Failed for status: {}", status_str); + + match result.unwrap() { + Event::SessionUpdate { update, .. } => { + match update { + SessionUpdate::ToolCall { tool_call, .. } => { + assert_eq!(tool_call.status, expected_status, "Wrong status for: {}", status_str); + } + _ => panic!("Expected ToolCall variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + } + + // T009: Test translate_tool_call_update() Claude format + // Tests status transitions, content updates, and metadata preservation + #[tokio::test] + async fn test_t009_tool_call_update_pending_to_running() { + let adapter = AcpAdapter::new(); + + // Update: pending → running + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t009-running", + "update": { + "sessionUpdate": "tool_call_update", + "toolCallId": "toolu_t009", + "status": "running", + "content": [], + "_meta": { + "claudeCode": { + "toolName": "Bash", + "toolResponse": { + "started": true + } + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok(), "Tool call update should succeed"); + + match result.unwrap() { + Event::SessionUpdate { connector_id: _, session_id, update } => { + assert_eq!(session_id, "test-session-t009-running"); + match update { + SessionUpdate::ToolCallUpdate { message_id, tool_call_id, tool_call, _meta } => { + // T009: Verify generated message ID + assert!(message_id.starts_with("tool-")); + assert_eq!(tool_call_id, "toolu_t009"); + assert_eq!(tool_call.id, "toolu_t009"); + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert_eq!(tool_call.tool_name, "Bash"); + + // T009: Verify metadata preservation from _meta.claudeCode.toolResponse + assert!(_meta.is_some(), "Metadata should be preserved"); + } + _ => panic!("Expected ToolCallUpdate variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_t009_tool_call_update_running_to_completed() { + let adapter = AcpAdapter::new(); + + // Update: running → completed with content + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t009-completed", + "update": { + "sessionUpdate": "tool_call_update", + "toolCallId": "toolu_t009_complete", + "status": "completed", + "content": [ + { + "type": "content", + "content": { + "type": "text", + "text": "Tool execution completed successfully" + } + }, + { + "type": "content", + "content": { + "type": "text", + "text": "Output: test result" + } + } + ], + "_meta": { + "claudeCode": { + "toolName": "Grep", + "toolResponse": { + "exitCode": 0, + "duration": 125 + } + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { update, .. } => { + match update { + SessionUpdate::ToolCallUpdate { tool_call, _meta, .. } => { + // T009: Test status transition to completed + assert_eq!(tool_call.status, ToolCallStatus::Completed); + + // T009: Test content updates (Claude wraps content blocks) + assert_eq!(tool_call.content.len(), 2, "Should have 2 content blocks"); + match &tool_call.content[0] { + ToolCallContent::Content { content } => match content { + ContentBlock::Text { text } => { + assert_eq!(text, "Tool execution completed successfully"); + } + _ => panic!("Expected Text content"), + } + _ => panic!("Expected Content wrapper"), + } + match &tool_call.content[1] { + ToolCallContent::Content { content } => match content { + ContentBlock::Text { text } => { + assert_eq!(text, "Output: test result"); + } + _ => panic!("Expected Text content"), + } + _ => panic!("Expected Content wrapper"), + } + + // T009: Verify metadata preservation + assert!(_meta.is_some()); + if let Some(meta) = _meta { + let meta_obj = serde_json::to_value(&meta).unwrap(); + assert!(meta_obj.get("claudeCode").is_some()); + } + } + _ => panic!("Expected ToolCallUpdate variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_t009_tool_call_update_error_status() { + let adapter = AcpAdapter::new(); + + // Update: error status + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t009-error", + "update": { + "sessionUpdate": "tool_call_update", + "toolCallId": "toolu_t009_error", + "status": "error", + "content": [ + { + "type": "content", + "content": { + "type": "text", + "text": "Error: File not found" + } + } + ], + "_meta": { + "claudeCode": { + "toolName": "Read", + "toolResponse": { + "error": "ENOENT" + } + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { update, .. } => { + match update { + SessionUpdate::ToolCallUpdate { tool_call, _meta, .. } => { + assert_eq!(tool_call.status, ToolCallStatus::Error); + assert_eq!(tool_call.tool_name, "Read"); + assert_eq!(tool_call.content.len(), 1); + + // Verify metadata preserved for error case + assert!(_meta.is_some()); + } + _ => panic!("Expected ToolCallUpdate variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_t009_tool_call_update_with_tool_response_data() { + let adapter = AcpAdapter::new(); + + // Test that toolResponse data in _meta.claudeCode is preserved and extracted + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t009-response", + "update": { + "sessionUpdate": "tool_call_update", + "toolCallId": "toolu_t009_response", + "status": "completed", + "content": [], + "_meta": { + "claudeCode": { + "toolName": "Edit", + "toolResponse": { + "linesAdded": 5, + "linesRemoved": 2, + "filePath": "/test/file.rs" + } + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { update, .. } => { + match update { + SessionUpdate::ToolCallUpdate { tool_call, _meta, .. } => { + assert_eq!(tool_call.status, ToolCallStatus::Completed); + + // T009: Verify raw_output extracted from toolResponse + assert!( + tool_call.raw_output.is_some(), + "raw_output should be extracted from toolResponse" + ); + + if let Some(output) = tool_call.raw_output { + assert_eq!(output.get("linesAdded"), Some(&json!(5))); + assert_eq!(output.get("linesRemoved"), Some(&json!(2))); + assert_eq!(output.get("filePath"), Some(&json!("/test/file.rs"))); + } + + // Metadata should also be preserved + assert!(_meta.is_some()); + } + _ => panic!("Expected ToolCallUpdate variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + // T010: Test edge cases + // Missing toolCallId, missing status, missing tool name, empty content array + #[tokio::test] + async fn test_t010_missing_tool_call_id_should_error() { + let adapter = AcpAdapter::new(); + + // Missing toolCallId - should error + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t010-error", + "update": { + "sessionUpdate": "tool_call", + "status": "pending", + "content": [], + "_meta": { + "claudeCode": { + "toolName": "Grep" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err(), "Should error when toolCallId is missing"); + + match result.unwrap_err() { + AcpTranslationError::MissingField(field) => { + assert_eq!(field, "toolCallId"); + } + _ => panic!("Expected MissingField error"), + } + } + + #[tokio::test] + async fn test_t010_missing_status_defaults_to_pending() { + let adapter = AcpAdapter::new(); + + // Missing status - should default to Pending for tool_call + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t010-default", + "update": { + "sessionUpdate": "tool_call", + "toolCallId": "toolu_t010_default", + "content": [], + "_meta": { + "claudeCode": { + "toolName": "Bash" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok(), "Should succeed with default status"); + + match result.unwrap() { + Event::SessionUpdate { update, .. } => { + match update { + SessionUpdate::ToolCall { tool_call, .. } => { + // T010: Verify status defaults to Pending + assert_eq!( + tool_call.status, + ToolCallStatus::Pending, + "Status should default to Pending when missing" + ); + } + _ => panic!("Expected ToolCall variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_t010_missing_status_in_update_defaults_to_running() { + let adapter = AcpAdapter::new(); + + // Missing status in tool_call_update - should default to Running + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t010-update-default", + "update": { + "sessionUpdate": "tool_call_update", + "toolCallId": "toolu_t010_update_default", + "content": [], + "_meta": { + "claudeCode": { + "toolName": "Read" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok(), "Should succeed with default status"); + + match result.unwrap() { + Event::SessionUpdate { update, .. } => { + match update { + SessionUpdate::ToolCallUpdate { tool_call, .. } => { + // T010: Verify status defaults to Running for updates + assert_eq!( + tool_call.status, + ToolCallStatus::Running, + "Status should default to Running when missing in update" + ); + } + _ => panic!("Expected ToolCallUpdate variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_t010_missing_tool_name_uses_unknown() { + let adapter = AcpAdapter::new(); + + // Missing _meta.claudeCode.toolName - should use "unknown" + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t010-unknown", + "update": { + "sessionUpdate": "tool_call", + "toolCallId": "toolu_t010_unknown", + "status": "pending", + "content": [] + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok(), "Should succeed with unknown tool name"); + + match result.unwrap() { + Event::SessionUpdate { update, .. } => { + match update { + SessionUpdate::ToolCall { tool_call, .. } => { + // T010: Verify tool name defaults to "unknown" + assert_eq!( + tool_call.tool_name, + "unknown", + "Tool name should be 'unknown' when _meta.claudeCode.toolName is missing" + ); + } + _ => panic!("Expected ToolCall variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_t010_empty_content_array() { + let adapter = AcpAdapter::new(); + + // Empty content array - should succeed with empty vec + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t010-empty-content", + "update": { + "sessionUpdate": "tool_call", + "toolCallId": "toolu_t010_empty", + "status": "pending", + "content": [], + "_meta": { + "claudeCode": { + "toolName": "Grep" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok(), "Should succeed with empty content array"); + + match result.unwrap() { + Event::SessionUpdate { update, .. } => { + match update { + SessionUpdate::ToolCall { tool_call, .. } => { + // T010: Verify empty content array is handled + assert_eq!( + tool_call.content.len(), + 0, + "Content should be empty vec when content array is empty" + ); + } + _ => panic!("Expected ToolCall variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_t010_partial_meta_structure() { + let adapter = AcpAdapter::new(); + + // Partial _meta structure (has _meta but no claudeCode) + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t010-partial-meta", + "update": { + "sessionUpdate": "tool_call", + "toolCallId": "toolu_t010_partial", + "status": "running", + "content": [], + "_meta": { + "someOtherField": "value" + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_ok(), "Should succeed with partial _meta"); + + match result.unwrap() { + Event::SessionUpdate { update, .. } => { + match update { + SessionUpdate::ToolCall { tool_call, .. } => { + // T010: Should fall back to "unknown" when claudeCode is missing + assert_eq!(tool_call.tool_name, "unknown"); + } + _ => panic!("Expected ToolCall variant"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[tokio::test] + async fn test_t010_invalid_status_value() { + let adapter = AcpAdapter::new(); + + // Invalid status value - should error + let notification = json!({ + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "test-session-t010-invalid-status", + "update": { + "sessionUpdate": "tool_call", + "toolCallId": "toolu_t010_invalid", + "status": "invalid_status_value", + "content": [], + "_meta": { + "claudeCode": { + "toolName": "Grep" + } + } + } + } + }); + + let result = adapter.translate_notification(notification).await; + assert!(result.is_err(), "Should error for invalid status value"); + + match result.unwrap_err() { + AcpTranslationError::InvalidValue { field, .. } => { + assert_eq!(field, "status"); + } + _ => panic!("Expected InvalidValue error"), + } + } +} diff --git a/crates/dirigent_protocol/src/adapters/mod.rs b/crates/dirigent_protocol/src/adapters/mod.rs new file mode 100644 index 0000000..e9d0148 --- /dev/null +++ b/crates/dirigent_protocol/src/adapters/mod.rs @@ -0,0 +1,11 @@ +pub mod acp; + +#[cfg(feature = "adapters")] +pub mod opencode; +#[cfg(feature = "adapters")] +pub mod rest; + +pub use acp::{AcpAdapter, AcpTranslationError}; + +#[cfg(feature = "adapters")] +pub use opencode::{OpenCodeAdapter, TranslationError}; diff --git a/crates/dirigent_protocol/src/adapters/opencode.rs b/crates/dirigent_protocol/src/adapters/opencode.rs new file mode 100644 index 0000000..90c4781 --- /dev/null +++ b/crates/dirigent_protocol/src/adapters/opencode.rs @@ -0,0 +1,1465 @@ +use crate::events::Event; +use crate::types::{ContentBlock, Meta, ProviderMeta, SessionUpdate, ToolCall, ToolCallStatus}; +use crate::{ + Message, MessageMetadata, MessagePart, MessageRole, MessageStatus, Session, SessionMetadata, +}; +use chrono::{DateTime, TimeZone, Utc}; +use opencode_client::types as oc; +use std::collections::{HashMap, HashSet}; +use std::sync::{Arc, Mutex}; + +/// Stateful adapter that tracks message and part IDs to prevent duplicates +#[derive(Clone)] +pub struct OpenCodeAdapter { + state: Arc<Mutex<AdapterState>>, +} + +#[derive(Default)] +struct AdapterState { + /// Track which messages we've sent "started" events for + started_messages: HashSet<String>, + /// Track which messages we've sent "completed" events for + completed_messages: HashSet<String>, + /// Track which parts we've seen (to differentiate new vs update) + seen_parts: HashSet<String>, + /// Track which sessions have had their system message set + sessions_with_system: HashSet<String>, + /// Track tool calls we've seen (for initial ToolCall vs ToolCallUpdate) + tool_calls_seen: HashSet<String>, + /// Track message roles to determine UserMessageChunk vs AgentMessageChunk + message_roles: HashMap<String, MessageRole>, +} + +#[derive(Debug)] +pub enum TranslationError { + UnknownEvent, + UnsupportedEvent, + UnsupportedPartType, + InvalidTimestamp, + MissingField(String), + Duplicate, // Event should be skipped (duplicate) +} + +impl std::fmt::Display for TranslationError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TranslationError::UnknownEvent => write!(f, "Unknown event type"), + TranslationError::UnsupportedEvent => write!(f, "Unsupported event type"), + TranslationError::UnsupportedPartType => write!(f, "Unsupported part type"), + TranslationError::InvalidTimestamp => write!(f, "Invalid timestamp"), + TranslationError::MissingField(field) => write!(f, "Missing required field: {}", field), + TranslationError::Duplicate => write!(f, "Duplicate event (filtered)"), + } + } +} + +impl std::error::Error for TranslationError {} + +impl Default for OpenCodeAdapter { + fn default() -> Self { + Self::new() + } +} + +impl OpenCodeAdapter { + pub fn new() -> Self { + Self { + state: Arc::new(Mutex::new(AdapterState::default())), + } + } + + /// Translate an OpenCode event to a Dirigent event + pub fn translate_event(&self, oc_event: oc::Event) -> Result<Event, TranslationError> { + match oc_event { + oc::Event::SessionCreated { properties } => Ok(Event::SessionCreated { + connector_id: String::new(), // Placeholder - connector will set this + session: Self::translate_session(properties.info)?, + }), + + oc::Event::SessionUpdated { properties } => Ok(Event::SessionUpdated { + connector_id: String::new(), // Placeholder - connector will set this + session: Self::translate_session(properties.info)?, + }), + + oc::Event::SessionDeleted { properties } => Ok(Event::SessionDeleted { + session_id: properties.info.id, + }), + + oc::Event::SessionIdle { properties } => Ok(Event::SessionIdle { + connector_id: String::new(), // Placeholder - connector will set this + session_id: properties.session_id, + }), + + oc::Event::MessageUpdated { properties } => { + let msg_id = match &properties.info { + oc::Message::User(u) => u.id.clone(), + oc::Message::Assistant(a) => a.id.clone(), + }; + + // Extract session_id and system message before translation + let (session_id, system_message) = match &properties.info { + oc::Message::Assistant(a) if !a.system.is_empty() => { + (a.session_id.clone(), a.system.first().cloned()) + } + oc::Message::User(u) => (u.session_id.clone(), None), + _ => (String::new(), None), + }; + + // Check if completed + let is_completed = if let oc::Message::Assistant(ref a) = properties.info { + a.time.completed.is_some() + } else { + true // User messages are always completed + }; + + let msg = Self::translate_message(properties.info)?; + + // Store the message role for later part translation + let mut state = self.state.lock().unwrap(); + state.message_roles.insert(msg_id.clone(), msg.role.clone()); + + // Track system message for this session (only emit once) + if let Some(sys_msg) = system_message { + if !state.sessions_with_system.contains(&session_id) { + state.sessions_with_system.insert(session_id.clone()); + drop(state); // Release lock before returning + + // Return SessionSystemMessageSet event instead of message event + // The message event will come on the next update + return Ok(Event::SessionSystemMessageSet { + session_id, + system_message: sys_msg, + }); + } + } + + if is_completed { + if state.completed_messages.contains(&msg_id) { + return Err(TranslationError::Duplicate); + } + state.completed_messages.insert(msg_id); + Ok(Event::MessageCompleted { + connector_id: String::new(), // Placeholder - connector will set this + message: msg, + }) + } else { + if state.started_messages.contains(&msg_id) { + return Err(TranslationError::Duplicate); + } + state.started_messages.insert(msg_id); + Ok(Event::MessageStarted { + connector_id: String::new(), // Placeholder - connector will set this + message: msg, + }) + } + } + + oc::Event::MessagePartUpdated { properties } => { + // Extract session_id from the part before translation + let session_id = Self::extract_session_id_from_part(&properties.part)?; + + // Use the new translate_to_session_update method + if let Some(update) = + self.translate_to_session_update(oc::Event::MessagePartUpdated { properties })? + { + Ok(Event::SessionUpdate { + connector_id: String::new(), // Placeholder - connector will set this + session_id, + update, + }) + } else { + // Duplicate or non-delta update - skip + Err(TranslationError::Duplicate) + } + } + + oc::Event::ServerConnected { .. } => Ok(Event::Connected), + + oc::Event::Unknown => Err(TranslationError::UnknownEvent), + + _ => { + // Log unhandled event + eprintln!("[OpenCodeAdapter] Unhandled event type: {:?}", oc_event); + Err(TranslationError::UnsupportedEvent) + } + } + } + + /// Translate an OpenCode event to a SessionUpdate (ACP-style streaming updates) + /// Returns Ok(None) if the event should be skipped (e.g., duplicates) + pub fn translate_to_session_update( + &self, + oc_event: oc::Event, + ) -> Result<Option<SessionUpdate>, TranslationError> { + match oc_event { + // Only translate MessagePartUpdated events for now + // Other events will be handled in future tasks + oc::Event::MessagePartUpdated { properties } => { + self.translate_part_to_update(properties) + } + _ => Err(TranslationError::UnsupportedEvent), + } + } + + /// Extract session_id from an OpenCode part + /// Returns an error if the part type doesn't contain a session_id + fn extract_session_id_from_part(part: &oc::Part) -> Result<String, TranslationError> { + let session_id = match part { + oc::Part::Text(t) => &t.session_id, + oc::Part::Reasoning(r) => &r.session_id, + oc::Part::Tool(t) => &t.session_id, + oc::Part::File(f) => &f.session_id, + // For unsupported types, return error + _ => return Err(TranslationError::UnsupportedPartType), + }; + Ok(session_id.clone()) + } + + /// Build provider metadata from an OpenCode part + /// Extracts session_id, message_id, and part_id for debugging and traceability + fn build_meta(part: &oc::Part) -> Meta { + let (session_id, message_id, part_id) = match part { + oc::Part::Text(t) => (&t.session_id, &t.message_id, &t.id), + oc::Part::Reasoning(r) => (&r.session_id, &r.message_id, &r.id), + oc::Part::Tool(t) => (&t.session_id, &t.message_id, &t.id), + oc::Part::File(f) => (&f.session_id, &f.message_id, &f.id), + // For unsupported types, return empty meta (they'll fail translation anyway) + _ => return Meta::default(), + }; + + Meta { + provider: Some(ProviderMeta { + name: "opencode".to_string(), + original_ids: Some(HashMap::from([ + ("session_id".to_string(), session_id.clone()), + ("message_id".to_string(), message_id.clone()), + ("part_id".to_string(), part_id.clone()), + ])), + raw_excerpt: None, + }), + extra: HashMap::new(), + } + } + + /// Translate a MessagePartUpdated event to a SessionUpdate + /// Returns Ok(None) if this is a duplicate or should be skipped + fn translate_part_to_update( + &self, + properties: oc::MessagePartEventInfo, + ) -> Result<Option<SessionUpdate>, TranslationError> { + let part = &properties.part; + + // Extract common fields from the part + let (message_id, part_id) = match part { + oc::Part::Text(t) => (t.message_id.clone(), t.id.clone()), + oc::Part::Reasoning(r) => (r.message_id.clone(), r.id.clone()), + oc::Part::Tool(t) => (t.message_id.clone(), t.id.clone()), + oc::Part::File(f) => (f.message_id.clone(), f.id.clone()), + _ => return Err(TranslationError::UnsupportedPartType), + }; + + // Track parts to avoid sending duplicates + let mut state = self.state.lock().unwrap(); + let is_new_part = state.seen_parts.insert(part_id.clone()); + + // If this is an update to an existing part with no delta, skip it + if !is_new_part && properties.delta.is_none() { + return Ok(None); + } + drop(state); // Release lock + + // Build metadata once for this part + let meta = Self::build_meta(part); + + // Translate the part based on its type + match part { + oc::Part::Text(t) => { + // Use delta instead of full text to avoid duplication + let text = properties.delta.unwrap_or_else(|| t.text.clone()); + + // Determine if this is a user or assistant message + // by looking up the message_id in our tracked roles + let state = self.state.lock().unwrap(); + let role = state.message_roles.get(&message_id).cloned(); + drop(state); + + match role { + Some(MessageRole::User) => Ok(Some(SessionUpdate::UserMessageChunk { + message_id, + content: ContentBlock::Text { text }, + _meta: Some(meta), + })), + Some(MessageRole::Assistant) | None => { + // Default to assistant for backward compatibility + // (in case message wasn't tracked or is from old events) + Ok(Some(SessionUpdate::AgentMessageChunk { + message_id, + content: ContentBlock::Text { text }, + _meta: Some(meta), + })) + } + } + } + + oc::Part::Reasoning(r) => { + // Use delta instead of full text to avoid duplication + let text = properties.delta.unwrap_or_else(|| r.text.clone()); + + Ok(Some(SessionUpdate::AgentThoughtChunk { + message_id, + content: ContentBlock::Text { text }, + _meta: Some(meta), + })) + } + + oc::Part::Tool(t) => { + // Map OpenCode tool state to ToolCallStatus + let (status, raw_input, raw_output, error) = match &t.state { + oc::ToolState::Pending => (ToolCallStatus::Pending, None, None, None), + oc::ToolState::Running { input, .. } => { + (ToolCallStatus::Running, Some(input.clone()), None, None) + } + oc::ToolState::Completed { input, output, .. } => ( + ToolCallStatus::Completed, + Some(input.clone()), + Some(serde_json::Value::String(output.clone())), + None, + ), + oc::ToolState::Error { input, error, .. } => ( + ToolCallStatus::Error, + Some(input.clone()), + None, + Some(error.clone()), + ), + }; + + let tool_call = ToolCall { + id: t.id.clone(), + tool_name: t.tool.clone(), + status, + content: vec![], // Tool output will be added as content blocks in future + raw_input, + raw_output, + title: None, + error, + metadata: t.metadata.clone(), + origin: None, // OpenCode tools - origin not specified + }; + + // Track tool calls: first occurrence → ToolCall, subsequent → ToolCallUpdate + let mut state = self.state.lock().unwrap(); + let is_new_tool_call = state.tool_calls_seen.insert(t.id.clone()); + drop(state); // Release lock + + if is_new_tool_call { + Ok(Some(SessionUpdate::ToolCall { + message_id, + tool_call, + _meta: Some(meta), + })) + } else { + Ok(Some(SessionUpdate::ToolCallUpdate { + message_id, + tool_call_id: t.id.clone(), + tool_call, + _meta: Some(meta), + })) + } + } + + oc::Part::File(f) => { + // File parts become ResourceLink content blocks + // Determine the appropriate chunk type (assume agent for now) + Ok(Some(SessionUpdate::AgentMessageChunk { + message_id, + content: ContentBlock::ResourceLink { + uri: f.url.clone(), + name: f.filename.clone(), + mime_type: Some(f.mime.clone()), + }, + _meta: Some(meta), + })) + } + + _ => Err(TranslationError::UnsupportedPartType), + } + } + + /// Translate an OpenCode session to a Dirigent session + fn translate_session(oc_session: oc::Session) -> Result<Session, TranslationError> { + Ok(Session { + id: oc_session.id, + title: oc_session.title, + created_at: timestamp_to_datetime(oc_session.time.created)?, + updated_at: timestamp_to_datetime(oc_session.time.updated)?, + metadata: SessionMetadata { + project_path: oc_session.directory, + model: None, // Could extract from assistant messages if needed + total_messages: 0, // Would need to be calculated separately + system_message: None, // Will be set from first assistant message + current_mode_id: None, + _meta: None, + project_id: None, + }, + cwd: None, // OpenCode doesn't expose cwd separately from project_path + models: None, // OpenCode doesn't provide ACP model state + modes: None, // OpenCode doesn't provide ACP mode state + config_options: None, + acp_client_id: None, // OpenCode doesn't have ACP client ID + }) + } + + /// Translate an OpenCode message to a Dirigent message + /// Note: System messages are extracted separately and stored at the session level + fn translate_message(oc_msg: oc::Message) -> Result<Message, TranslationError> { + let (id, session_id, role, created_at, status, metadata) = match oc_msg { + oc::Message::User(u) => ( + u.id, + u.session_id, + MessageRole::User, + timestamp_to_datetime(u.time.created)?, + MessageStatus::Completed, // User messages are always complete + None, // User messages don't have metadata + ), + oc::Message::Assistant(a) => { + let status = if let Some(err) = a.error { + MessageStatus::Failed { + error: format_message_error(&err), + } + } else if a.time.completed.is_some() { + MessageStatus::Completed + } else { + MessageStatus::Streaming + }; + + // Extract metadata from assistant message + let metadata = Some(MessageMetadata { + cost: Some(a.cost), + tokens_input: Some(a.tokens.input), + tokens_output: Some(a.tokens.output), + response_time_ms: None, + latency_ms: None, + model: a.model_id.clone(), + other: None, + }); + + ( + a.id, + a.session_id, + MessageRole::Assistant, + timestamp_to_datetime(a.time.created)?, + status, + metadata, + ) + } + }; + + Ok(Message { + id, + session_id, + role, + created_at, + content: vec![], // Parts come separately via MessagePartUpdated events + status, + metadata, + }) + } + + /// Translate an OpenCode part directly to a Dirigent message part (for list operations) + /// This is used when loading message history, not for streaming updates. + /// Returns just the MessagePart (session_id and message_id are on the message itself) + pub fn translate_part_for_list(oc_part: oc::Part) -> Result<MessagePart, TranslationError> { + let (_, _, _, part) = Self::translate_part(oc_part)?; + Ok(part) + } + + /// Translate an OpenCode part to a Dirigent message part (internal helper) + /// Returns (session_id, message_id, part_id, part) + fn translate_part( + oc_part: oc::Part, + ) -> Result<(String, String, String, MessagePart), TranslationError> { + match oc_part { + oc::Part::Text(t) => Ok(( + t.session_id.clone(), + t.message_id.clone(), + t.id, + MessagePart::Text { text: t.text }, + )), + + oc::Part::Reasoning(r) => Ok(( + r.session_id.clone(), + r.message_id.clone(), + r.id, + MessagePart::Thinking { text: r.text }, + )), + + oc::Part::Tool(t) => { + let (input, output) = match t.state { + oc::ToolState::Pending => (serde_json::Value::Null, None), + oc::ToolState::Running { input, .. } => (input, None), + oc::ToolState::Completed { input, output, .. } => { + (input, Some(serde_json::Value::String(output))) + } + oc::ToolState::Error { input, error, .. } => { + (input, Some(serde_json::json!({ "error": error }))) + } + }; + + Ok(( + t.session_id.clone(), + t.message_id.clone(), + t.id, + MessagePart::Tool { + tool: t.tool, + tool_call_id: None, + input, + output, + }, + )) + } + + oc::Part::File(f) => Ok(( + f.session_id.clone(), + f.message_id.clone(), + f.id, + MessagePart::File { + path: f.filename.unwrap_or_else(|| f.url.clone()), + content: f.url, // In Dirigent protocol, we store the URL as content + }, + )), + + // For now, we skip these part types - they could be added later + oc::Part::StepStart(_) + | oc::Part::StepFinish(_) + | oc::Part::Snapshot(_) + | oc::Part::Patch(_) + | oc::Part::Agent(_) + | oc::Part::Retry(_) => Err(TranslationError::UnsupportedPartType), + + oc::Part::Unknown => Err(TranslationError::UnsupportedPartType), + } + } +} // End impl OpenCodeAdapter + +/// Convert Unix timestamp (milliseconds) to DateTime<Utc> +fn timestamp_to_datetime(timestamp: u64) -> Result<DateTime<Utc>, TranslationError> { + Utc.timestamp_millis_opt(timestamp as i64) + .single() + .ok_or(TranslationError::InvalidTimestamp) +} + +/// Format a message error into a user-friendly string +fn format_message_error(error: &oc::MessageError) -> String { + match error { + oc::MessageError::ProviderAuthError { data } => { + format!( + "Authentication error for {}: {}", + data.provider_id, data.message + ) + } + oc::MessageError::UnknownError { data } => { + format!("Unknown error: {}", data.message) + } + oc::MessageError::MessageOutputLengthError => "Message output length exceeded".to_string(), + oc::MessageError::MessageAbortedError { data } => { + format!("Message aborted: {}", data.message) + } + oc::MessageError::ApiError { data } => { + if let Some(status) = data.status_code { + format!("API error ({}): {}", status, data.message) + } else { + format!("API error: {}", data.message) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Helper to create a text part for testing + fn create_text_part(session_id: &str, message_id: &str, part_id: &str, text: &str) -> oc::Part { + oc::Part::Text(oc::TextPart { + id: part_id.to_string(), + session_id: session_id.to_string(), + message_id: message_id.to_string(), + text: text.to_string(), + synthetic: None, + time: Some(oc::PartTime { + start: 1000, + end: None, + }), + }) + } + + /// Helper to create a reasoning part for testing + fn create_reasoning_part( + session_id: &str, + message_id: &str, + part_id: &str, + text: &str, + ) -> oc::Part { + oc::Part::Reasoning(oc::ReasoningPart { + id: part_id.to_string(), + session_id: session_id.to_string(), + message_id: message_id.to_string(), + text: text.to_string(), + time: Some(oc::PartTime { + start: 1000, + end: None, + }), + }) + } + + /// Helper to create a tool part for testing + fn create_tool_part( + session_id: &str, + message_id: &str, + part_id: &str, + tool: &str, + state: oc::ToolState, + ) -> oc::Part { + oc::Part::Tool(oc::ToolPart { + id: part_id.to_string(), + session_id: session_id.to_string(), + message_id: message_id.to_string(), + call_id: format!("call_{}", part_id), + tool: tool.to_string(), + state, + metadata: None, + }) + } + + /// Helper to create a file part for testing + fn create_file_part( + session_id: &str, + message_id: &str, + part_id: &str, + filename: &str, + url: &str, + mime: &str, + ) -> oc::Part { + oc::Part::File(oc::FilePart { + id: part_id.to_string(), + session_id: session_id.to_string(), + message_id: message_id.to_string(), + mime: mime.to_string(), + filename: Some(filename.to_string()), + url: url.to_string(), + source: None, + }) + } + + #[test] + fn test_translate_text_part_to_agent_message_chunk() { + let adapter = OpenCodeAdapter::new(); + let part = create_text_part("sess1", "msg1", "part1", "Hello world"); + + let event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part, + delta: Some("Hello world".to_string()), + }, + }; + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::AgentMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg1"); + assert_eq!( + content, + ContentBlock::Text { + text: "Hello world".to_string() + } + ); + // Verify metadata is present with provider info + assert!(_meta.is_some()); + let meta = _meta.unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.unwrap(); + assert_eq!(provider.name, "opencode"); + assert!(provider.original_ids.is_some()); + let ids = provider.original_ids.unwrap(); + assert_eq!(ids.get("session_id").unwrap(), "sess1"); + assert_eq!(ids.get("message_id").unwrap(), "msg1"); + assert_eq!(ids.get("part_id").unwrap(), "part1"); + } + _ => panic!("Expected AgentMessageChunk"), + } + } + + #[test] + fn test_translate_reasoning_part_to_agent_thought_chunk() { + let adapter = OpenCodeAdapter::new(); + let part = create_reasoning_part("sess1", "msg1", "part2", "Thinking about the problem..."); + + let event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part, + delta: Some("Thinking about the problem...".to_string()), + }, + }; + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::AgentThoughtChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg1"); + assert_eq!( + content, + ContentBlock::Text { + text: "Thinking about the problem...".to_string() + } + ); + // Verify metadata is present with provider info + assert!(_meta.is_some()); + let meta = _meta.unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.unwrap(); + assert_eq!(provider.name, "opencode"); + assert!(provider.original_ids.is_some()); + let ids = provider.original_ids.unwrap(); + assert_eq!(ids.get("session_id").unwrap(), "sess1"); + assert_eq!(ids.get("message_id").unwrap(), "msg1"); + assert_eq!(ids.get("part_id").unwrap(), "part2"); + } + _ => panic!("Expected AgentThoughtChunk"), + } + } + + #[test] + fn test_translate_tool_part_pending_to_tool_call() { + let adapter = OpenCodeAdapter::new(); + let part = create_tool_part("sess1", "msg1", "part3", "bash", oc::ToolState::Pending); + + let event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { part, delta: None }, + }; + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::ToolCall { + message_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg1"); + assert_eq!(tool_call.id, "part3"); + assert_eq!(tool_call.tool_name, "bash"); + assert_eq!(tool_call.status, ToolCallStatus::Pending); + assert_eq!(tool_call.raw_input, None); + assert_eq!(tool_call.raw_output, None); + assert_eq!(tool_call.error, None); + // Verify metadata is present + assert!(_meta.is_some()); + let meta = _meta.unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.unwrap(); + assert_eq!(provider.name, "opencode"); + assert!(provider.original_ids.is_some()); + let ids = provider.original_ids.unwrap(); + assert_eq!(ids.get("session_id").unwrap(), "sess1"); + assert_eq!(ids.get("message_id").unwrap(), "msg1"); + assert_eq!(ids.get("part_id").unwrap(), "part3"); + } + _ => panic!("Expected ToolCall"), + } + } + + #[test] + fn test_translate_tool_part_running_to_tool_call() { + let adapter = OpenCodeAdapter::new(); + let input = serde_json::json!({"command": "ls -la"}); + let part = create_tool_part( + "sess1", + "msg1", + "part4", + "bash", + oc::ToolState::Running { + input: input.clone(), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ); + + let event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { part, delta: None }, + }; + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::ToolCall { + message_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg1"); + assert_eq!(tool_call.id, "part4"); + assert_eq!(tool_call.tool_name, "bash"); + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert_eq!(tool_call.raw_input, Some(input)); + assert_eq!(tool_call.raw_output, None); + assert_eq!(tool_call.error, None); + // Verify metadata is present + assert!(_meta.is_some()); + let meta = _meta.unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.unwrap(); + assert_eq!(provider.name, "opencode"); + assert!(provider.original_ids.is_some()); + let ids = provider.original_ids.unwrap(); + assert_eq!(ids.get("session_id").unwrap(), "sess1"); + assert_eq!(ids.get("message_id").unwrap(), "msg1"); + assert_eq!(ids.get("part_id").unwrap(), "part4"); + } + _ => panic!("Expected ToolCall"), + } + } + + #[test] + fn test_translate_tool_part_completed_to_tool_call_update() { + let adapter = OpenCodeAdapter::new(); + let input = serde_json::json!({"command": "ls -la"}); + let output = "file1.txt\nfile2.txt"; + + // First, send the pending state to mark it as seen + let pending_part = + create_tool_part("sess1", "msg1", "part5", "bash", oc::ToolState::Pending); + let pending_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: pending_part, + delta: None, + }, + }; + let _ = adapter.translate_to_session_update(pending_event); + + // Now send the completed state + let completed_part = create_tool_part( + "sess1", + "msg1", + "part5", + "bash", + oc::ToolState::Completed { + input: input.clone(), + output: output.to_string(), + title: "bash command".to_string(), + metadata: serde_json::Value::Null, + time: oc::PartTime { + start: 1000, + end: Some(2000), + }, + attachments: None, + }, + ); + + let event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: completed_part, + delta: Some(output.to_string()), + }, + }; + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::ToolCallUpdate { + message_id, + tool_call_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg1"); + assert_eq!(tool_call_id, "part5"); + assert_eq!(tool_call.id, "part5"); + assert_eq!(tool_call.tool_name, "bash"); + assert_eq!(tool_call.status, ToolCallStatus::Completed); + assert_eq!(tool_call.raw_input, Some(input)); + assert_eq!( + tool_call.raw_output, + Some(serde_json::Value::String(output.to_string())) + ); + assert_eq!(tool_call.error, None); + // Verify metadata is present + assert!(_meta.is_some()); + let meta = _meta.unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.unwrap(); + assert_eq!(provider.name, "opencode"); + assert!(provider.original_ids.is_some()); + let ids = provider.original_ids.unwrap(); + assert_eq!(ids.get("session_id").unwrap(), "sess1"); + assert_eq!(ids.get("message_id").unwrap(), "msg1"); + assert_eq!(ids.get("part_id").unwrap(), "part5"); + } + _ => panic!("Expected ToolCallUpdate"), + } + } + + #[test] + fn test_translate_tool_part_error_to_tool_call_update() { + let adapter = OpenCodeAdapter::new(); + let input = serde_json::json!({"command": "invalid_command"}); + let error_msg = "Command not found"; + + // First, send the pending state + let pending_part = + create_tool_part("sess1", "msg1", "part6", "bash", oc::ToolState::Pending); + let pending_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: pending_part, + delta: None, + }, + }; + let _ = adapter.translate_to_session_update(pending_event); + + // Now send the error state + let error_part = create_tool_part( + "sess1", + "msg1", + "part6", + "bash", + oc::ToolState::Error { + input: input.clone(), + error: error_msg.to_string(), + metadata: None, + time: oc::PartTime { + start: 1000, + end: Some(2000), + }, + }, + ); + + let event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: error_part, + delta: Some(error_msg.to_string()), + }, + }; + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::ToolCallUpdate { + message_id, + tool_call_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg1"); + assert_eq!(tool_call_id, "part6"); + assert_eq!(tool_call.id, "part6"); + assert_eq!(tool_call.tool_name, "bash"); + assert_eq!(tool_call.status, ToolCallStatus::Error); + assert_eq!(tool_call.raw_input, Some(input)); + assert_eq!(tool_call.raw_output, None); + assert_eq!(tool_call.error, Some(error_msg.to_string())); + // Verify metadata is present + assert!(_meta.is_some()); + let meta = _meta.unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.unwrap(); + assert_eq!(provider.name, "opencode"); + assert!(provider.original_ids.is_some()); + let ids = provider.original_ids.unwrap(); + assert_eq!(ids.get("session_id").unwrap(), "sess1"); + assert_eq!(ids.get("message_id").unwrap(), "msg1"); + assert_eq!(ids.get("part_id").unwrap(), "part6"); + } + _ => panic!("Expected ToolCallUpdate"), + } + } + + #[test] + fn test_translate_file_part_to_resource_link() { + let adapter = OpenCodeAdapter::new(); + let part = create_file_part( + "sess1", + "msg1", + "part7", + "test.txt", + "file:///path/to/test.txt", + "text/plain", + ); + + let event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { part, delta: None }, + }; + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::AgentMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg1"); + assert_eq!( + content, + ContentBlock::ResourceLink { + uri: "file:///path/to/test.txt".to_string(), + name: Some("test.txt".to_string()), + mime_type: Some("text/plain".to_string()), + } + ); + // Verify metadata is present + assert!(_meta.is_some()); + let meta = _meta.unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.unwrap(); + assert_eq!(provider.name, "opencode"); + assert!(provider.original_ids.is_some()); + let ids = provider.original_ids.unwrap(); + assert_eq!(ids.get("session_id").unwrap(), "sess1"); + assert_eq!(ids.get("message_id").unwrap(), "msg1"); + assert_eq!(ids.get("part_id").unwrap(), "part7"); + } + _ => panic!("Expected AgentMessageChunk with ResourceLink"), + } + } + + #[test] + fn test_duplicate_part_without_delta_returns_none() { + let adapter = OpenCodeAdapter::new(); + let part = create_text_part("sess1", "msg1", "part8", "Hello"); + + // First update with delta + let event1 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: part.clone(), + delta: Some("Hello".to_string()), + }, + }; + + let result1 = adapter.translate_to_session_update(event1); + assert!(result1.is_ok()); + assert!(result1.unwrap().is_some()); + + // Second update without delta (should be skipped) + let event2 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { part, delta: None }, + }; + + let result2 = adapter.translate_to_session_update(event2); + assert!(result2.is_ok()); + assert!(result2.unwrap().is_none()); // Should return None (skip) + } + + #[test] + fn test_unsupported_event_type_returns_error() { + let adapter = OpenCodeAdapter::new(); + + // SessionCreated is not supported in translate_to_session_update + let event = oc::Event::SessionCreated { + properties: oc::SessionEventInfo { + info: oc::Session { + id: "sess1".to_string(), + project_id: "proj1".to_string(), + directory: "/test".to_string(), + parent_id: None, + summary: None, + share: None, + title: "Test".to_string(), + version: "1.0".to_string(), + time: oc::SessionTime { + created: 1000, + updated: 1000, + compacting: None, + }, + revert: None, + }, + }, + }; + + let result = adapter.translate_to_session_update(event); + assert!(result.is_err()); + assert!(matches!(result, Err(TranslationError::UnsupportedEvent))); + } + + #[test] + fn test_unsupported_part_type_returns_error() { + let adapter = OpenCodeAdapter::new(); + + // Snapshot is not supported + let part = oc::Part::Snapshot(oc::SnapshotPart { + id: "snap1".to_string(), + session_id: "sess1".to_string(), + message_id: "msg1".to_string(), + snapshot: "snapshot_data".to_string(), + }); + + let event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { part, delta: None }, + }; + + let result = adapter.translate_to_session_update(event); + assert!(result.is_err()); + assert!(matches!(result, Err(TranslationError::UnsupportedPartType))); + } + + #[test] + fn test_tool_call_lifecycle_tracking() { + let adapter = OpenCodeAdapter::new(); + let tool_id = "tool_lifecycle_test"; + let message_id = "msg_lifecycle"; + let input = serde_json::json!({"command": "echo hello"}); + let output = "hello"; + + // Step 1: Pending state (first occurrence) → ToolCall + let pending_part = + create_tool_part("sess1", message_id, tool_id, "bash", oc::ToolState::Pending); + let pending_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: pending_part, + delta: None, + }, + }; + + let result = adapter.translate_to_session_update(pending_event); + assert!(result.is_ok()); + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::ToolCall { + message_id: msg_id, + tool_call, + .. + } => { + assert_eq!(msg_id, message_id); + assert_eq!(tool_call.id, tool_id); + assert_eq!(tool_call.status, ToolCallStatus::Pending); + } + _ => panic!("Expected ToolCall for first occurrence"), + } + + // Step 2: Running state (update) → ToolCallUpdate + let running_part = create_tool_part( + "sess1", + message_id, + tool_id, + "bash", + oc::ToolState::Running { + input: input.clone(), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ); + let running_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: running_part, + delta: Some("running".to_string()), + }, + }; + + let result = adapter.translate_to_session_update(running_event); + assert!(result.is_ok()); + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::ToolCallUpdate { + message_id: msg_id, + tool_call_id, + tool_call, + .. + } => { + assert_eq!(msg_id, message_id); + assert_eq!(tool_call_id, tool_id); + assert_eq!(tool_call.id, tool_id); + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert_eq!(tool_call.raw_input, Some(input.clone())); + } + _ => panic!("Expected ToolCallUpdate for second occurrence"), + } + + // Step 3: Completed state (update) → ToolCallUpdate + let completed_part = create_tool_part( + "sess1", + message_id, + tool_id, + "bash", + oc::ToolState::Completed { + input: input.clone(), + output: output.to_string(), + title: "bash command".to_string(), + metadata: serde_json::Value::Null, + time: oc::PartTime { + start: 1000, + end: Some(2000), + }, + attachments: None, + }, + ); + let completed_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: completed_part, + delta: Some(output.to_string()), + }, + }; + + let result = adapter.translate_to_session_update(completed_event); + assert!(result.is_ok()); + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::ToolCallUpdate { + message_id: msg_id, + tool_call_id, + tool_call, + .. + } => { + assert_eq!(msg_id, message_id); + assert_eq!(tool_call_id, tool_id); + assert_eq!(tool_call.id, tool_id); + assert_eq!(tool_call.status, ToolCallStatus::Completed); + assert_eq!(tool_call.raw_input, Some(input.clone())); + assert_eq!( + tool_call.raw_output, + Some(serde_json::Value::String(output.to_string())) + ); + } + _ => panic!("Expected ToolCallUpdate for third occurrence"), + } + } + + #[test] + fn test_tool_call_error_lifecycle() { + let adapter = OpenCodeAdapter::new(); + let tool_id = "tool_error_test"; + let message_id = "msg_error"; + let input = serde_json::json!({"command": "invalid"}); + let error_msg = "Command failed"; + + // Step 1: Pending → ToolCall + let pending_part = + create_tool_part("sess1", message_id, tool_id, "bash", oc::ToolState::Pending); + let pending_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: pending_part, + delta: None, + }, + }; + + let result = adapter.translate_to_session_update(pending_event); + assert!(result.is_ok()); + assert!(matches!( + result.unwrap(), + Some(SessionUpdate::ToolCall { .. }) + )); + + // Step 2: Running → ToolCallUpdate + let running_part = create_tool_part( + "sess1", + message_id, + tool_id, + "bash", + oc::ToolState::Running { + input: input.clone(), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ); + let running_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: running_part, + delta: Some("running".to_string()), + }, + }; + + let result = adapter.translate_to_session_update(running_event); + assert!(result.is_ok()); + assert!(matches!( + result.unwrap(), + Some(SessionUpdate::ToolCallUpdate { .. }) + )); + + // Step 3: Error → ToolCallUpdate + let error_part = create_tool_part( + "sess1", + message_id, + tool_id, + "bash", + oc::ToolState::Error { + input: input.clone(), + error: error_msg.to_string(), + metadata: None, + time: oc::PartTime { + start: 1000, + end: Some(2000), + }, + }, + ); + let error_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: error_part, + delta: Some(error_msg.to_string()), + }, + }; + + let result = adapter.translate_to_session_update(error_event); + assert!(result.is_ok()); + let update = result.unwrap(); + assert!(update.is_some()); + + match update.unwrap() { + SessionUpdate::ToolCallUpdate { + message_id: msg_id, + tool_call_id, + tool_call, + .. + } => { + assert_eq!(msg_id, message_id); + assert_eq!(tool_call_id, tool_id); + assert_eq!(tool_call.status, ToolCallStatus::Error); + assert_eq!(tool_call.error, Some(error_msg.to_string())); + } + _ => panic!("Expected ToolCallUpdate for error state"), + } + } + + #[test] + fn test_multiple_concurrent_tool_calls_tracked_independently() { + let adapter = OpenCodeAdapter::new(); + let message_id = "msg_multi"; + + // Tool 1: Pending → ToolCall + let tool1_pending = + create_tool_part("sess1", message_id, "tool1", "bash", oc::ToolState::Pending); + let event1 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: tool1_pending, + delta: None, + }, + }; + let result1 = adapter.translate_to_session_update(event1); + assert!(matches!( + result1.unwrap(), + Some(SessionUpdate::ToolCall { .. }) + )); + + // Tool 2: Pending → ToolCall (different tool call) + let tool2_pending = + create_tool_part("sess1", message_id, "tool2", "read", oc::ToolState::Pending); + let event2 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: tool2_pending, + delta: None, + }, + }; + let result2 = adapter.translate_to_session_update(event2); + assert!(matches!( + result2.unwrap(), + Some(SessionUpdate::ToolCall { .. }) + )); + + // Tool 1: Running → ToolCallUpdate + let tool1_running = create_tool_part( + "sess1", + message_id, + "tool1", + "bash", + oc::ToolState::Running { + input: serde_json::json!({"cmd": "ls"}), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ); + let event3 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: tool1_running, + delta: Some("running".to_string()), + }, + }; + let result3 = adapter.translate_to_session_update(event3); + assert!(matches!( + result3.unwrap(), + Some(SessionUpdate::ToolCallUpdate { .. }) + )); + + // Tool 2: Running → ToolCallUpdate + let tool2_running = create_tool_part( + "sess1", + message_id, + "tool2", + "read", + oc::ToolState::Running { + input: serde_json::json!({"file": "test.txt"}), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ); + let event4 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: tool2_running, + delta: Some("running".to_string()), + }, + }; + let result4 = adapter.translate_to_session_update(event4); + assert!(matches!( + result4.unwrap(), + Some(SessionUpdate::ToolCallUpdate { .. }) + )); + } +} diff --git a/crates/dirigent_protocol/src/adapters/rest.rs b/crates/dirigent_protocol/src/adapters/rest.rs new file mode 100644 index 0000000..ffd55f2 --- /dev/null +++ b/crates/dirigent_protocol/src/adapters/rest.rs @@ -0,0 +1,167 @@ +/// REST API conversion helpers +/// +/// Converts OpenCode REST API responses to Dirigent protocol types +use crate::{ + Message, MessageMetadata, MessagePart, MessageRole, MessageStatus, Session, SessionMetadata, +}; +use chrono::{DateTime, TimeZone, Utc}; +use opencode_client::types as oc; + +/// Convert OpenCode Session to Dirigent Session +pub fn convert_session(oc_session: oc::Session) -> Session { + Session { + id: oc_session.id, + title: oc_session.title, + created_at: timestamp_to_datetime(oc_session.time.created), + updated_at: timestamp_to_datetime(oc_session.time.updated), + metadata: SessionMetadata { + project_path: oc_session.directory, + model: None, // Not available in session info + total_messages: 0, // Would need to be calculated separately + system_message: None, // Will be set from first assistant message + current_mode_id: None, + _meta: None, + project_id: None, + }, + cwd: None, // OpenCode REST doesn't expose cwd separately from project_path + models: None, // OpenCode doesn't provide ACP model state + modes: None, // OpenCode doesn't provide ACP mode state + config_options: None, + acp_client_id: None, // OpenCode doesn't have ACP client ID + } +} + +/// Convert OpenCode Message to Dirigent Message +pub fn convert_message(oc_msg: oc::Message) -> Message { + let (id, session_id, role, created_at, status, metadata) = match oc_msg { + oc::Message::User(u) => ( + u.id, + u.session_id, + MessageRole::User, + timestamp_to_datetime(u.time.created), + MessageStatus::Completed, + None, // User messages don't have metadata + ), + oc::Message::Assistant(a) => { + let status = if let Some(err) = a.error { + MessageStatus::Failed { + error: format_message_error(&err), + } + } else if a.time.completed.is_some() { + MessageStatus::Completed + } else { + MessageStatus::Streaming + }; + + // Extract metadata from assistant message + let metadata = Some(MessageMetadata { + cost: Some(a.cost), + tokens_input: Some(a.tokens.input), + tokens_output: Some(a.tokens.output), + response_time_ms: None, + latency_ms: None, + model: a.model_id.clone(), + other: None, + }); + + ( + a.id, + a.session_id, + MessageRole::Assistant, + timestamp_to_datetime(a.time.created), + status, + metadata, + ) + } + }; + + Message { + id, + session_id, + role, + created_at, + content: vec![], // Parts are separate + status, + metadata, + } +} + +/// Convert OpenCode MessageWithParts to Dirigent Message with parts +pub fn convert_message_with_parts(oc_msg: oc::MessageWithParts) -> Message { + let mut message = convert_message(oc_msg.info); + + // Convert parts + message.content = oc_msg + .parts + .into_iter() + .filter_map(|part| convert_part(part)) + .collect(); + + message +} + +/// Convert OpenCode Part to Dirigent MessagePart +fn convert_part(oc_part: oc::Part) -> Option<MessagePart> { + match oc_part { + oc::Part::Text(t) => Some(MessagePart::Text { text: t.text }), + oc::Part::Reasoning(r) => Some(MessagePart::Thinking { text: r.text }), + oc::Part::Tool(t) => { + let (input, output) = match t.state { + oc::ToolState::Pending => (serde_json::Value::Null, None), + oc::ToolState::Running { input, .. } => (input, None), + oc::ToolState::Completed { input, output, .. } => { + (input, Some(serde_json::Value::String(output))) + } + oc::ToolState::Error { input, error, .. } => { + (input, Some(serde_json::json!({ "error": error }))) + } + }; + + Some(MessagePart::Tool { + tool: t.tool, + tool_call_id: None, + input, + output, + }) + } + oc::Part::File(f) => Some(MessagePart::File { + path: f.filename.unwrap_or_else(|| f.url.clone()), + content: f.url, + }), + // Skip unsupported part types + _ => None, + } +} + +/// Convert Unix timestamp (milliseconds) to DateTime<Utc> +fn timestamp_to_datetime(timestamp: u64) -> DateTime<Utc> { + Utc.timestamp_millis_opt(timestamp as i64) + .single() + .unwrap_or_else(|| Utc::now()) +} + +/// Format a message error into a user-friendly string +fn format_message_error(error: &oc::MessageError) -> String { + match error { + oc::MessageError::ProviderAuthError { data } => { + format!( + "Authentication error for {}: {}", + data.provider_id, data.message + ) + } + oc::MessageError::UnknownError { data } => { + format!("Unknown error: {}", data.message) + } + oc::MessageError::MessageOutputLengthError => "Message output length exceeded".to_string(), + oc::MessageError::MessageAbortedError { data } => { + format!("Message aborted: {}", data.message) + } + oc::MessageError::ApiError { data } => { + if let Some(status) = data.status_code { + format!("API error ({}): {}", status, data.message) + } else { + format!("API error: {}", data.message) + } + } + } +} diff --git a/crates/dirigent_protocol/src/conversation.rs b/crates/dirigent_protocol/src/conversation.rs new file mode 100644 index 0000000..1fd4c7f --- /dev/null +++ b/crates/dirigent_protocol/src/conversation.rs @@ -0,0 +1,80 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Message { + pub id: String, + pub session_id: String, + pub role: MessageRole, + pub created_at: DateTime<Utc>, + pub content: Vec<MessagePart>, + pub status: MessageStatus, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option<MessageMetadata>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MessageMetadata { + // Cost information + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option<f64>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens_input: Option<u64>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens_output: Option<u64>, + + // Performance metrics + #[serde(skip_serializing_if = "Option::is_none")] + pub response_time_ms: Option<u64>, + #[serde(skip_serializing_if = "Option::is_none")] + pub latency_ms: Option<u64>, + + // Model information + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option<String>, + + // Arbitrary metadata from connector clients + #[serde(skip_serializing_if = "Option::is_none")] + pub other: Option<Value>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum MessageRole { + User, + Assistant, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum MessageStatus { + Pending, + Streaming, + Completed, + Failed { error: String }, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type")] +pub enum MessagePart { + Text { + text: String, + }, + Thinking { + text: String, + }, + Code { + language: String, + code: String, + }, + Tool { + tool: String, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option<String>, + input: Value, + output: Option<Value>, + }, + File { + path: String, + content: String, + }, +} diff --git a/crates/dirigent_protocol/src/events/mod.rs b/crates/dirigent_protocol/src/events/mod.rs new file mode 100644 index 0000000..09b28d9 --- /dev/null +++ b/crates/dirigent_protocol/src/events/mod.rs @@ -0,0 +1,974 @@ +use crate::session::{ConfigOption, SessionModeState, SessionModelState}; +use crate::{Message, Session, SessionUpdate}; +use serde::{Deserialize, Serialize}; + +/// Reason why the turn was marked complete (for debugging/observability) +/// +/// This enum indicates **how** the system determined that a turn has completed. +/// Different connector types use different strategies: +/// +/// - **OpenCode Connector**: Uses `ExplicitSignal` (upstream session.idle event) +/// - **ACP Connector (stdio)**: Uses `ResponseReceived` (JSON-RPC response is final) +/// - **Gateway Connector**: Uses `OperationsComplete` (tracks pending tool calls) +/// - **Fallback**: Uses `IdleTimeout` when no other signal available +/// +/// # Consumer Usage +/// +/// Most consumers should treat all triggers the same (turn is complete). +/// The trigger type is primarily for debugging and observability. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub enum TurnCompleteTrigger { + /// Explicit signal from upstream provider (e.g., OpenCode session.idle event) + /// + /// This is the most reliable trigger as it comes directly from the agent system. + ExplicitSignal, + + /// JSON-RPC response received (ACP stdio transport) + /// + /// In ACP stdio mode, the response message is the last message in the turn. + ResponseReceived, + + /// All tracked operations completed (e.g., pending tool calls resolved) + /// + /// Used when the connector tracks operation state and can determine + /// completion by monitoring tool call statuses. + OperationsComplete, + + /// Timeout-based idle detection (fallback mechanism) + /// + /// Used when no other completion signal is available. + /// The duration indicates how long the system waited before declaring completion. + IdleTimeout { duration_ms: u64 }, +} + +/// A single node in an inspector snapshot (protocol-level DTO). +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InspectorSnapshotNode { + pub id: String, + pub parent: Option<String>, + pub children: Vec<String>, + pub label: String, + pub kind: String, + pub state: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub state_detail: Option<String>, + pub properties: std::collections::BTreeMap<String, serde_json::Value>, + pub created_at: String, + pub last_updated: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "event", content = "data")] +pub enum Event { + // Session events + SessionsListed { + connector_id: String, + sessions: Vec<Session>, + }, + SessionCreated { + connector_id: String, + session: Session, + }, + SessionUpdated { + connector_id: String, + session: Session, + }, + SessionMetadataUpdated { + connector_id: String, + session_id: String, + title: Option<String>, + total_messages: Option<u32>, + model: Option<String>, + }, + SessionDeleted { + session_id: String, + }, + /// Session was closed (agent released resources, session remains in list). + /// The session can be loaded again later via session/load. + SessionClosed { + connector_id: String, + session_id: String, + }, + SessionSystemMessageSet { + session_id: String, + system_message: String, + }, + SessionIdle { + connector_id: String, + session_id: String, + }, + /// Session mode/model metadata received from an ACP connector + /// + /// Emitted when metadata is received from a connector (e.g., after session/new or session/load). + /// This event is separate from SessionCreated to support: + /// - Session takeover scenarios (session already exists, but metadata is new) + /// - Specific subscriptions to metadata changes + /// - Connectors that provide metadata asynchronously + /// + /// # Fields + /// - `models`: UNSTABLE in ACP spec but used by Claude-ACP + /// - `modes`: Stable in ACP spec + /// + /// Both fields are optional since not all connectors provide this data. + SessionMetadataReceived { + /// Connector that provided the metadata + connector_id: String, + /// Session the metadata belongs to + session_id: String, + /// Available models and current model (UNSTABLE in ACP spec) + #[serde(skip_serializing_if = "Option::is_none")] + models: Option<SessionModelState>, + /// Available modes and current mode + #[serde(skip_serializing_if = "Option::is_none")] + modes: Option<SessionModeState>, + /// ACP config options (replaces modes/models in future ACP versions) + #[serde(default, skip_serializing_if = "Option::is_none")] + config_options: Option<Vec<ConfigOption>>, + }, + + /// **All content for this turn/message has been received.** + /// + /// This is the **primary signal** for finalization actions (archiving, UI state lock). + /// Emitted BEFORE `SessionIdle` to ensure proper event ordering. + /// + /// # Event Semantics + /// + /// - **`MessageCompleted`**: Message metadata is ready (informational) + /// - Purpose: UI status updates ("Assistant is typing" → "Complete") + /// - Timing: Sent when message record exists, content may still be streaming + /// - Consumer action: Update UI state, show completion status + /// + /// - **`TurnComplete`**: All content received (actionable) + /// - Purpose: Signal that the entire turn is finalized + /// - Timing: Sent AFTER all content chunks, tool calls, and metadata + /// - Consumer action: Finalize storage, lock state, trigger post-processing + /// + /// - **`SessionIdle`**: No recent activity (informational) + /// - Purpose: UI spinner control, activity indication + /// - Timing: Sent AFTER `TurnComplete` as final safety signal + /// - Consumer action: Hide spinners, update activity indicators + /// + /// # Event Ordering + /// + /// ```text + /// 1. MessageStarted (message created) + /// 2. SessionUpdate::*Chunk (content streaming) + /// 3. SessionUpdate::ToolCall* (tool execution) + /// 4. MessageCompleted (metadata ready) + /// 5. TurnComplete ← YOU ARE HERE (finalize!) + /// 6. SessionIdle (activity stopped) + /// ``` + /// + /// # Consumer Behavior + /// + /// | Consumer | MessageCompleted | TurnComplete | SessionIdle | + /// |----------|------------------|--------------|-------------| + /// | **Archivist** | Ignore | **Finalize and write** | Safety net | + /// | **UI Cache** | Update status | **Lock state** | Hide spinner | + /// | **Conductor Bridge** | - | **Flush response** | Fallback flush | + /// + /// # Example Usage + /// + /// ```rust + /// use dirigent_protocol::{Event, TurnCompleteTrigger}; + /// + /// match event { + /// Event::TurnComplete { session_id, message_id, trigger, .. } => { + /// // Finalize the message in your storage + /// archivist.finalize_message(&session_id, &message_id).await?; + /// + /// // Lock UI state + /// ui_cache.lock_message(&message_id); + /// + /// // Log trigger for debugging + /// println!("Turn complete via {:?}", trigger); + /// } + /// _ => {} + /// } + /// ``` + TurnComplete { + connector_id: String, + session_id: String, + message_id: String, + trigger: TurnCompleteTrigger, + }, + /// Session-level error that can be displayed in the chat UI. + /// Used when a connector encounters an error during session operations. + /// + /// # Fields + /// + /// - `error_message`: Human-readable error summary + /// - `is_recoverable`: Whether the session can continue after this error + /// - `error_code`: Optional categorization code (e.g., "TRANSPORT_PARSE_FAILED") + /// - `technical_details`: Optional full technical details (truncated if large) + /// - `context`: Optional JSON blob with structured error context for debugging + SessionError { + connector_id: String, + session_id: String, + error_message: String, + /// Whether the session can continue after this error. + /// If false, the session should be considered terminated. + is_recoverable: bool, + /// Error categorization code for UI grouping and filtering. + /// Examples: "TRANSPORT_PARSE_FAILED", "SESSION_NOT_FOUND", "TIMEOUT" + #[serde(skip_serializing_if = "Option::is_none")] + error_code: Option<String>, + /// Full technical details including stack traces, received content, etc. + /// May be truncated if the original content was very large. + #[serde(skip_serializing_if = "Option::is_none")] + technical_details: Option<String>, + /// Structured error context for debug view (JSON blob). + /// Contains machine-readable error information. + #[serde(skip_serializing_if = "Option::is_none")] + context: Option<serde_json::Value>, + }, + /// Session was transferred from one connector to another + /// + /// Emitted by CoreRuntime when a session transfer completes successfully. + /// ACP Server should update client mappings on receiving this event. + SessionTransferred { + /// Source connector ID (where transfer originated) + from_connector: String, + /// Source session ID + from_session: String, + /// Target connector ID (where session is now active) + to_connector: String, + /// New session ID in target connector + to_session: String, + /// Whether a new session was created (true) or existing loaded (false) + is_new_session: bool, + /// Available models and current model from the new connector (optional) + #[serde(skip_serializing_if = "Option::is_none")] + models: Option<SessionModelState>, + /// Available modes and current mode from the new connector (optional) + #[serde(skip_serializing_if = "Option::is_none")] + modes: Option<SessionModeState>, + }, + /// Emitted by archivist when a session registration is durable and list-stable. + /// + /// Frontend can use this to refresh the session list with confidence that + /// the session will appear (it's been written to the archive index). + /// This replaces any timeout-based delay hacks after session creation. + /// + /// # Fields + /// + /// - `connector_id`: The connector that owns this session + /// - `session_id`: The native session ID from the connector + /// - `scroll_id`: The archivist's canonical scroll_id for this session + SessionRegistered { + connector_id: String, + session_id: String, + /// The archivist's canonical scroll_id for this session + scroll_id: String, + }, + + /// A forwarded session encountered a failure + /// + /// Emitted when a connector that received a transferred session fails. + /// Clients should be routed back to Gateway automatically. + ForwardingPanic { + /// Connector that failed + connector_id: String, + /// Session that was affected + session_id: String, + /// Human-readable reason for the failure + reason: String, + /// ID of the Gateway session to fall back to (if available) + fallback_gateway_session: Option<String>, + }, + /// New ACP-style session update (replaces MessagePartAdded for new consumers) + SessionUpdate { + connector_id: String, + session_id: String, + update: SessionUpdate, + }, + /// Agent-initiated request requiring client response (e.g., permission prompt) + /// + /// Emitted when an agent sends a request (like session/request_permission) that + /// requires user input. The client should respond via the appropriate API endpoint. + /// + /// # Permission Flow Routing + /// + /// - `is_forwarded: false` - Internal session (UI-owned) → Show modal in web UI + /// - `is_forwarded: true` - External session (ACP client-owned) → Forward to EventBridge, DO NOT show UI modal + /// + /// The `is_forwarded` field determines whether this permission request should be handled + /// by the Dirigent web UI or forwarded to an external ACP client that owns the session. + AgentRequest { + /// ID of the connector that received the request + connector_id: String, + /// Session ID from the request parameters + session_id: String, + /// Request ID from the agent (for correlating the response) + request_id: serde_json::Value, + /// Method being requested (e.g., "session/request_permission") + method: String, + /// Request parameters from the agent + params: serde_json::Value, + /// Whether this is a forwarded (external) session. + /// + /// If `true`, the UI MUST NOT show a permission modal. Instead, the EventBridge + /// should forward this request to the external ACP client that owns the session. + /// + /// If `false`, this is an internal session and the UI should show the permission modal. + is_forwarded: bool, + }, + + // ACP Client Connection Events (for UI visibility of incoming connections) + /// An ACP client has connected to the server + /// + /// Emitted when a new client connects via the ACP Server. + /// Used by UI to show incoming connections in the sidebar. + AcpClientConnected { + /// Unique client identifier (UUID7) + client_id: String, + /// When the client connected (ISO 8601 timestamp) + connected_at: String, + /// Optional client capabilities from the initialize handshake + capabilities: Option<serde_json::Value>, + /// The Acceptor connector's UID (for archivist meta session creation) + connector_uid: String, + }, + /// An ACP client has disconnected from the server + /// + /// Emitted when a client disconnects (explicitly or due to connection loss). + /// The client record should be marked as disconnected, not removed (for history). + AcpClientDisconnected { + /// Unique client identifier + client_id: String, + /// When the client disconnected (ISO 8601 timestamp) + disconnected_at: String, + /// Optional reason for disconnection + reason: Option<String>, + }, + /// An ACP client has opened a new session via Gateway + /// + /// Emitted when a client creates a new session through the ACP Server. + /// This adds an entry to the connection history. + AcpClientSessionOpened { + /// Client that opened the session + client_id: String, + /// The Gateway session ID (or initial session before routing) + gateway_session_id: String, + /// The client-facing session ID + client_session_id: String, + /// When this occurred (ISO 8601 timestamp) + timestamp: String, + }, + /// An ACP client's session was routed to a different connector + /// + /// Emitted when a session is transferred from Gateway to another connector. + /// This adds an entry to the connection history showing the route change. + AcpClientSessionRouted { + /// Client whose session was routed + client_id: String, + /// Original session ID (typically Gateway session) + from_session_id: String, + /// New session ID in the target connector + to_session_id: String, + /// Target connector ID + connector_id: String, + /// Target connector title (for display) + connector_title: String, + /// Connector kind (e.g., "opencode", "acp", "gateway") + #[serde(default)] + connector_kind: Option<String>, + /// Current model being used (if known) + #[serde(default)] + model: Option<String>, + /// Agent version/name info (if available) + #[serde(default)] + agent_info: Option<String>, + /// When this occurred (ISO 8601 timestamp) + timestamp: String, + }, + + // Message events + MessagesListed { + messages: Vec<Message>, + }, + MessageStarted { + connector_id: String, + message: Message, + }, + MessageCompleted { + connector_id: String, + message: Message, + }, + MessageFailed { + message_id: String, + error: String, + }, + + // Connector lifecycle events + ConnectorCreated { + connector_id: String, + kind: String, + title: String, + }, + ConnectorRemoved { + connector_id: String, + }, + ConnectorStateChanged { + connector_id: String, + state: String, + /// Machine-readable error classification ("offline", "unstable", "connection_failed") + #[serde(default, skip_serializing_if = "Option::is_none")] + error_kind: Option<String>, + }, + + // System events + Connected, + Disconnected, + Error { + message: String, + }, + + // Inspector events (runtime tree visualization) + /// Full snapshot of the inspector tree — sent on initial connection + /// and can be requested via server function. + InspectorSnapshot { + /// ISO 8601 timestamp of the snapshot + timestamp: String, + /// All nodes in the tree + nodes: Vec<InspectorSnapshotNode>, + /// Total node count + node_count: usize, + }, + /// A new node was registered in the inspector tree + InspectorNodeRegistered { + id: String, + parent: String, + kind: String, + }, + /// A node was removed from the inspector tree + InspectorNodeRemoved { + id: String, + }, + /// A node's lifecycle state changed + InspectorStateChanged { + id: String, + old: String, + new: String, + }, + /// A node's properties were updated + InspectorPropertiesUpdated { + id: String, + keys: Vec<String>, + }, + + // System task events + /// A background system task changed status (completed, failed, cancelled). + /// + /// Emitted by the SystemTaskRegistry when a task reaches a terminal state. + /// Allows the UI to react to task completion without polling. + SystemTaskStatusChanged { + /// Unique task identifier (UUIDv7) + task_id: String, + /// What kind of operation (e.g., "ClaudeImport") + kind: String, + /// Terminal status: "completed", "failed", or "cancelled" + status: String, + /// JSON result payload (present when status == "completed") + #[serde(skip_serializing_if = "Option::is_none")] + result_json: Option<String>, + /// Error message (present when status == "failed") + #[serde(skip_serializing_if = "Option::is_none")] + error: Option<String>, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::ContentBlock; + + #[test] + fn test_session_update_variant_serialization() { + let update = SessionUpdate::UserMessageChunk { + message_id: "msg_123".to_string(), + content: ContentBlock::Text { + text: "Hello from event".to_string(), + }, + _meta: None, + }; + + let event = Event::SessionUpdate { + connector_id: "conn_123".to_string(), + session_id: "session_456".to_string(), + update, + }; + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains(r#""event":"SessionUpdate"#)); + assert!(json.contains(r#""session_id":"session_456"#)); + assert!(json.contains(r#""type":"user_message_chunk"#)); + assert!(json.contains(r#""message_id":"msg_123"#)); + assert!(json.contains(r#""text":"Hello from event"#)); + } + + #[test] + fn test_session_update_variant_deserialization() { + let json = r#"{ + "event": "SessionUpdate", + "data": { + "connector_id": "conn_123", + "session_id": "session_789", + "update": { + "type": "agent_message_chunk", + "message_id": "msg_789", + "content": { + "type": "text", + "text": "Agent response" + } + } + } + }"#; + + let event: Event = serde_json::from_str(json).unwrap(); + match event { + Event::SessionUpdate { + connector_id, + session_id, + update, + } => { + assert_eq!(connector_id, "conn_123"); + assert_eq!(session_id, "session_789"); + match update { + SessionUpdate::AgentMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_789"); + assert_eq!( + content, + ContentBlock::Text { + text: "Agent response".to_string() + } + ); + assert_eq!(_meta, None); + } + _ => panic!("Expected AgentMessageChunk"), + } + } + _ => panic!("Expected SessionUpdate event"), + } + } + + #[test] + fn test_session_update_variant_roundtrip() { + let original = Event::SessionUpdate { + connector_id: "conn_roundtrip".to_string(), + session_id: "session_roundtrip".to_string(), + update: SessionUpdate::AgentThoughtChunk { + message_id: "msg_thought".to_string(), + content: ContentBlock::Text { + text: "Thinking...".to_string(), + }, + _meta: None, + }, + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: Event = serde_json::from_str(&json).unwrap(); + + match (&original, &deserialized) { + ( + Event::SessionUpdate { + connector_id: cid1, + session_id: sid1, + update: update1, + }, + Event::SessionUpdate { + connector_id: cid2, + session_id: sid2, + update: update2, + }, + ) => { + assert_eq!(cid1, cid2); + assert_eq!(sid1, sid2); + assert_eq!(update1, update2); + } + _ => panic!("Roundtrip failed"), + } + } + + #[test] + fn test_existing_events_still_work() { + // Verify that existing event variants are not affected + use crate::SessionMetadata; + use chrono::Utc; + + let now = Utc::now(); + let session_created = Event::SessionCreated { + connector_id: "conn_test".to_string(), + session: Session { + id: "session_123".to_string(), + title: "Test Session".to_string(), + created_at: now, + updated_at: now, + metadata: SessionMetadata { + project_path: "/test".to_string(), + model: Some("gpt-4".to_string()), + total_messages: 0, + system_message: None, + current_mode_id: None, + _meta: None, + project_id: None, + }, + cwd: None, + models: None, + modes: None, + config_options: None, + acp_client_id: None, + }, + }; + + let json = serde_json::to_string(&session_created).unwrap(); + let _deserialized: Event = serde_json::from_str(&json).unwrap(); + } + + #[test] + fn test_session_error_serialization() { + let event = Event::SessionError { + connector_id: "acp_conn_1".to_string(), + session_id: "session_456".to_string(), + error_message: "Session not found".to_string(), + is_recoverable: false, + error_code: None, + technical_details: None, + context: None, + }; + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains(r#""event":"SessionError"#)); + assert!(json.contains(r#""connector_id":"acp_conn_1"#)); + assert!(json.contains(r#""session_id":"session_456"#)); + assert!(json.contains(r#""error_message":"Session not found"#)); + assert!(json.contains(r#""is_recoverable":false"#)); + // Optional fields should not be present when None + assert!(!json.contains(r#""error_code"#)); + assert!(!json.contains(r#""technical_details"#)); + assert!(!json.contains(r#""context"#)); + } + + #[test] + fn test_session_error_with_details_serialization() { + let event = Event::SessionError { + connector_id: "acp_conn_1".to_string(), + session_id: "session_456".to_string(), + error_message: "Transport parse failed".to_string(), + is_recoverable: true, + error_code: Some("TRANSPORT_PARSE_FAILED".to_string()), + technical_details: Some("Failed to parse JSON: expected value at line 1".to_string()), + context: Some(serde_json::json!({ + "received_bytes": 1024, + "received_preview": "option { key: ...", + "expected": "JSON-RPC message" + })), + }; + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains(r#""error_code":"TRANSPORT_PARSE_FAILED"#)); + assert!(json.contains(r#""technical_details"#)); + assert!(json.contains(r#""context"#)); + assert!(json.contains(r#""received_bytes":1024"#)); + } + + #[test] + fn test_session_error_deserialization() { + // Test backward compatibility - old format without new fields + let json = r#"{ + "event": "SessionError", + "data": { + "connector_id": "conn_test", + "session_id": "session_789", + "error_message": "Connection timeout", + "is_recoverable": true + } + }"#; + + let event: Event = serde_json::from_str(json).unwrap(); + match event { + Event::SessionError { + connector_id, + session_id, + error_message, + is_recoverable, + error_code, + technical_details, + context, + } => { + assert_eq!(connector_id, "conn_test"); + assert_eq!(session_id, "session_789"); + assert_eq!(error_message, "Connection timeout"); + assert!(is_recoverable); + assert!(error_code.is_none()); + assert!(technical_details.is_none()); + assert!(context.is_none()); + } + _ => panic!("Expected SessionError event"), + } + } + + #[test] + fn test_session_error_roundtrip() { + let original = Event::SessionError { + connector_id: "roundtrip_conn".to_string(), + session_id: "roundtrip_session".to_string(), + error_message: "API rate limit exceeded".to_string(), + is_recoverable: true, + error_code: Some("RATE_LIMITED".to_string()), + technical_details: Some("429 Too Many Requests".to_string()), + context: Some(serde_json::json!({"retry_after": 60})), + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: Event = serde_json::from_str(&json).unwrap(); + + match (&original, &deserialized) { + ( + Event::SessionError { + connector_id: cid1, + session_id: sid1, + error_message: err1, + is_recoverable: rec1, + error_code: code1, + technical_details: details1, + context: ctx1, + }, + Event::SessionError { + connector_id: cid2, + session_id: sid2, + error_message: err2, + is_recoverable: rec2, + error_code: code2, + technical_details: details2, + context: ctx2, + }, + ) => { + assert_eq!(cid1, cid2); + assert_eq!(sid1, sid2); + assert_eq!(code1, code2); + assert_eq!(details1, details2); + assert_eq!(ctx1, ctx2); + assert_eq!(err1, err2); + assert_eq!(rec1, rec2); + } + _ => panic!("Roundtrip failed"), + } + } + + #[test] + fn test_session_transferred_serialization() { + let event = Event::SessionTransferred { + from_connector: "gateway-1".to_string(), + from_session: "session-old".to_string(), + to_connector: "opencode-1".to_string(), + to_session: "session-new".to_string(), + is_new_session: true, + models: None, + modes: None, + }; + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains(r#""event":"SessionTransferred"#)); + assert!(json.contains(r#""from_connector":"gateway-1"#)); + assert!(json.contains(r#""to_connector":"opencode-1"#)); + + let deserialized: Event = serde_json::from_str(&json).unwrap(); + // Verify roundtrip + match deserialized { + Event::SessionTransferred { + from_connector, + from_session, + to_connector, + to_session, + is_new_session, + models, + modes, + } => { + assert_eq!(from_connector, "gateway-1"); + assert_eq!(from_session, "session-old"); + assert_eq!(to_connector, "opencode-1"); + assert_eq!(to_session, "session-new"); + assert!(is_new_session); + assert!(models.is_none()); + assert!(modes.is_none()); + } + _ => panic!("Expected SessionTransferred event"), + } + } + + #[test] + fn test_forwarding_panic_serialization() { + let event = Event::ForwardingPanic { + connector_id: "opencode-1".to_string(), + session_id: "session-123".to_string(), + reason: "Connection lost".to_string(), + fallback_gateway_session: Some("gateway-session-1".to_string()), + }; + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains(r#""event":"ForwardingPanic"#)); + assert!(json.contains(r#""connector_id":"opencode-1"#)); + assert!(json.contains(r#""session_id":"session-123"#)); + assert!(json.contains(r#""reason":"Connection lost"#)); + + let deserialized: Event = serde_json::from_str(&json).unwrap(); + // Verify roundtrip + match deserialized { + Event::ForwardingPanic { + connector_id, + session_id, + reason, + fallback_gateway_session, + } => { + assert_eq!(connector_id, "opencode-1"); + assert_eq!(session_id, "session-123"); + assert_eq!(reason, "Connection lost"); + assert_eq!( + fallback_gateway_session, + Some("gateway-session-1".to_string()) + ); + } + _ => panic!("Expected ForwardingPanic event"), + } + } + + #[test] + fn test_session_metadata_received_full() { + use crate::session::{ModelInfo, SessionMode, SessionModeState, SessionModelState}; + + let event = Event::SessionMetadataReceived { + connector_id: "claude-acp-1".to_string(), + session_id: "session-123".to_string(), + models: Some(SessionModelState { + available_models: vec![ + ModelInfo { + model_id: "default".to_string(), + name: "Default (recommended)".to_string(), + description: Some("Opus 4.5".to_string()), + }, + ModelInfo { + model_id: "sonnet".to_string(), + name: "Sonnet".to_string(), + description: None, + }, + ], + current_model_id: "default".to_string(), + }), + modes: Some(SessionModeState { + current_mode_id: "default".to_string(), + available_modes: vec![ + SessionMode { + id: "default".to_string(), + name: "Always Ask".to_string(), + description: Some("Prompts for permission".to_string()), + }, + SessionMode { + id: "plan".to_string(), + name: "Plan Mode".to_string(), + description: None, + }, + ], + }), + config_options: None, + }; + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains(r#""event":"SessionMetadataReceived"#)); + assert!(json.contains(r#""connector_id":"claude-acp-1"#)); + assert!(json.contains(r#""session_id":"session-123"#)); + // Check camelCase in nested types + assert!(json.contains("availableModels")); + assert!(json.contains("currentModelId")); + assert!(json.contains("availableModes")); + assert!(json.contains("currentModeId")); + + let deserialized: Event = serde_json::from_str(&json).unwrap(); + match deserialized { + Event::SessionMetadataReceived { + connector_id, + session_id, + models, + modes, + .. + } => { + assert_eq!(connector_id, "claude-acp-1"); + assert_eq!(session_id, "session-123"); + assert!(models.is_some()); + assert!(modes.is_some()); + let models = models.unwrap(); + assert_eq!(models.current_model_id, "default"); + assert_eq!(models.available_models.len(), 2); + let modes = modes.unwrap(); + assert_eq!(modes.current_mode_id, "default"); + assert_eq!(modes.available_modes.len(), 2); + } + _ => panic!("Expected SessionMetadataReceived event"), + } + } + + #[test] + fn test_session_metadata_received_partial() { + // Test with only modes (models is None) + use crate::session::{SessionMode, SessionModeState}; + + let event = Event::SessionMetadataReceived { + connector_id: "gateway-1".to_string(), + session_id: "session-456".to_string(), + models: None, + modes: Some(SessionModeState { + current_mode_id: "default".to_string(), + available_modes: vec![SessionMode { + id: "default".to_string(), + name: "Default".to_string(), + description: None, + }], + }), + config_options: None, + }; + + let json = serde_json::to_string(&event).unwrap(); + // models should be skipped when None + assert!(!json.contains("availableModels")); + assert!(json.contains("availableModes")); + + let deserialized: Event = serde_json::from_str(&json).unwrap(); + match deserialized { + Event::SessionMetadataReceived { models, modes, .. } => { + assert!(models.is_none()); + assert!(modes.is_some()); + } + _ => panic!("Expected SessionMetadataReceived event"), + } + } + + #[test] + fn test_session_metadata_received_empty() { + // Test with both None (connector provides no metadata) + let event = Event::SessionMetadataReceived { + connector_id: "generic-1".to_string(), + session_id: "session-789".to_string(), + models: None, + modes: None, + config_options: None, + }; + + let json = serde_json::to_string(&event).unwrap(); + // Both should be skipped when None + assert!(!json.contains("models")); + assert!(!json.contains("modes")); + + let deserialized: Event = serde_json::from_str(&json).unwrap(); + match deserialized { + Event::SessionMetadataReceived { models, modes, .. } => { + assert!(models.is_none()); + assert!(modes.is_none()); + } + _ => panic!("Expected SessionMetadataReceived event"), + } + } +} diff --git a/crates/dirigent_protocol/src/inspector.rs b/crates/dirigent_protocol/src/inspector.rs new file mode 100644 index 0000000..66cd244 --- /dev/null +++ b/crates/dirigent_protocol/src/inspector.rs @@ -0,0 +1,145 @@ +//! Inspector node types for the process tree. +//! +//! These types represent nodes in the inspector's hierarchical process tree, +//! providing a canonical definition that can be shared between the server-side +//! inspector and WASM-based UI. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::fmt; + +/// Hierarchical identifier for a node in the inspector tree. +/// +/// Uses `/`-separated segments (e.g., `"root/connector-1/process-a"`). +#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, Deserialize)] +pub struct NodeId(pub String); + +impl NodeId { + pub fn new(id: impl Into<String>) -> Self { + Self(id.into()) + } + + /// Create a child node ID by appending a segment. + pub fn child(&self, segment: &str) -> Self { + Self(format!("{}/{}", self.0, segment)) + } + + /// Get the parent node ID (everything before the last `/`). + pub fn parent(&self) -> Option<Self> { + self.0.rfind('/').map(|idx| Self(self.0[..idx].to_string())) + } + + /// Get the last segment of the path (the node's own name). + pub fn name(&self) -> &str { + self.0.rsplit('/').next().unwrap_or(&self.0) + } + + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl fmt::Display for NodeId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<&str> for NodeId { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl From<String> for NodeId { + fn from(s: String) -> Self { + Self(s) + } +} + +/// The kind of node in the inspector tree. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum NodeKind { + Root, + Connector, + Process, + Service, + AsyncTask, + System, + Custom(String), +} + +impl fmt::Display for NodeKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + NodeKind::Root => write!(f, "Root"), + NodeKind::Connector => write!(f, "Connector"), + NodeKind::Process => write!(f, "Process"), + NodeKind::Service => write!(f, "Service"), + NodeKind::AsyncTask => write!(f, "AsyncTask"), + NodeKind::System => write!(f, "System"), + NodeKind::Custom(name) => write!(f, "Custom({})", name), + } + } +} + +/// The runtime state of an inspector node. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum NodeState { + Initializing, + Running, + Idle, + Busy(String), + Degraded(String), + Error(String), + Stopped, +} + +impl fmt::Display for NodeState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + NodeState::Initializing => write!(f, "Initializing"), + NodeState::Running => write!(f, "Running"), + NodeState::Idle => write!(f, "Idle"), + NodeState::Busy(desc) => write!(f, "Busy({})", desc), + NodeState::Degraded(reason) => write!(f, "Degraded({})", reason), + NodeState::Error(msg) => write!(f, "Error({})", msg), + NodeState::Stopped => write!(f, "Stopped"), + } + } +} + +/// Metadata associated with an inspector node. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct NodeMetadata { + pub kind: NodeKind, + pub label: String, + pub state: NodeState, + pub created_at: chrono::DateTime<chrono::Utc>, + pub last_updated: chrono::DateTime<chrono::Utc>, + pub properties: HashMap<String, serde_json::Value>, +} + +impl NodeMetadata { + pub fn new(kind: NodeKind, label: impl Into<String>) -> Self { + let now = chrono::Utc::now(); + Self { + kind, + label: label.into(), + state: NodeState::Initializing, + created_at: now, + last_updated: now, + properties: HashMap::new(), + } + } + + pub fn with_state(mut self, state: NodeState) -> Self { + self.state = state; + self + } + + pub fn with_property(mut self, key: impl Into<String>, value: serde_json::Value) -> Self { + self.properties.insert(key.into(), value); + self + } +} diff --git a/crates/dirigent_protocol/src/lib.rs b/crates/dirigent_protocol/src/lib.rs new file mode 100644 index 0000000..c063524 --- /dev/null +++ b/crates/dirigent_protocol/src/lib.rs @@ -0,0 +1,187 @@ +pub mod accumulator; +pub mod adapters; +pub mod conversation; +pub mod events; +pub mod inspector; +pub mod log_utils; +pub mod project; +pub mod session; +pub mod sharing; +pub mod streaming; +pub mod types; + +pub use conversation::{Message, MessageMetadata, MessagePart, MessageRole, MessageStatus}; +pub use events::{Event, InspectorSnapshotNode, TurnCompleteTrigger}; +pub use inspector::{NodeId, NodeKind, NodeMetadata, NodeState}; +pub use session::{ + ConfigOption, ConfigOptionType, ConfigOptionValue, ModelId, ModelInfo, Session, + SessionMetadata, SessionMode, SessionModeId, SessionModeState, SessionModelState, + SessionOrigin, SessionOwnership, ToolHandler, +}; +pub use types::{ + ContentBlock, Meta, PermissionOption, PermissionOptionKind, PermissionToolCallStatus, + ProviderMeta, RequestPermissionOutcome, RequestPermissionResponse, SessionUpdate, ToolCall, + ToolCallContent, ToolCallId, ToolCallInfo, ToolCallLocation, ToolCallStatus, ToolKind, +}; +pub use sharing::{SessionShare, ShareId, ShareSummary}; +pub use accumulator::{AccumulatedMessage, AccumulatedPart, MessageAccumulator, ToolCallData as AccumulatorToolCallData}; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_public_api_imports() { + // Test that all types are accessible from the crate root + + // ContentBlock + let _content = ContentBlock::Text { + text: "test".to_string(), + }; + + // Meta and ProviderMeta + let _meta = Meta::default(); + let _provider = ProviderMeta { + name: "test".to_string(), + original_ids: None, + raw_excerpt: None, + }; + + // ToolCall, ToolCallId, ToolCallStatus + let _tool_call_id: ToolCallId = "call_123".to_string(); + let _status = ToolCallStatus::Pending; + let _tool_call = ToolCall { + id: "call_123".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Running, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + // SessionUpdate + let _update = SessionUpdate::UserMessageChunk { + message_id: "msg_123".to_string(), + content: ContentBlock::Text { + text: "Hello".to_string(), + }, + _meta: None, + }; + + // If this compiles, all types are accessible via use dirigent_protocol::{...} + } + + #[test] + fn test_session_metadata_types_accessible() { + // Test that new ACP session metadata types are accessible from crate root + + // ModelId and SessionModeId (type aliases) + let _model_id: ModelId = "default".to_string(); + let _mode_id: SessionModeId = "plan".to_string(); + + // ModelInfo + let _model_info = ModelInfo { + model_id: "default".to_string(), + name: "Default".to_string(), + description: Some("Main model".to_string()), + }; + + // SessionModelState + let _model_state = SessionModelState { + available_models: vec![_model_info], + current_model_id: "default".to_string(), + }; + + // SessionMode + let _mode = SessionMode { + id: "default".to_string(), + name: "Default".to_string(), + description: None, + }; + + // SessionModeState + let _mode_state = SessionModeState { + current_mode_id: "default".to_string(), + available_modes: vec![_mode], + }; + + // SessionMetadataReceived event + let _event = Event::SessionMetadataReceived { + connector_id: "test".to_string(), + session_id: "test".to_string(), + models: Some(_model_state), + modes: Some(_mode_state), + config_options: None, + }; + } + + #[test] + fn test_permission_types_accessible() { + // Test that ACP permission types are accessible from crate root + + // PermissionOption and PermissionOptionKind + let _option = PermissionOption { + option_id: "allow_1".to_string(), + name: "Allow once".to_string(), + kind: PermissionOptionKind::AllowOnce, + }; + + // RequestPermissionResponse and RequestPermissionOutcome + let _response = RequestPermissionResponse { + outcome: RequestPermissionOutcome::Selected { + option_id: "allow_1".to_string(), + }, + }; + + // ToolCallInfo and related types + let _info = ToolCallInfo { + tool_call_id: "call_123".to_string(), + title: "Read file".to_string(), + kind: Some(ToolKind::Read), + status: Some(PermissionToolCallStatus::Pending), + locations: Some(vec![ToolCallLocation { + path: "/test.txt".to_string(), + line: Some(10), + }]), + raw_input: None, + }; + + // If this compiles, all permission types are accessible + } + + #[test] + fn test_session_ownership_types_accessible() { + // Test that Session Ownership Model types are accessible from crate root + + // SessionOrigin + let _origin_internal = SessionOrigin::Internal; + let _origin_external = SessionOrigin::External { + client_id: "test".to_string(), + client_capabilities: None, + }; + + // ToolHandler + let _handler_agent = ToolHandler::Agent; + let _handler_dirigent = ToolHandler::Dirigent; + let _handler_forward = ToolHandler::ForwardToClient; + + // SessionOwnership and its constructors + let _ownership_default = SessionOwnership::default(); + let _ownership_internal = SessionOwnership::internal(); + let _ownership_forwarded = + SessionOwnership::external_forwarded("client-123".to_string(), None); + let _ownership_handled = SessionOwnership::external_handled("client-456".to_string(), None); + + // Test helper methods + assert!(!_ownership_internal.is_external()); + assert!(_ownership_forwarded.is_external()); + assert_eq!(_ownership_internal.client_id(), None); + assert_eq!(_ownership_forwarded.client_id(), Some("client-123")); + + // If this compiles, all ownership types are accessible + } +} diff --git a/crates/dirigent_protocol/src/log_utils.rs b/crates/dirigent_protocol/src/log_utils.rs new file mode 100644 index 0000000..3626ed3 --- /dev/null +++ b/crates/dirigent_protocol/src/log_utils.rs @@ -0,0 +1,250 @@ +/// Utilities for masking sensitive or verbose content in logs +use serde_json::Value; + +/// Truncate long strings to a reasonable length for logging +const MAX_LOG_LENGTH: usize = 100; + +/// Extract filename from a file path (handles both Unix and Windows paths) +fn extract_filename(path: &str) -> &str { + // Try Unix-style path separator first, then Windows + if path.contains('/') { + path.split('/').last().unwrap_or("file") + } else if path.contains('\\') { + path.split('\\').last().unwrap_or("file") + } else { + // No path separator, just a filename + path + } +} + +/// Mask text content in JSON values for concise logging +/// +/// This function recursively processes JSON and replaces long text fields +/// with truncated versions or placeholders, while preserving structure +/// and non-text metadata. +pub fn mask_content(value: &Value) -> Value { + match value { + Value::String(s) => { + if s.len() > MAX_LOG_LENGTH { + Value::String(format!("... ({} chars)", s.len())) + } else { + Value::String(s.clone()) + } + } + Value::Array(arr) => { + Value::Array(arr.iter().map(mask_content).collect()) + } + Value::Object(obj) => { + let mut masked = serde_json::Map::new(); + for (k, v) in obj { + // Mask known content fields (only if they're strings) + if k == "text" || k == "content_md" || k == "message" || k == "thinking" { + if let Value::String(s) = v { + if s.len() > 50 { + masked.insert(k.clone(), Value::String(format!("... ({} chars)", s.len()))); + } else if s.len() > 0 { + masked.insert(k.clone(), Value::String("...".to_string())); + } else { + masked.insert(k.clone(), Value::String("".to_string())); + } + } else { + // Not a string, recurse (e.g., content as object) + masked.insert(k.clone(), mask_content(v)); + } + } + // Mask raw input/output fields (large blobs of data) + else if k == "rawOutput" || k == "rawInput" || k == "raw_output" || k == "raw_input" { + if let Value::String(s) = v { + if s.len() > 50 { + masked.insert(k.clone(), Value::String(format!("... ({} chars)", s.len()))); + } else if s.len() > 0 { + masked.insert(k.clone(), Value::String("...".to_string())); + } else { + masked.insert(k.clone(), Value::String("".to_string())); + } + } else { + // Not a string, recurse + masked.insert(k.clone(), mask_content(v)); + } + } + // Mask filename/path fields + else if k == "filename" || k == "file" || k == "path" || k == "file_path" || k == "filepath" { + if let Value::String(s) = v { + // Extract just the filename from path for debugging context + let filename = extract_filename(s); + masked.insert(k.clone(), Value::String(format!("<{}>", filename))); + } else if let Value::Array(arr) = v { + // Array of filenames + let masked_arr: Vec<Value> = arr.iter().map(|item| { + if let Value::String(s) = item { + let filename = extract_filename(s); + Value::String(format!("<{}>", filename)) + } else { + item.clone() + } + }).collect(); + masked.insert(k.clone(), Value::Array(masked_arr)); + } else { + masked.insert(k.clone(), mask_content(v)); + } + } + // Mask arrays of filenames + else if k == "filenames" || k == "files" || k == "paths" { + if let Value::Array(arr) = v { + let masked_arr: Vec<Value> = arr.iter().map(|item| { + if let Value::String(s) = item { + let filename = extract_filename(s); + Value::String(format!("<{}>", filename)) + } else { + mask_content(item) + } + }).collect(); + masked.insert(k.clone(), Value::Array(masked_arr)); + } else { + masked.insert(k.clone(), mask_content(v)); + } + } + else { + masked.insert(k.clone(), mask_content(v)); + } + } + Value::Object(masked) + } + _ => value.clone(), + } +} + +/// Format a Value for logging with content masked +pub fn format_for_log(value: &Value) -> String { + let masked = mask_content(value); + serde_json::to_string(&masked).unwrap_or_else(|_| "{}".to_string()) +} + +/// Mask content in a JSON string, returning a masked JSON string +pub fn mask_json_string(json_str: &str) -> String { + match serde_json::from_str::<Value>(json_str) { + Ok(value) => format_for_log(&value), + Err(_) => { + // If not valid JSON, just truncate + if json_str.len() > MAX_LOG_LENGTH { + format!("{}... ({} bytes)", &json_str[..MAX_LOG_LENGTH], json_str.len()) + } else { + json_str.to_string() + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_mask_short_string() { + let value = Value::String("short".to_string()); + let masked = mask_content(&value); + assert_eq!(masked, Value::String("short".to_string())); + } + + #[test] + fn test_mask_long_string() { + let long_str = "a".repeat(200); + let value = Value::String(long_str.clone()); + let masked = mask_content(&value); + if let Value::String(s) = masked { + assert!(s.contains("(200 chars)")); + assert!(s.len() < long_str.len()); + } else { + panic!("Expected string"); + } + } + + #[test] + fn test_mask_text_field() { + let value = json!({ + "text": "This is some long message content that should be masked", + "message_id": "123", + "role": "user" + }); + let masked = mask_content(&value); + assert_eq!(masked["message_id"], "123"); + assert_eq!(masked["role"], "user"); + // Text is 56 chars, which is > 50, so should be masked with char count + assert_eq!(masked["text"], "... (55 chars)"); + } + + #[test] + fn test_mask_nested_content() { + let value = json!({ + "sessionId": "abc-123", + "update": { + "content": [{ + "type": "content", + "content": { + "type": "text", + "text": "A".repeat(500) + } + }], + "sessionUpdate": "tool_call_update", + "status": "completed" + } + }); + let masked = mask_content(&value); + assert_eq!(masked["sessionId"], "abc-123"); + assert_eq!(masked["update"]["sessionUpdate"], "tool_call_update"); + assert_eq!(masked["update"]["status"], "completed"); + + // Check that the text field is masked + let text_value = &masked["update"]["content"][0]["content"]["text"]; + if let Value::String(s) = text_value { + assert!(s.contains("(500 chars)")); + } else { + panic!("Expected text to be masked"); + } + } + + #[test] + fn test_mask_content_as_object() { + // Ensure "content" as object is NOT masked, only strings + let value = json!({ + "content": { + "type": "text", + "text": "Some message" + }, + "message_id": "123" + }); + let masked = mask_content(&value); + assert_eq!(masked["message_id"], "123"); + assert_eq!(masked["content"]["type"], "text"); + assert_eq!(masked["content"]["text"], "..."); // text field is masked + } + + #[test] + fn test_mask_filename() { + let value = json!({ + "filename": "/Users/name/Projects/dirigent/packages/web/src/main.rs", + "operation": "read" + }); + let masked = mask_content(&value); + assert_eq!(masked["operation"], "read"); + assert_eq!(masked["filename"], "<main.rs>"); + } + + #[test] + fn test_mask_filenames_array() { + let value = json!({ + "filenames": [ + "/Users/name/Projects/dirigent/packages/web/src/main.rs", + "/Users/name/Projects/dirigent/packages/api/src/core.rs", + "C:\\Users\\name\\Documents\\file.txt" + ], + "count": 3 + }); + let masked = mask_content(&value); + assert_eq!(masked["count"], 3); + assert_eq!(masked["filenames"][0], "<main.rs>"); + assert_eq!(masked["filenames"][1], "<core.rs>"); + assert_eq!(masked["filenames"][2], "<file.txt>"); + } +} diff --git a/crates/dirigent_protocol/src/project.rs b/crates/dirigent_protocol/src/project.rs new file mode 100644 index 0000000..ea129aa --- /dev/null +++ b/crates/dirigent_protocol/src/project.rs @@ -0,0 +1,267 @@ +//! Project types for the Dirigent system. +//! +//! WASM-compatible shared types for the Projects module. These types are +//! used by both server and client (web UI) code. + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use uuid::Uuid; + +/// A project in the Dirigent system. +/// +/// Projects organize work across repositories, sessions, and connectors. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Project { + /// Unique project identifier (UUID v7) + pub id: Uuid, + /// Human-readable project name + pub name: String, + /// Project description (empty by default) + #[serde(default)] + pub description: String, + /// Optional icon (emoji or abbreviation) + #[serde(skip_serializing_if = "Option::is_none")] + pub icon: Option<String>, + /// Owner user ID + pub owner: Uuid, + /// Member user IDs + #[serde(default)] + pub members: Vec<Uuid>, + /// Categorization tags + #[serde(default)] + pub tags: Vec<String>, + /// Programming languages used + #[serde(default)] + pub languages: Vec<String>, + /// Linked project IDs (for multi-project setups) + #[serde(default)] + pub linked_projects: Vec<Uuid>, + /// Arbitrary metadata + #[serde(default = "default_metadata")] + pub metadata: serde_json::Value, + /// When this project was created + pub created_at: DateTime<Utc>, + /// When this project was last updated + pub updated_at: DateTime<Utc>, +} + +fn default_metadata() -> serde_json::Value { + serde_json::Value::Object(serde_json::Map::new()) +} + +/// A local git repository associated with a project. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ProjectRepository { + /// Unique repository identifier (UUID v7) + pub id: Uuid, + /// Project this repository belongs to + pub project_id: Uuid, + /// Local filesystem path + pub path: PathBuf, + /// Whether this is the primary repository + #[serde(default)] + pub is_primary: bool, + /// Optional human-readable label + #[serde(skip_serializing_if = "Option::is_none")] + pub label: Option<String>, + /// Access mode + #[serde(default)] + pub access: AccessMode, + /// When this repository was added + pub created_at: DateTime<Utc>, + /// When this repository was last updated + pub updated_at: DateTime<Utc>, +} + +/// Repository access mode. +#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)] +pub enum AccessMode { + /// Read-only access + Read, + /// Read and write access + #[default] + ReadWrite, +} + +/// A git worktree linked to a repository. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Worktree { + /// Unique worktree identifier (UUID v7) + pub id: Uuid, + /// Repository this worktree belongs to + pub repository_id: Uuid, + /// Local filesystem path + pub path: PathBuf, + /// Branch name + pub branch: String, + /// Optional work branch name + #[serde(skip_serializing_if = "Option::is_none")] + pub work_branch: Option<String>, + /// Optional naming strategy + #[serde(skip_serializing_if = "Option::is_none")] + pub naming_strategy: Option<String>, + /// When this worktree was created + pub created_at: DateTime<Utc>, +} + +/// Binding between a project and a connector/session. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ProjectBinding { + /// Unique binding identifier (UUID v7) + pub id: Uuid, + /// Project this binding belongs to + pub project_id: Uuid, + /// Optional connector ID + #[serde(skip_serializing_if = "Option::is_none")] + pub connector_id: Option<String>, + /// Optional session ID + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option<Uuid>, + /// Optional working directory override + #[serde(skip_serializing_if = "Option::is_none")] + pub working_dir: Option<PathBuf>, +} + +/// Runtime git state (not persisted, computed on demand). +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct GitState { + /// Current branch name + pub branch: String, + /// Whether there are uncommitted changes + #[serde(default)] + pub is_dirty: bool, + /// Commits ahead of remote + #[serde(default)] + pub ahead: u32, + /// Commits behind remote + #[serde(default)] + pub behind: u32, + /// Remote names + #[serde(default)] + pub remotes: Vec<String>, + /// Active worktrees + #[serde(default)] + pub worktrees: Vec<WorktreeInfo>, + /// Unexpected conditions + #[serde(default)] + pub unexpected: Vec<GitWarning>, +} + +/// Information about an active worktree. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct WorktreeInfo { + /// Worktree filesystem path + pub path: PathBuf, + /// Branch checked out (None if detached) + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option<String>, + /// Whether HEAD is detached + #[serde(default)] + pub is_detached: bool, +} + +/// A warning about unexpected git state. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GitWarning { + /// Warning code for programmatic handling + pub code: String, + /// Human-readable warning message + pub message: String, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_project_serialization_roundtrip() { + let now = Utc::now(); + let project = Project { + id: Uuid::now_v7(), + name: "Test Project".to_string(), + description: "A test project".to_string(), + icon: Some("🚀".to_string()), + owner: Uuid::now_v7(), + members: vec![], + tags: vec!["rust".to_string()], + languages: vec!["Rust".to_string()], + linked_projects: vec![], + metadata: serde_json::json!({"key": "value"}), + created_at: now, + updated_at: now, + }; + + let json = serde_json::to_string(&project).expect("serialize"); + let deser: Project = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(deser.id, project.id); + assert_eq!(deser.name, project.name); + assert_eq!(deser.icon, project.icon); + } + + #[test] + fn test_project_defaults() { + let json = r#"{ + "id": "019504a0-0000-7000-8000-000000000001", + "name": "Minimal", + "owner": "019504a0-0000-7000-8000-000000000002", + "created_at": "2026-01-01T00:00:00Z", + "updated_at": "2026-01-01T00:00:00Z" + }"#; + + let project: Project = serde_json::from_str(json).expect("deserialize"); + assert_eq!(project.description, ""); + assert!(project.tags.is_empty()); + assert!(project.members.is_empty()); + assert!(project.metadata.is_object()); + } + + #[test] + fn test_access_mode_default() { + assert_eq!(AccessMode::default(), AccessMode::ReadWrite); + } + + #[test] + fn test_project_repository_roundtrip() { + let now = Utc::now(); + let repo = ProjectRepository { + id: Uuid::now_v7(), + project_id: Uuid::now_v7(), + path: PathBuf::from("/home/user/project"), + is_primary: true, + label: Some("main".to_string()), + access: AccessMode::ReadWrite, + created_at: now, + updated_at: now, + }; + + let json = serde_json::to_string(&repo).expect("serialize"); + let deser: ProjectRepository = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(deser.id, repo.id); + assert!(deser.is_primary); + } + + #[test] + fn test_git_state_default() { + let state = GitState::default(); + assert_eq!(state.branch, ""); + assert!(!state.is_dirty); + assert_eq!(state.ahead, 0); + } + + #[test] + fn test_binding_roundtrip() { + let binding = ProjectBinding { + id: Uuid::now_v7(), + project_id: Uuid::now_v7(), + connector_id: Some("conn-1".to_string()), + session_id: None, + working_dir: None, + }; + + let json = serde_json::to_string(&binding).expect("serialize"); + let deser: ProjectBinding = serde_json::from_str(&json).expect("deserialize"); + assert_eq!(deser.id, binding.id); + assert!(deser.session_id.is_none()); + } +} diff --git a/crates/dirigent_protocol/src/session.rs b/crates/dirigent_protocol/src/session.rs new file mode 100644 index 0000000..bcc52d6 --- /dev/null +++ b/crates/dirigent_protocol/src/session.rs @@ -0,0 +1,972 @@ +use crate::types::meta::Meta; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Session { + pub id: String, + pub title: String, + pub created_at: DateTime<Utc>, + pub updated_at: DateTime<Utc>, + pub metadata: SessionMetadata, + /// Working directory for this session (if known) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cwd: Option<String>, + /// ACP model state (available models and current model) + /// Populated from archivist for archived sessions, or from SSE events for live sessions. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub models: Option<SessionModelState>, + /// ACP mode state (available modes and current mode) + /// Populated from archivist for archived sessions, or from SSE events for live sessions. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub modes: Option<SessionModeState>, + /// ACP config options (replaces modes/models in future ACP versions) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub config_options: Option<Vec<ConfigOption>>, + /// ACP client ID that owns this session. + /// For sessions created via ACP Server (incoming connections), this identifies + /// which connected client created/owns this session. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub acp_client_id: Option<String>, +} + +// ============================================================================ +// ACP Session Mode/Model Types +// ============================================================================ +// These types match the Agent-Client Protocol (ACP) specification exactly. +// They use camelCase serialization to match Claude-ACP's JSON format. +// See: docs/architecture/agent_client_protocol/schema.md + +/// Type alias for session mode identifiers (e.g., "default", "plan", "bypassPermissions") +pub type SessionModeId = String; + +/// Type alias for model identifiers (e.g., "default", "sonnet", "haiku", "opus") +pub type ModelId = String; + +/// Session mode state from ACP `session/new` response +/// +/// Contains the list of available modes and the currently active mode. +/// This is part of the stable ACP specification. +/// +/// # Example (from Claude-ACP) +/// ```json +/// { +/// "currentModeId": "default", +/// "availableModes": [ +/// {"id": "default", "name": "Always Ask", "description": "..."}, +/// {"id": "plan", "name": "Plan Mode", "description": "..."} +/// ] +/// } +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct SessionModeState { + /// The currently active mode ID + pub current_mode_id: SessionModeId, + /// List of all available modes for this session + pub available_modes: Vec<SessionMode>, +} + +/// A single session mode definition +/// +/// Modes affect agent behavior, tool availability, and permission handling. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct SessionMode { + /// Unique identifier for this mode + pub id: SessionModeId, + /// Human-readable display name + pub name: String, + /// Optional description of what this mode does + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option<String>, +} + +/// Session model state from ACP `session/new` response +/// +/// Contains the list of available models and the currently selected model. +/// Note: This field is marked UNSTABLE in the ACP spec but is used by Claude-ACP. +/// +/// # Example (from Claude-ACP) +/// ```json +/// { +/// "availableModels": [ +/// {"modelId": "default", "name": "Default (recommended)", "description": "..."}, +/// {"modelId": "sonnet", "name": "Sonnet", "description": "..."} +/// ], +/// "currentModelId": "default" +/// } +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct SessionModelState { + /// List of all available models for this session + pub available_models: Vec<ModelInfo>, + /// The currently selected model ID + pub current_model_id: ModelId, +} + +/// Information about a single model +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct ModelInfo { + /// Unique identifier for this model + pub model_id: ModelId, + /// Human-readable display name + pub name: String, + /// Optional description of the model's capabilities + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option<String>, +} + +// ============================================================================ +// ACP Config Options (replaces modes/models in future ACP versions) +// ============================================================================ + +/// A configuration option for a session (ACP configOptions). +/// +/// Agents provide config options in session/new and session/load responses. +/// Clients should use these instead of the legacy `modes`/`models` fields. +/// See: docs/architecture/agent_client_protocol/session-config-options.md +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct ConfigOption { + /// Unique identifier (e.g., "mode", "model") + pub id: String, + /// Human-readable label + pub name: String, + /// Optional description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option<String>, + /// Semantic category for UX grouping (e.g., "mode", "model", "thought_level") + #[serde(skip_serializing_if = "Option::is_none")] + pub category: Option<String>, + /// Input type (currently only "select" is defined) + #[serde(rename = "type")] + pub option_type: ConfigOptionType, + /// Currently selected value + pub current_value: String, + /// Available values for select-type options + pub options: Vec<ConfigOptionValue>, +} + +/// Type of configuration option input +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ConfigOptionType { + Select, +} + +/// A single value choice within a config option +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct ConfigOptionValue { + /// Value identifier (sent back when setting this option) + pub value: String, + /// Human-readable display name + pub name: String, + /// Optional description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option<String>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SessionMetadata { + pub project_path: String, + pub model: Option<String>, + /// Total count of user and assistant messages in the session (excludes system messages). + /// A value of 0 may indicate either an empty session or that the count has not yet been calculated. + /// Counts are populated lazily when messages are loaded for a session. + /// See `docs/architecture/session_message_counts.md` for details. + pub total_messages: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_message: Option<String>, + /// Current mode identifier for future mode tracking + #[serde(skip_serializing_if = "Option::is_none")] + pub current_mode_id: Option<String>, + /// Provider metadata for tracking original IDs and debugging information + #[serde(skip_serializing_if = "Option::is_none")] + pub _meta: Option<Meta>, + /// Optional project ID linking this session to a dirigent_projects Project. + /// When set, the session belongs to the specified project for organizational purposes. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub project_id: Option<uuid::Uuid>, +} + +// ============================================================================ +// Session Ownership Model +// ============================================================================ +// These types define how sessions are owned and how tool execution is routed. +// See: docs/architecture/session_ownership.md (Phase 7) + +/// The origin of a session - who initiated it +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum SessionOrigin { + /// Session created by Dirigent UI user + Internal, + + /// Session forwarded from external ACP client + External { + /// The ACP client ID that owns this session + client_id: String, + /// Cached client capabilities (from initialization) + #[serde(default, skip_serializing_if = "Option::is_none")] + client_capabilities: Option<serde_json::Value>, + }, + + /// Session representing a subagent or internal task (future) + Subagent { + /// Parent session that spawned this subagent + parent_session_id: String, + /// Task identifier + task_id: String, + }, +} + +impl Default for SessionOrigin { + fn default() -> Self { + Self::Internal + } +} + +/// Who handles tool execution for this session +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum ToolHandler { + /// Agent handles its own tools (default) + #[default] + Agent, + + /// Dirigent intercepts and handles tools via dirigent_tools (future) + Dirigent, + + /// Forward tool requests to originating client (External sessions only) + ForwardToClient, +} + +/// Complete ownership model for a session +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct SessionOwnership { + /// Where this session originated + #[serde(default)] + pub origin: SessionOrigin, + + /// How tool requests are handled + #[serde(default)] + pub tool_handler: ToolHandler, +} + +impl SessionOwnership { + /// Internal session with agent handling tools (default UI case) + pub fn internal() -> Self { + Self { + origin: SessionOrigin::Internal, + tool_handler: ToolHandler::Agent, + } + } + + /// External session with tools forwarded to client + pub fn external_forwarded(client_id: String, capabilities: Option<serde_json::Value>) -> Self { + Self { + origin: SessionOrigin::External { + client_id, + client_capabilities: capabilities, + }, + tool_handler: ToolHandler::ForwardToClient, + } + } + + /// External session but Dirigent handles tools + pub fn external_handled(client_id: String, capabilities: Option<serde_json::Value>) -> Self { + Self { + origin: SessionOrigin::External { + client_id, + client_capabilities: capabilities, + }, + tool_handler: ToolHandler::Dirigent, + } + } + + /// Get capabilities to advertise to agent based on ownership + pub fn capabilities_for_agent(&self) -> serde_json::Value { + match (&self.origin, &self.tool_handler) { + // External + ForwardToClient: use client's capabilities + ( + SessionOrigin::External { + client_capabilities: Some(caps), + .. + }, + ToolHandler::ForwardToClient, + ) => caps.clone(), + // Dirigent handles tools: advertise dirigent_tools capabilities + (_, ToolHandler::Dirigent) => { + serde_json::json!({ + "fs": { "readTextFile": true, "writeTextFile": true }, + "terminal": true + }) + } + // Agent handles tools or no client caps: empty (agent uses its own) + _ => serde_json::json!({}), + } + } + + /// Get the client ID if this should forward requests to a client + pub fn forward_to_client(&self) -> Option<&str> { + match (&self.origin, &self.tool_handler) { + (SessionOrigin::External { client_id, .. }, ToolHandler::ForwardToClient) => { + Some(client_id.as_str()) + } + _ => None, + } + } + + /// Check if this is an external (forwarded) session + pub fn is_external(&self) -> bool { + matches!(self.origin, SessionOrigin::External { .. }) + } + + /// Get the originating client ID if external + pub fn client_id(&self) -> Option<&str> { + match &self.origin { + SessionOrigin::External { client_id, .. } => Some(client_id.as_str()), + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::meta::{Meta, ProviderMeta}; + use std::collections::HashMap; + + // ======================================================================== + // ACP Session Mode/Model Type Tests + // ======================================================================== + + #[test] + fn test_session_mode_state_serialization_camel_case() { + // Verify camelCase serialization matches Claude-ACP format + let mode_state = SessionModeState { + current_mode_id: "default".to_string(), + available_modes: vec![ + SessionMode { + id: "default".to_string(), + name: "Always Ask".to_string(), + description: Some( + "Prompts for permission on first use of each tool".to_string(), + ), + }, + SessionMode { + id: "plan".to_string(), + name: "Plan Mode".to_string(), + description: Some("Claude can analyze but not modify files".to_string()), + }, + ], + }; + + let json = serde_json::to_string(&mode_state).unwrap(); + // Verify camelCase field names + assert!(json.contains("currentModeId")); + assert!(json.contains("availableModes")); + // Verify content + assert!(json.contains(r#""currentModeId":"default"#)); + assert!(json.contains(r#""name":"Always Ask"#)); + } + + #[test] + fn test_session_mode_state_deserialization_from_claude_format() { + // Test deserialization of actual Claude-ACP format + let json = r#"{ + "currentModeId": "default", + "availableModes": [ + { + "id": "default", + "name": "Always Ask", + "description": "Prompts for permission on first use of each tool" + }, + { + "id": "acceptEdits", + "name": "Accept Edits", + "description": "Automatically accepts file edit permissions for the session" + }, + { + "id": "plan", + "name": "Plan Mode", + "description": "Claude can analyze but not modify files or execute commands" + }, + { + "id": "bypassPermissions", + "name": "Bypass Permissions", + "description": "Skips all permission prompts" + } + ] + }"#; + + let mode_state: SessionModeState = serde_json::from_str(json).unwrap(); + assert_eq!(mode_state.current_mode_id, "default"); + assert_eq!(mode_state.available_modes.len(), 4); + assert_eq!(mode_state.available_modes[0].id, "default"); + assert_eq!(mode_state.available_modes[0].name, "Always Ask"); + assert_eq!(mode_state.available_modes[3].id, "bypassPermissions"); + } + + #[test] + fn test_session_mode_state_roundtrip() { + let original = SessionModeState { + current_mode_id: "plan".to_string(), + available_modes: vec![ + SessionMode { + id: "default".to_string(), + name: "Default".to_string(), + description: None, + }, + SessionMode { + id: "plan".to_string(), + name: "Plan".to_string(), + description: Some("Planning mode".to_string()), + }, + ], + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: SessionModeState = serde_json::from_str(&json).unwrap(); + assert_eq!(original, deserialized); + } + + #[test] + fn test_session_mode_skip_none_description() { + // Test that None descriptions are not serialized + let mode = SessionMode { + id: "test".to_string(), + name: "Test".to_string(), + description: None, + }; + + let json = serde_json::to_string(&mode).unwrap(); + assert!(!json.contains("description")); + } + + #[test] + fn test_session_model_state_serialization_camel_case() { + // Verify camelCase serialization matches Claude-ACP format + let model_state = SessionModelState { + available_models: vec![ + ModelInfo { + model_id: "default".to_string(), + name: "Default (recommended)".to_string(), + description: Some("Opus 4.5 · Most capable for complex work".to_string()), + }, + ModelInfo { + model_id: "sonnet".to_string(), + name: "Sonnet".to_string(), + description: Some("Sonnet 4.5 · Best for everyday tasks".to_string()), + }, + ], + current_model_id: "default".to_string(), + }; + + let json = serde_json::to_string(&model_state).unwrap(); + // Verify camelCase field names + assert!(json.contains("availableModels")); + assert!(json.contains("currentModelId")); + assert!(json.contains("modelId")); + // Verify content + assert!(json.contains(r#""currentModelId":"default"#)); + assert!(json.contains(r#""name":"Sonnet"#)); + } + + #[test] + fn test_session_model_state_deserialization_from_claude_format() { + // Test deserialization of actual Claude-ACP format (from zed_claude_code_direct_acp_log.txt) + let json = r#"{ + "availableModels": [ + { + "modelId": "default", + "name": "Default (recommended)", + "description": "Opus 4.5 · Most capable for complex work" + }, + { + "modelId": "sonnet", + "name": "Sonnet", + "description": "Sonnet 4.5 · Best for everyday tasks" + }, + { + "modelId": "haiku", + "name": "Haiku", + "description": "Haiku 4.5 · Fastest for quick answers" + }, + { + "modelId": "opus", + "name": "opus", + "description": "Custom model" + } + ], + "currentModelId": "default" + }"#; + + let model_state: SessionModelState = serde_json::from_str(json).unwrap(); + assert_eq!(model_state.current_model_id, "default"); + assert_eq!(model_state.available_models.len(), 4); + assert_eq!(model_state.available_models[0].model_id, "default"); + assert_eq!( + model_state.available_models[0].name, + "Default (recommended)" + ); + assert_eq!(model_state.available_models[2].model_id, "haiku"); + assert_eq!(model_state.available_models[3].model_id, "opus"); + } + + #[test] + fn test_session_model_state_roundtrip() { + let original = SessionModelState { + available_models: vec![ModelInfo { + model_id: "default".to_string(), + name: "Default".to_string(), + description: None, + }], + current_model_id: "default".to_string(), + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: SessionModelState = serde_json::from_str(&json).unwrap(); + assert_eq!(original, deserialized); + } + + #[test] + fn test_model_info_skip_none_description() { + // Test that None descriptions are not serialized + let model = ModelInfo { + model_id: "test".to_string(), + name: "Test".to_string(), + description: None, + }; + + let json = serde_json::to_string(&model).unwrap(); + assert!(!json.contains("description")); + } + + // ======================================================================== + // SessionMetadata Tests (existing) + // ======================================================================== + + #[test] + fn test_session_metadata_backward_compatibility() { + // Test that existing SessionMetadata without new fields can be deserialized + let json = r#"{ + "project_path": "/test/path", + "model": "gpt-4", + "total_messages": 10, + "system_message": "System prompt" + }"#; + + let metadata: SessionMetadata = serde_json::from_str(json).unwrap(); + assert_eq!(metadata.project_path, "/test/path"); + assert_eq!(metadata.model, Some("gpt-4".to_string())); + assert_eq!(metadata.total_messages, 10); + assert_eq!(metadata.system_message, Some("System prompt".to_string())); + assert_eq!(metadata.current_mode_id, None); + assert_eq!(metadata._meta, None); + } + + #[test] + fn test_session_metadata_skip_serializing_none() { + // Test that None values are skipped during serialization + let metadata = SessionMetadata { + project_path: "/test".to_string(), + model: Some("gpt-4".to_string()), + total_messages: 0, + system_message: None, + current_mode_id: None, + _meta: None, + project_id: None, + }; + + let json = serde_json::to_string(&metadata).unwrap(); + // Should not contain system_message, current_mode_id, or _meta fields + assert!(!json.contains("system_message")); + assert!(!json.contains("current_mode_id")); + assert!(!json.contains("_meta")); + // Should contain the present fields + assert!(json.contains("project_path")); + assert!(json.contains("model")); + assert!(json.contains("total_messages")); + } + + #[test] + fn test_session_metadata_with_current_mode_id() { + // Test serialization/deserialization with current_mode_id + let metadata = SessionMetadata { + project_path: "/test".to_string(), + model: Some("gpt-4".to_string()), + total_messages: 5, + system_message: None, + current_mode_id: Some("code_mode".to_string()), + _meta: None, + project_id: None, + }; + + let json = serde_json::to_string(&metadata).unwrap(); + assert!(json.contains("current_mode_id")); + assert!(json.contains("code_mode")); + + let deserialized: SessionMetadata = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.current_mode_id, Some("code_mode".to_string())); + } + + #[test] + fn test_session_metadata_with_meta() { + // Test serialization/deserialization with provider metadata + let meta = Meta { + provider: Some(ProviderMeta { + name: "opencode".to_string(), + original_ids: Some(HashMap::from([( + "session_id".to_string(), + "ses_abc123".to_string(), + )])), + raw_excerpt: None, + }), + extra: HashMap::new(), + }; + + let metadata = SessionMetadata { + project_path: "/test".to_string(), + model: Some("gpt-4".to_string()), + total_messages: 5, + system_message: None, + current_mode_id: None, + _meta: Some(meta), + project_id: None, + }; + + let json = serde_json::to_string(&metadata).unwrap(); + assert!(json.contains("_meta")); + assert!(json.contains("opencode")); + assert!(json.contains("ses_abc123")); + + let deserialized: SessionMetadata = serde_json::from_str(&json).unwrap(); + assert!(deserialized._meta.is_some()); + let deserialized_meta = deserialized._meta.unwrap(); + assert!(deserialized_meta.provider.is_some()); + assert_eq!(deserialized_meta.provider.unwrap().name, "opencode"); + } + + #[test] + fn test_session_metadata_with_all_fields() { + // Test with all fields populated + let meta = Meta { + provider: Some(ProviderMeta { + name: "anthropic".to_string(), + original_ids: Some(HashMap::from([( + "conversation_id".to_string(), + "conv_xyz".to_string(), + )])), + raw_excerpt: Some(serde_json::json!({"version": "1.0"})), + }), + extra: HashMap::new(), + }; + + let metadata = SessionMetadata { + project_path: "/project".to_string(), + model: Some("claude-3".to_string()), + total_messages: 42, + system_message: Some("Be helpful".to_string()), + current_mode_id: Some("architect".to_string()), + _meta: Some(meta), + project_id: None, + }; + + let json = serde_json::to_string(&metadata).unwrap(); + let deserialized: SessionMetadata = serde_json::from_str(&json).unwrap(); + + assert_eq!(metadata, deserialized); + assert!(json.contains("system_message")); + assert!(json.contains("current_mode_id")); + assert!(json.contains("_meta")); + } + + #[test] + fn test_session_roundtrip_with_new_fields() { + // Test that a Session with new metadata fields survives roundtrip + let now = Utc::now(); + let session = Session { + id: "ses_test123".to_string(), + title: "Test Session".to_string(), + created_at: now, + updated_at: now, + metadata: SessionMetadata { + project_path: "/workspace".to_string(), + model: Some("gpt-4-turbo".to_string()), + total_messages: 7, + system_message: Some("You are a coding assistant".to_string()), + current_mode_id: Some("debug_mode".to_string()), + _meta: Some(Meta { + provider: Some(ProviderMeta { + name: "test_provider".to_string(), + original_ids: None, + raw_excerpt: None, + }), + extra: HashMap::new(), + }), + project_id: None, + }, + cwd: None, + models: None, + modes: None, + config_options: None, + acp_client_id: None, + }; + + let json = serde_json::to_string(&session).unwrap(); + let deserialized: Session = serde_json::from_str(&json).unwrap(); + + assert_eq!(session, deserialized); + } + + #[test] + fn test_session_with_models_and_modes() { + // Test Session with models and modes populated + let now = Utc::now(); + let session = Session { + id: "ses_test456".to_string(), + title: "Test Session with ACP".to_string(), + created_at: now, + updated_at: now, + metadata: SessionMetadata { + project_path: "/workspace".to_string(), + model: Some("default".to_string()), + total_messages: 5, + system_message: None, + current_mode_id: Some("default".to_string()), + _meta: None, + project_id: None, + }, + cwd: None, + models: Some(SessionModelState { + available_models: vec![ + ModelInfo { + model_id: "default".to_string(), + name: "Default".to_string(), + description: Some("Default model".to_string()), + }, + ModelInfo { + model_id: "sonnet".to_string(), + name: "Sonnet".to_string(), + description: None, + }, + ], + current_model_id: "default".to_string(), + }), + modes: Some(SessionModeState { + current_mode_id: "default".to_string(), + available_modes: vec![SessionMode { + id: "default".to_string(), + name: "Always Ask".to_string(), + description: None, + }], + }), + config_options: None, + acp_client_id: Some("test-client-123".to_string()), + }; + + let json = serde_json::to_string(&session).unwrap(); + assert!(json.contains("models")); + assert!(json.contains("modes")); + assert!(json.contains("acp_client_id")); + assert!(json.contains("availableModels")); + assert!(json.contains("currentModeId")); + + let deserialized: Session = serde_json::from_str(&json).unwrap(); + assert_eq!(session, deserialized); + } + + #[test] + fn test_session_backward_compatibility_no_models_modes() { + // Test that old Session JSON without models/modes can be deserialized + let json = r#"{ + "id": "ses_old", + "title": "Old Session", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z", + "metadata": { + "project_path": "/test", + "model": "gpt-4", + "total_messages": 10 + } + }"#; + + let session: Session = serde_json::from_str(json).unwrap(); + assert_eq!(session.id, "ses_old"); + assert!(session.models.is_none()); + assert!(session.modes.is_none()); + } + + // ======================================================================== + // Session Ownership Model Tests + // ======================================================================== + + #[test] + fn test_session_origin_default() { + // Verify Internal is the default + let origin = SessionOrigin::default(); + assert_eq!(origin, SessionOrigin::Internal); + } + + #[test] + fn test_session_origin_serialization() { + // Test Internal variant + let internal = SessionOrigin::Internal; + let json = serde_json::to_string(&internal).unwrap(); + assert!(json.contains(r#""type":"internal"#)); + + // Test External variant + let external = SessionOrigin::External { + client_id: "client-123".to_string(), + client_capabilities: Some(serde_json::json!({"tools": ["bash"]})), + }; + let json = serde_json::to_string(&external).unwrap(); + assert!(json.contains(r#""type":"external"#)); + assert!(json.contains(r#""client_id":"client-123"#)); + assert!(json.contains("tools")); + + // Test Subagent variant + let subagent = SessionOrigin::Subagent { + parent_session_id: "parent-456".to_string(), + task_id: "task-789".to_string(), + }; + let json = serde_json::to_string(&subagent).unwrap(); + assert!(json.contains(r#""type":"subagent"#)); + assert!(json.contains(r#""parent_session_id":"parent-456"#)); + assert!(json.contains(r#""task_id":"task-789"#)); + } + + #[test] + fn test_tool_handler_default() { + // Verify Agent is the default + let handler = ToolHandler::default(); + assert_eq!(handler, ToolHandler::Agent); + } + + #[test] + fn test_tool_handler_serialization() { + let agent = ToolHandler::Agent; + let json = serde_json::to_string(&agent).unwrap(); + assert_eq!(json, r#""agent""#); + + let dirigent = ToolHandler::Dirigent; + let json = serde_json::to_string(&dirigent).unwrap(); + assert_eq!(json, r#""dirigent""#); + + let forward = ToolHandler::ForwardToClient; + let json = serde_json::to_string(&forward).unwrap(); + assert_eq!(json, r#""forward_to_client""#); + } + + #[test] + fn test_session_ownership_internal() { + let ownership = SessionOwnership::internal(); + assert_eq!(ownership.origin, SessionOrigin::Internal); + assert_eq!(ownership.tool_handler, ToolHandler::Agent); + assert!(!ownership.is_external()); + assert_eq!(ownership.client_id(), None); + assert_eq!(ownership.forward_to_client(), None); + } + + #[test] + fn test_session_ownership_external_forwarded() { + let caps = serde_json::json!({"tools": ["bash", "edit"]}); + let ownership = + SessionOwnership::external_forwarded("client-123".to_string(), Some(caps.clone())); + + match &ownership.origin { + SessionOrigin::External { + client_id, + client_capabilities, + } => { + assert_eq!(client_id, "client-123"); + assert_eq!(client_capabilities.as_ref().unwrap(), &caps); + } + _ => panic!("Expected External origin"), + } + + assert_eq!(ownership.tool_handler, ToolHandler::ForwardToClient); + assert!(ownership.is_external()); + assert_eq!(ownership.client_id(), Some("client-123")); + assert_eq!(ownership.forward_to_client(), Some("client-123")); + } + + #[test] + fn test_session_ownership_external_handled() { + let ownership = SessionOwnership::external_handled("client-456".to_string(), None); + + match &ownership.origin { + SessionOrigin::External { + client_id, + client_capabilities, + } => { + assert_eq!(client_id, "client-456"); + assert!(client_capabilities.is_none()); + } + _ => panic!("Expected External origin"), + } + + assert_eq!(ownership.tool_handler, ToolHandler::Dirigent); + assert!(ownership.is_external()); + assert_eq!(ownership.client_id(), Some("client-456")); + assert_eq!(ownership.forward_to_client(), None); // Dirigent handles, not forwarded + } + + #[test] + fn test_capabilities_for_agent_external_forwarded() { + let client_caps = serde_json::json!({"fs": true, "terminal": true}); + let ownership = SessionOwnership::external_forwarded( + "client-123".to_string(), + Some(client_caps.clone()), + ); + + let caps = ownership.capabilities_for_agent(); + assert_eq!(caps, client_caps); + } + + #[test] + fn test_capabilities_for_agent_dirigent() { + let ownership = SessionOwnership { + origin: SessionOrigin::Internal, + tool_handler: ToolHandler::Dirigent, + }; + + let caps = ownership.capabilities_for_agent(); + assert!(caps.is_object()); + assert!(caps.get("fs").is_some()); + assert!(caps.get("terminal").is_some()); + } + + #[test] + fn test_capabilities_for_agent_agent_handled() { + let ownership = SessionOwnership::internal(); + let caps = ownership.capabilities_for_agent(); + assert_eq!(caps, serde_json::json!({})); + } + + #[test] + fn test_session_ownership_serialization_roundtrip() { + let original = SessionOwnership::external_forwarded( + "test-client".to_string(), + Some(serde_json::json!({"test": true})), + ); + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: SessionOwnership = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.tool_handler, original.tool_handler); + assert_eq!(deserialized.client_id(), Some("test-client")); + } + + #[test] + fn test_session_ownership_default() { + let ownership = SessionOwnership::default(); + assert_eq!(ownership.origin, SessionOrigin::Internal); + assert_eq!(ownership.tool_handler, ToolHandler::Agent); + } +} diff --git a/crates/dirigent_protocol/src/sharing.rs b/crates/dirigent_protocol/src/sharing.rs new file mode 100644 index 0000000..9bac516 --- /dev/null +++ b/crates/dirigent_protocol/src/sharing.rs @@ -0,0 +1,54 @@ +//! Session sharing abstraction +//! +//! The `SessionShare` trait abstracts bidirectional bridges between Dirigent +//! sessions and external communication systems (Matrix, Slack, etc.). +//! +//! A share attaches to a (connector_id, session_id) pair without taking +//! ownership of the session. Multiple shares can coexist on the same session. + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +/// Unique identifier for a share instance. +pub type ShareId = String; + +/// Summary info about an active share. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ShareSummary { + /// Share identifier (e.g., "matrix:connector-1:session-abc") + pub id: ShareId, + /// Connector this share is attached to + pub connector_id: String, + /// Session this share is attached to + pub session_id: String, + /// Backend type (e.g., "matrix", "slack") + pub backend: String, + /// Backend-specific destination (e.g., Matrix room ID) + pub destination: String, + /// Whether the share is currently active + pub active: bool, +} + +/// Trait for session share backends. +/// +/// Implementors provide bidirectional bridging between a Dirigent session +/// and an external system. The trait is deliberately minimal — the concrete +/// implementation handles all backend-specific details. +/// +/// # Design Notes +/// +/// - Shares do NOT modify the Connector trait or require special connector support +/// - Shares use existing channels: `connector.subscribe()` for events, +/// `connector.command_tx()` for sending messages +/// - Multiple shares can be active on the same session simultaneously +#[async_trait] +pub trait SessionShare: Send + Sync { + /// Get summary information about this share. + fn summary(&self) -> ShareSummary; + + /// Check if the share is actively forwarding. + fn is_active(&self) -> bool; + + /// Gracefully shut down this share. + async fn shutdown(&self); +} diff --git a/crates/dirigent_protocol/src/streaming/bus_event.rs b/crates/dirigent_protocol/src/streaming/bus_event.rs new file mode 100644 index 0000000..c495025 --- /dev/null +++ b/crates/dirigent_protocol/src/streaming/bus_event.rs @@ -0,0 +1,353 @@ +//! Bus event envelope: wraps `Event` with routing context for subscribers. + +use std::sync::Arc; + +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::Event; + +/// Full bus event envelope: the `Event` plus routing context derived at emit time. +#[derive(Debug, Clone)] +pub struct BusEvent { + pub routing: EventRouting, + pub origin: EventOrigin, + pub event: Arc<Event>, +} + +/// Routing metadata attached to every `BusEvent`. +/// +/// `scroll_id` is intentionally left `None` at construction time and filled in +/// later by the bus cache once the archivist has registered the session. +#[derive(Debug, Clone, Default)] +pub struct EventRouting { + pub connector_uid: Option<Uuid>, + pub scroll_id: Option<Uuid>, + pub connector_id: Option<String>, + pub native_session_id: Option<String>, + pub kind: EventKind, +} + +/// High-level classification of a `BusEvent`, used for subscriber filtering. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum EventKind { + #[default] + SessionLifecycle, + Message, + Update, + System, +} + +/// Records the subsystem that originally produced the event. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum EventOrigin { + Connector { + connector_uid: Option<Uuid>, + connector_id: String, + }, + Archivist, + Runtime, + Replay { + replay_id: Uuid, + }, +} + +// ─── BusEvent constructors ─────────────────────────────────────────────────── + +impl BusEvent { + /// Wrap a connector-sourced `Event` in a `BusEvent`. + /// + /// Routing metadata is derived from the event fields; `scroll_id` is left + /// `None` and must be patched by the bus cache after archivist registration. + pub fn from_connector_event( + event: Event, + connector_uid: Option<Uuid>, + connector_id: String, + ) -> Self { + let routing = EventRouting::derive(&event, connector_uid, &connector_id); + Self { + routing, + origin: EventOrigin::Connector { + connector_uid, + connector_id, + }, + event: Arc::new(event), + } + } + + /// Wrap an archivist-sourced `Event` (e.g. `SessionRegistered`) in a + /// `BusEvent`. The archivist knows the canonical `(connector_id, + /// native_session_id, scroll_id)` triple; we pass it directly so the + /// bus does not have to consult its scroll-id cache to route the + /// event. Origin is set to `EventOrigin::Archivist`. + pub fn from_archivist_event( + event: Event, + connector_id: &str, + native_session_id: &str, + scroll_id: Option<Uuid>, + ) -> Self { + let (kind, _) = classify(&event); + let routing = EventRouting { + connector_uid: None, + scroll_id, + connector_id: Some(connector_id.to_string()), + native_session_id: Some(native_session_id.to_string()), + kind, + }; + Self { + routing, + origin: EventOrigin::Archivist, + event: Arc::new(event), + } + } +} + +// ─── EventRouting::derive ──────────────────────────────────────────────────── + +impl EventRouting { + /// Derive routing from an `Event` plus the emitting connector's identity. + /// + /// `scroll_id` is always `None` here; the bus cache fills it in later. + pub fn derive(event: &Event, connector_uid: Option<Uuid>, connector_id: &str) -> Self { + let (kind, native_session_id) = classify(event); + Self { + connector_uid, + scroll_id: None, + connector_id: Some(connector_id.to_string()), + native_session_id, + kind, + } + } +} + +/// Return the `(EventKind, Option<native_session_id>)` for a given `Event`. +/// +/// Every current variant is matched explicitly; the `_` arm is a safety net for +/// future additions and classifies them as `System` with no session context. +fn classify(event: &Event) -> (EventKind, Option<String>) { + use EventKind::*; + + match event { + // ── SessionLifecycle ───────────────────────────────────────────────── + Event::SessionsListed { .. } => (SessionLifecycle, None), + + Event::SessionCreated { session, .. } => { + (SessionLifecycle, Some(session.id.clone())) + } + + Event::SessionUpdated { session, .. } => { + (SessionLifecycle, Some(session.id.clone())) + } + + Event::SessionMetadataUpdated { session_id, .. } => { + (SessionLifecycle, Some(session_id.clone())) + } + + Event::SessionDeleted { session_id } => { + (SessionLifecycle, Some(session_id.clone())) + } + + Event::SessionClosed { session_id, .. } => { + (SessionLifecycle, Some(session_id.clone())) + } + + Event::SessionSystemMessageSet { session_id, .. } => { + (SessionLifecycle, Some(session_id.clone())) + } + + Event::SessionIdle { session_id, .. } => { + (SessionLifecycle, Some(session_id.clone())) + } + + Event::SessionMetadataReceived { session_id, .. } => { + (SessionLifecycle, Some(session_id.clone())) + } + + Event::SessionTransferred { from_session, .. } => { + (SessionLifecycle, Some(from_session.clone())) + } + + Event::SessionRegistered { session_id, .. } => { + (SessionLifecycle, Some(session_id.clone())) + } + + Event::ForwardingPanic { session_id, .. } => { + (SessionLifecycle, Some(session_id.clone())) + } + + Event::Connected => (SessionLifecycle, None), + Event::Disconnected => (SessionLifecycle, None), + + Event::ConnectorCreated { .. } => (SessionLifecycle, None), + Event::ConnectorRemoved { .. } => (SessionLifecycle, None), + Event::ConnectorStateChanged { .. } => (SessionLifecycle, None), + + Event::AcpClientConnected { .. } => (SessionLifecycle, None), + Event::AcpClientDisconnected { .. } => (SessionLifecycle, None), + + Event::AcpClientSessionOpened { + client_session_id, .. + } => (SessionLifecycle, Some(client_session_id.clone())), + + Event::AcpClientSessionRouted { from_session_id, .. } => { + (SessionLifecycle, Some(from_session_id.clone())) + } + + Event::AgentRequest { session_id, .. } => { + (SessionLifecycle, Some(session_id.clone())) + } + + // ── Message ────────────────────────────────────────────────────────── + Event::MessagesListed { .. } => (Message, None), + + Event::MessageStarted { message, .. } => { + (Message, Some(message.session_id.clone())) + } + + Event::MessageCompleted { message, .. } => { + (Message, Some(message.session_id.clone())) + } + + Event::MessageFailed { .. } => (Message, None), + + Event::TurnComplete { session_id, .. } => (Message, Some(session_id.clone())), + + // ── Update ─────────────────────────────────────────────────────────── + Event::SessionUpdate { session_id, .. } => (Update, Some(session_id.clone())), + + // ── System ─────────────────────────────────────────────────────────── + Event::Error { .. } => (System, None), + Event::SessionError { session_id, .. } => (System, Some(session_id.clone())), + Event::InspectorSnapshot { .. } => (System, None), + Event::InspectorNodeRegistered { .. } => (System, None), + Event::InspectorNodeRemoved { .. } => (System, None), + Event::InspectorStateChanged { .. } => (System, None), + Event::InspectorPropertiesUpdated { .. } => (System, None), + Event::SystemTaskStatusChanged { .. } => (System, None), + + // Safety net: future variants not yet listed above. + // #[allow] is intentional — this arm exists to catch additions to Event. + #[allow(unreachable_patterns)] + _ => (System, None), + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Event, Session, SessionMetadata}; + use chrono::Utc; + + fn minimal_session(id: &str) -> Session { + let now = Utc::now(); + Session { + id: id.to_string(), + title: "test".to_string(), + created_at: now, + updated_at: now, + metadata: SessionMetadata { + project_path: "/tmp".to_string(), + model: None, + total_messages: 0, + system_message: None, + current_mode_id: None, + _meta: None, + project_id: None, + }, + cwd: None, + models: None, + modes: None, + config_options: None, + acp_client_id: None, + } + } + + #[test] + fn derive_extracts_session_id_on_session_created() { + let event = Event::SessionCreated { + connector_id: "conn-1".to_string(), + session: minimal_session("ses-abc"), + }; + let routing = EventRouting::derive(&event, None, "conn-1"); + assert_eq!(routing.native_session_id.as_deref(), Some("ses-abc")); + assert_eq!(routing.kind, EventKind::SessionLifecycle); + } + + #[test] + fn derive_sets_kind_update_on_session_update() { + use crate::SessionUpdate as SU; + let event = Event::SessionUpdate { + connector_id: "conn-1".to_string(), + session_id: "ses-xyz".to_string(), + update: SU::Unknown { + data: serde_json::json!({"type": "unknown_future"}), + }, + }; + let routing = EventRouting::derive(&event, None, "conn-1"); + assert_eq!(routing.kind, EventKind::Update); + assert_eq!(routing.native_session_id.as_deref(), Some("ses-xyz")); + } + + #[test] + fn derive_sets_kind_message_for_message_started() { + use crate::conversation::{Message, MessageRole, MessageStatus}; + let now = Utc::now(); + let msg = Message { + id: "msg-1".to_string(), + session_id: "ses-msg".to_string(), + role: MessageRole::User, + created_at: now, + content: vec![], + status: MessageStatus::Pending, + metadata: None, + }; + let event = Event::MessageStarted { + connector_id: "conn-1".to_string(), + message: msg, + }; + let routing = EventRouting::derive(&event, None, "conn-1"); + assert_eq!(routing.kind, EventKind::Message); + assert_eq!(routing.native_session_id.as_deref(), Some("ses-msg")); + } + + #[test] + fn derive_sets_kind_system_for_error() { + let event = Event::Error { + message: "something went wrong".to_string(), + }; + let routing = EventRouting::derive(&event, None, "conn-1"); + assert_eq!(routing.kind, EventKind::System); + assert!(routing.native_session_id.is_none()); + } + + #[test] + fn event_origin_replay_roundtrips_replay_id() { + let id = Uuid::new_v4(); + let origin = EventOrigin::Replay { replay_id: id }; + match &origin { + EventOrigin::Replay { replay_id } => assert_eq!(*replay_id, id), + _ => panic!("wrong variant"), + } + } + + #[test] + fn from_connector_event_produces_connector_origin() { + let event = Event::Connected; + let uid = Uuid::nil(); + let bus = BusEvent::from_connector_event(event, Some(uid), "conn-test".to_string()); + match &bus.origin { + EventOrigin::Connector { + connector_uid, + connector_id, + } => { + assert_eq!(*connector_uid, Some(uid)); + assert_eq!(connector_id, "conn-test"); + } + _ => panic!("expected Connector origin"), + } + assert_eq!(bus.routing.connector_id.as_deref(), Some("conn-test")); + } +} diff --git a/crates/dirigent_protocol/src/streaming/filter.rs b/crates/dirigent_protocol/src/streaming/filter.rs new file mode 100644 index 0000000..f05ee12 --- /dev/null +++ b/crates/dirigent_protocol/src/streaming/filter.rs @@ -0,0 +1,244 @@ +//! Filters applied on the subscriber side of the SharingBus. + +use std::ops::{BitOr, BitOrAssign}; + +use uuid::Uuid; + +use super::bus_event::{BusEvent, EventKind}; + +/// A subscriber-side predicate that selects which `BusEvent`s to forward. +#[derive(Debug, Clone)] +pub enum EventFilter { + /// Accept every event unconditionally. + All, + /// Accept only events whose `routing.scroll_id` matches the given UUID. + ScrollId(Uuid), + /// Accept only events whose `routing.connector_uid` matches the given UUID. + ConnectorUid(Uuid), + /// Accept only events whose `routing.kind` is set in the mask. + Kinds(EventKindMask), + /// Accept events that satisfy at least one of the inner filters. + AnyOf(Vec<EventFilter>), + /// Accept events that satisfy all of the inner filters. + AllOf(Vec<EventFilter>), +} + +/// Bit-mask over `EventKind` variants for efficient kind-based filtering. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] +pub struct EventKindMask(pub u8); + +impl EventKindMask { + pub const SESSION_LIFECYCLE: Self = Self(1 << 0); + pub const MESSAGE: Self = Self(1 << 1); + pub const UPDATE: Self = Self(1 << 2); + pub const SYSTEM: Self = Self(1 << 3); + pub const ALL: Self = Self(0b1111); + + /// Returns `true` if `kind` is set in this mask. + pub fn contains(self, kind: EventKind) -> bool { + let bit = match kind { + EventKind::SessionLifecycle => Self::SESSION_LIFECYCLE, + EventKind::Message => Self::MESSAGE, + EventKind::Update => Self::UPDATE, + EventKind::System => Self::SYSTEM, + }; + (self.0 & bit.0) != 0 + } +} + +impl BitOr for EventKindMask { + type Output = Self; + + fn bitor(self, rhs: Self) -> Self::Output { + Self(self.0 | rhs.0) + } +} + +impl BitOrAssign for EventKindMask { + fn bitor_assign(&mut self, rhs: Self) { + self.0 |= rhs.0; + } +} + +impl EventFilter { + /// Returns `true` if this filter accepts the given `BusEvent`. + pub fn matches(&self, event: &BusEvent) -> bool { + match self { + EventFilter::All => true, + EventFilter::ScrollId(s) => event.routing.scroll_id == Some(*s), + EventFilter::ConnectorUid(u) => event.routing.connector_uid == Some(*u), + EventFilter::Kinds(m) => m.contains(event.routing.kind), + EventFilter::AnyOf(filters) => filters.iter().any(|f| f.matches(event)), + EventFilter::AllOf(filters) => filters.iter().all(|f| f.matches(event)), + } + } +} + +// ─── Tests ─────────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use uuid::Uuid; + + use super::*; + use crate::{ + streaming::bus_event::{EventOrigin, EventRouting}, + Event, + }; + + /// Build a minimal `BusEvent` for testing. + fn make_event( + scroll_id: Option<Uuid>, + connector_uid: Option<Uuid>, + kind: EventKind, + ) -> BusEvent { + BusEvent { + routing: EventRouting { + scroll_id, + connector_uid, + kind, + ..Default::default() + }, + origin: EventOrigin::Runtime, + event: Arc::new(Event::Connected), + } + } + + // 1. EventFilter::All matches any BusEvent. + #[test] + fn all_matches_any_event() { + let ev = make_event(None, None, EventKind::System); + assert!(EventFilter::All.matches(&ev)); + + let ev2 = make_event(Some(Uuid::new_v4()), Some(Uuid::new_v4()), EventKind::Message); + assert!(EventFilter::All.matches(&ev2)); + } + + // 2. EventFilter::ScrollId matches Some(x), rejects Some(y), rejects None. + #[test] + fn scroll_id_matches_correct_uuid_only() { + let x = Uuid::new_v4(); + let y = Uuid::new_v4(); + + let filter = EventFilter::ScrollId(x); + + let ev_match = make_event(Some(x), None, EventKind::SessionLifecycle); + assert!(filter.matches(&ev_match), "should match Some(x)"); + + let ev_other = make_event(Some(y), None, EventKind::SessionLifecycle); + assert!(!filter.matches(&ev_other), "should reject Some(y)"); + + let ev_none = make_event(None, None, EventKind::SessionLifecycle); + assert!(!filter.matches(&ev_none), "should reject None"); + } + + // 3. EventFilter::ConnectorUid matches only when routing.connector_uid == Some(u). + #[test] + fn connector_uid_matches_correct_uuid_only() { + let u = Uuid::new_v4(); + let other = Uuid::new_v4(); + + let filter = EventFilter::ConnectorUid(u); + + let ev_match = make_event(None, Some(u), EventKind::Update); + assert!(filter.matches(&ev_match)); + + let ev_other = make_event(None, Some(other), EventKind::Update); + assert!(!filter.matches(&ev_other)); + + let ev_none = make_event(None, None, EventKind::Update); + assert!(!filter.matches(&ev_none)); + } + + // 4. EventFilter::Kinds(MESSAGE) matches Message, rejects Update. + #[test] + fn kinds_mask_message_matches_message_only() { + let filter = EventFilter::Kinds(EventKindMask::MESSAGE); + + let ev_msg = make_event(None, None, EventKind::Message); + assert!(filter.matches(&ev_msg)); + + let ev_upd = make_event(None, None, EventKind::Update); + assert!(!filter.matches(&ev_upd)); + } + + // 5. AnyOf([ScrollId(X), ConnectorUid(Y)]) matches when either matches, rejects otherwise. + #[test] + fn any_of_matches_when_at_least_one_sub_filter_matches() { + let x = Uuid::new_v4(); + let y = Uuid::new_v4(); + let z = Uuid::new_v4(); + + let filter = EventFilter::AnyOf(vec![ + EventFilter::ScrollId(x), + EventFilter::ConnectorUid(y), + ]); + + // scroll_id matches + let ev_scroll = make_event(Some(x), None, EventKind::Message); + assert!(filter.matches(&ev_scroll)); + + // connector_uid matches + let ev_conn = make_event(None, Some(y), EventKind::Message); + assert!(filter.matches(&ev_conn)); + + // both match + let ev_both = make_event(Some(x), Some(y), EventKind::Message); + assert!(filter.matches(&ev_both)); + + // neither matches + let ev_neither = make_event(Some(z), Some(z), EventKind::System); + assert!(!filter.matches(&ev_neither)); + } + + // 6. AllOf([ScrollId(X), Kinds(MESSAGE)]) matches only when both hold. + #[test] + fn all_of_matches_only_when_all_sub_filters_match() { + let x = Uuid::new_v4(); + + let filter = EventFilter::AllOf(vec![ + EventFilter::ScrollId(x), + EventFilter::Kinds(EventKindMask::MESSAGE), + ]); + + // both conditions satisfied + let ev_both = make_event(Some(x), None, EventKind::Message); + assert!(filter.matches(&ev_both)); + + // scroll_id matches but wrong kind + let ev_wrong_kind = make_event(Some(x), None, EventKind::Update); + assert!(!filter.matches(&ev_wrong_kind)); + + // right kind but wrong scroll_id + let ev_wrong_scroll = make_event(Some(Uuid::new_v4()), None, EventKind::Message); + assert!(!filter.matches(&ev_wrong_scroll)); + + // neither matches + let ev_neither = make_event(None, None, EventKind::System); + assert!(!filter.matches(&ev_neither)); + } + + // 7. BitOr combining masks: (MESSAGE | UPDATE).contains(kind) is true for both. + #[test] + fn bitor_combines_masks_correctly() { + let combined = EventKindMask::MESSAGE | EventKindMask::UPDATE; + + assert!(combined.contains(EventKind::Message)); + assert!(combined.contains(EventKind::Update)); + assert!(!combined.contains(EventKind::SessionLifecycle)); + assert!(!combined.contains(EventKind::System)); + } + + // Bonus: verify BitOrAssign works the same way. + #[test] + fn bitor_assign_accumulates_bits() { + let mut mask = EventKindMask::MESSAGE; + mask |= EventKindMask::SYSTEM; + + assert!(mask.contains(EventKind::Message)); + assert!(mask.contains(EventKind::System)); + assert!(!mask.contains(EventKind::Update)); + } +} diff --git a/crates/dirigent_protocol/src/streaming/mod.rs b/crates/dirigent_protocol/src/streaming/mod.rs new file mode 100644 index 0000000..3f30e85 --- /dev/null +++ b/crates/dirigent_protocol/src/streaming/mod.rs @@ -0,0 +1,18 @@ +//! Streaming primitives shared across the runtime, archivist, and sink crates. +//! +//! `BusEvent` wraps the existing `Event` with routing context; `SessionStream` +//! is the trait every uni-directional sink implements. The existing +//! `SessionShare` trait (bi-directional, Matrix) lives in `crate::sharing` +//! and is not superseded. + +pub mod bus_event; +pub mod filter; +pub mod receiver; +pub mod stream; + +pub use bus_event::{BusEvent, EventKind, EventOrigin, EventRouting}; +pub use filter::{EventFilter, EventKindMask}; +pub use receiver::BusReceiver; +pub use stream::{ + SessionStream, StreamError, StreamKind, StreamOutcome, StreamScope, StreamSummary, +}; diff --git a/crates/dirigent_protocol/src/streaming/receiver.rs b/crates/dirigent_protocol/src/streaming/receiver.rs new file mode 100644 index 0000000..51cf71c --- /dev/null +++ b/crates/dirigent_protocol/src/streaming/receiver.rs @@ -0,0 +1,27 @@ +//! `BusReceiver`: subscriber handle returned by a SharingBus-like fan-out. +//! +//! This type lives in `dirigent_protocol` (rather than next to its only +//! producer in `dirigent_core::sharing::bus`) so that downstream consumers +//! such as `dirigent_archivist` can accept a `BusReceiver` without taking +//! on a dependency on `dirigent_core` — which would be a dependency cycle. +//! +//! It is intentionally a dumb data container: just `id`, `rx`, and the +//! `lagged` counter that the producer task increments when it has to drop +//! events for a slow subscriber. No logic lives here. + +use std::sync::Arc; +use std::sync::atomic::AtomicU64; + +use tokio::sync::mpsc; + +use crate::streaming::BusEvent; + +/// Receiver handle returned to subscribers by a SharingBus-style fan-out. +/// +/// `lagged` counts how many events were dropped because the underlying +/// mpsc queue was full when the worker tried to deliver. +pub struct BusReceiver { + pub id: u64, + pub rx: mpsc::Receiver<BusEvent>, + pub lagged: Arc<AtomicU64>, +} diff --git a/crates/dirigent_protocol/src/streaming/stream.rs b/crates/dirigent_protocol/src/streaming/stream.rs new file mode 100644 index 0000000..7b17dc6 --- /dev/null +++ b/crates/dirigent_protocol/src/streaming/stream.rs @@ -0,0 +1,59 @@ +//! SessionStream: uni-directional sink trait. Archive backends use +//! ArchiveBackend; live-write sinks like Langfuse use SessionStream. + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use uuid::Uuid; + +use super::bus_event::BusEvent; + +#[async_trait] +pub trait SessionStream: Send + Sync { + fn summary(&self) -> StreamSummary; + fn scope(&self) -> StreamScope; + async fn on_event(&self, event: &BusEvent) -> StreamOutcome; + async fn shutdown(&self); +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StreamSummary { + pub name: String, + pub kind: StreamKind, + pub target: String, // human-readable ("langfuse: https://…", "matrix: #room:server") + pub active_since: DateTime<Utc>, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum StreamKind { + Matrix, + Langfuse, + Slack, + Webhook, + Custom, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case", tag = "kind")] +pub enum StreamScope { + Session { scroll_id: Uuid }, + Connector { connector_uid: Uuid }, + ArchiveWide { acknowledged: bool }, +} + +#[derive(Debug)] +pub enum StreamOutcome { + Ok, + Skipped, + Failed(StreamError), +} + +#[derive(Debug, Error)] +pub enum StreamError { + #[error("transport: {0}")] Transport(String), + #[error("serialisation: {0}")] Serialisation(String), + #[error("rejected: {0}")] Rejected(String), + #[error("shutdown")] Shutdown, +} diff --git a/crates/dirigent_protocol/src/types/content.rs b/crates/dirigent_protocol/src/types/content.rs new file mode 100644 index 0000000..6c73278 --- /dev/null +++ b/crates/dirigent_protocol/src/types/content.rs @@ -0,0 +1,55 @@ +use serde::{Deserialize, Serialize}; + +/// MCP-style content block for displayable content +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentBlock { + Text { + text: String, + }, + ResourceLink { + uri: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + mime_type: Option<String>, + }, + // Future: Resource, Image, Audio (marked as out-of-scope for phase 1) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_text_serialization() { + let block = ContentBlock::Text { + text: "Hello".to_string(), + }; + let json = serde_json::to_string(&block).unwrap(); + assert!(json.contains(r#""type":"text"#)); + assert!(json.contains(r#""text":"Hello"#)); + } + + #[test] + fn test_resource_link_serialization() { + let block = ContentBlock::ResourceLink { + uri: "file:///path/to/file".to_string(), + name: Some("file.txt".to_string()), + mime_type: Some("text/plain".to_string()), + }; + let json = serde_json::to_string(&block).unwrap(); + assert!(json.contains(r#""type":"resource_link"#)); + assert!(json.contains(r#""uri":"file:///path/to/file"#)); + } + + #[test] + fn test_roundtrip() { + let block = ContentBlock::Text { + text: "Test".to_string(), + }; + let json = serde_json::to_string(&block).unwrap(); + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert_eq!(block, deserialized); + } +} diff --git a/crates/dirigent_protocol/src/types/meta.rs b/crates/dirigent_protocol/src/types/meta.rs new file mode 100644 index 0000000..0bfd523 --- /dev/null +++ b/crates/dirigent_protocol/src/types/meta.rs @@ -0,0 +1,231 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +pub struct Meta { + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option<ProviderMeta>, + + /// Arbitrary extra fields + #[serde(flatten)] + pub extra: HashMap<String, Value>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ProviderMeta { + /// Provider name (e.g., "opencode", "anthropic") + pub name: String, + + /// Original provider-specific IDs + #[serde(skip_serializing_if = "Option::is_none")] + pub original_ids: Option<HashMap<String, String>>, + + /// Minimal raw payload excerpts for debugging (optional) + #[serde(skip_serializing_if = "Option::is_none")] + pub raw_excerpt: Option<Value>, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_meta_default() { + let meta = Meta::default(); + assert_eq!(meta.provider, None); + assert!(meta.extra.is_empty()); + } + + #[test] + fn test_meta_with_provider() { + let meta = Meta { + provider: Some(ProviderMeta { + name: "opencode".to_string(), + original_ids: Some(HashMap::from([ + ("session_id".to_string(), "abc123".to_string()), + ])), + raw_excerpt: None, + }), + extra: HashMap::new(), + }; + let json = serde_json::to_string(&meta).unwrap(); + assert!(json.contains(r#""name":"opencode"#)); + } + + #[test] + fn test_skip_serializing_if_none() { + let meta = Meta::default(); + let json = serde_json::to_string(&meta).unwrap(); + // Should be empty object since provider is None + assert_eq!(json, "{}"); + } + + #[test] + fn test_roundtrip() { + let meta = Meta { + provider: Some(ProviderMeta { + name: "test".to_string(), + original_ids: None, + raw_excerpt: None, + }), + extra: HashMap::new(), + }; + let json = serde_json::to_string(&meta).unwrap(); + let deserialized: Meta = serde_json::from_str(&json).unwrap(); + assert_eq!(meta, deserialized); + } + + #[test] + fn test_meta_preserves_claude_code_tool_response() { + use serde_json::json; + + // Test that complex nested _meta data is preserved (T013 requirement) + let meta_json = json!({ + "claudeCode": { + "toolResponse": { + "mode": "content", + "numFiles": 0, + "filenames": [], + "content": "some grep output here", + "numLines": 58, + "appliedLimit": 100 + }, + "toolName": "Grep" + } + }); + + // Deserialize into Meta + let meta: Meta = serde_json::from_value(meta_json.clone()).unwrap(); + + // Verify claudeCode is in extra + assert!(meta.extra.contains_key("claudeCode")); + + // Serialize back to JSON + let serialized = serde_json::to_value(&meta).unwrap(); + + // Verify all fields are preserved + assert_eq!(serialized["claudeCode"]["toolName"], "Grep"); + assert!(serialized["claudeCode"]["toolResponse"].is_object()); + assert_eq!(serialized["claudeCode"]["toolResponse"]["mode"], "content"); + assert_eq!(serialized["claudeCode"]["toolResponse"]["numFiles"], 0); + assert_eq!(serialized["claudeCode"]["toolResponse"]["numLines"], 58); + assert_eq!(serialized["claudeCode"]["toolResponse"]["appliedLimit"], 100); + assert_eq!( + serialized["claudeCode"]["toolResponse"]["content"], + "some grep output here" + ); + } + + #[test] + fn test_meta_round_trip_preservation() { + use serde_json::json; + + // Test incoming → serialize → deserialize → serialize preserves all fields (T013) + let original_meta_json = json!({ + "claudeCode": { + "toolResponse": { + "mode": "content", + "numFiles": 3, + "filenames": ["file1.rs", "file2.rs", "file3.rs"], + "content": "match results", + "numLines": 42, + "appliedLimit": 100, + "customField": "should be preserved" + }, + "toolName": "Grep", + "additionalField": "also preserved" + }, + "provider": { + "name": "anthropic", + "original_ids": { + "session_id": "sess_123" + } + }, + "customTopLevel": "preserved too" + }); + + // First round trip + let meta1: Meta = serde_json::from_value(original_meta_json.clone()).unwrap(); + let json1 = serde_json::to_value(&meta1).unwrap(); + + // Second round trip + let meta2: Meta = serde_json::from_value(json1.clone()).unwrap(); + let json2 = serde_json::to_value(&meta2).unwrap(); + + // Verify both serializations are identical (stable) + assert_eq!(json1, json2); + + // Verify all nested fields preserved + assert_eq!(json2["claudeCode"]["toolResponse"]["customField"], "should be preserved"); + assert_eq!(json2["claudeCode"]["additionalField"], "also preserved"); + assert_eq!(json2["customTopLevel"], "preserved too"); + assert_eq!(json2["provider"]["name"], "anthropic"); + } + + #[test] + fn test_meta_extra_fields_with_flatten() { + use serde_json::json; + + // Test that #[serde(flatten)] correctly captures arbitrary fields + let json = json!({ + "provider": { + "name": "test_provider" + }, + "arbitraryField1": "value1", + "arbitraryField2": { + "nested": "structure" + }, + "arbitraryField3": [1, 2, 3] + }); + + let meta: Meta = serde_json::from_value(json.clone()).unwrap(); + + // Verify provider is parsed correctly + assert!(meta.provider.is_some()); + assert_eq!(meta.provider.as_ref().unwrap().name, "test_provider"); + + // Verify extra fields are captured + assert_eq!(meta.extra.len(), 3); + assert!(meta.extra.contains_key("arbitraryField1")); + assert!(meta.extra.contains_key("arbitraryField2")); + assert!(meta.extra.contains_key("arbitraryField3")); + + // Verify round-trip preserves all fields + let serialized = serde_json::to_value(&meta).unwrap(); + assert_eq!(serialized["arbitraryField1"], "value1"); + assert_eq!(serialized["arbitraryField2"]["nested"], "structure"); + assert_eq!(serialized["arbitraryField3"], json!([1, 2, 3])); + } + + #[test] + fn test_meta_no_data_loss_on_unknown_fields() { + use serde_json::json; + + // Simulate receiving _meta from Claude with fields we don't know about yet + let future_meta = json!({ + "claudeCode": { + "toolName": "FutureTool", + "futureFeature1": "some value", + "futureFeature2": { + "deeplyNested": { + "data": [1, 2, 3] + } + } + }, + "unknownTopLevel": "should not be lost" + }); + + let meta: Meta = serde_json::from_value(future_meta.clone()).unwrap(); + let serialized = serde_json::to_value(&meta).unwrap(); + + // Verify NO data loss - all unknown fields preserved + assert_eq!(serialized["claudeCode"]["toolName"], "FutureTool"); + assert_eq!(serialized["claudeCode"]["futureFeature1"], "some value"); + assert_eq!( + serialized["claudeCode"]["futureFeature2"]["deeplyNested"]["data"], + json!([1, 2, 3]) + ); + assert_eq!(serialized["unknownTopLevel"], "should not be lost"); + } +} diff --git a/crates/dirigent_protocol/src/types/mod.rs b/crates/dirigent_protocol/src/types/mod.rs new file mode 100644 index 0000000..14fe4cf --- /dev/null +++ b/crates/dirigent_protocol/src/types/mod.rs @@ -0,0 +1,14 @@ +pub mod content; +pub mod meta; +pub mod permission; +pub mod tool; +pub mod updates; + +pub use content::ContentBlock; +pub use meta::{Meta, ProviderMeta}; +pub use permission::{ + PermissionOption, PermissionOptionKind, RequestPermissionOutcome, RequestPermissionResponse, + ToolCallInfo, ToolCallLocation, ToolCallStatus as PermissionToolCallStatus, ToolKind, +}; +pub use tool::{ToolCall, ToolCallContent, ToolCallId, ToolCallStatus, ToolOrigin}; +pub use updates::SessionUpdate; diff --git a/crates/dirigent_protocol/src/types/permission.rs b/crates/dirigent_protocol/src/types/permission.rs new file mode 100644 index 0000000..a698e73 --- /dev/null +++ b/crates/dirigent_protocol/src/types/permission.rs @@ -0,0 +1,454 @@ +use serde::{Deserialize, Serialize}; + +/// ACP permission option presented to the user +/// +/// When a tool requires permission, the system presents the user with a list of +/// options that define how they want to handle the request. Each option has a kind +/// that determines the scope of the permission grant or rejection. +/// +/// # Examples +/// +/// ```rust +/// use dirigent_protocol::{PermissionOption, PermissionOptionKind}; +/// +/// let allow_once = PermissionOption { +/// option_id: "allow_once_1".to_string(), +/// name: "Allow this time".to_string(), +/// kind: PermissionOptionKind::AllowOnce, +/// }; +/// +/// let allow_always = PermissionOption { +/// option_id: "allow_always_1".to_string(), +/// name: "Always allow for this session".to_string(), +/// kind: PermissionOptionKind::AllowAlways, +/// }; +/// ``` +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct PermissionOption { + /// Unique identifier for this permission option + #[serde(rename = "optionId")] + pub option_id: String, + /// User-facing name/label for this option + pub name: String, + /// The kind of permission action this option represents + pub kind: PermissionOptionKind, +} + +/// Kind of permission option defining scope of grant/rejection +/// +/// These variants define how the user's permission decision should be applied: +/// - **Once**: Applies to the current request only +/// - **Always**: Applies to all similar requests in the session/context +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum PermissionOptionKind { + /// Grant permission for this single request + AllowOnce, + /// Grant permission for all similar requests + AllowAlways, + /// Reject this single request + RejectOnce, + /// Reject all similar requests + RejectAlways, +} + +/// Response from a permission request +/// +/// When the system requests permission from the user, the response indicates +/// either that the user selected a specific option, or that they cancelled +/// the request entirely. +/// +/// # ACP Wire Format +/// +/// The response is structured with the outcome nested inside an `outcome` field: +/// ```json +/// {"outcome": {"outcome": "selected", "optionId": "allow_once_1"}} +/// ``` +/// +/// # Examples +/// +/// ```rust +/// use dirigent_protocol::{RequestPermissionResponse, RequestPermissionOutcome}; +/// +/// // User selected an option +/// let selected = RequestPermissionResponse { +/// outcome: RequestPermissionOutcome::Selected { +/// option_id: "allow_once_1".to_string(), +/// }, +/// }; +/// +/// // User cancelled the request +/// let cancelled = RequestPermissionResponse { +/// outcome: RequestPermissionOutcome::Cancelled, +/// }; +/// ``` +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RequestPermissionResponse { + /// The outcome of the permission request (contains optionId if selected) + pub outcome: RequestPermissionOutcome, +} + +/// Outcome of a permission request +/// +/// Uses internal tagging to produce the ACP wire format: +/// - Selected: `{"outcome": "selected", "optionId": "..."}` +/// - Cancelled: `{"outcome": "cancelled"}` +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(tag = "outcome", rename_all = "snake_case")] +pub enum RequestPermissionOutcome { + /// User selected one of the provided options + Selected { + /// The ID of the selected option + #[serde(rename = "optionId")] + option_id: String, + }, + /// User cancelled the permission request + Cancelled, +} + +/// Information about a tool call for permission requests +/// +/// When requesting permission for a tool execution, this provides context +/// about what the tool will do, including its kind, status, and affected locations. +/// +/// # Examples +/// +/// ```rust +/// use dirigent_protocol::{ToolCallInfo, ToolKind, ToolCallStatus, ToolCallLocation}; +/// +/// let info = ToolCallInfo { +/// tool_call_id: "call_123".to_string(), +/// title: "Read configuration file".to_string(), +/// kind: Some(ToolKind::Read), +/// status: Some(ToolCallStatus::Pending), +/// locations: Some(vec![ +/// ToolCallLocation { +/// path: "/etc/config.toml".to_string(), +/// line: None, +/// } +/// ]), +/// raw_input: None, +/// }; +/// ``` +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ToolCallInfo { + /// Unique identifier for this tool call + #[serde(rename = "toolCallId")] + pub tool_call_id: String, + /// User-facing title describing what the tool will do + pub title: String, + /// The kind/category of operation this tool performs + #[serde(skip_serializing_if = "Option::is_none")] + pub kind: Option<ToolKind>, + /// Current status of the tool call + #[serde(skip_serializing_if = "Option::is_none")] + pub status: Option<ToolCallStatus>, + /// File/resource locations affected by this tool call + #[serde(skip_serializing_if = "Option::is_none")] + pub locations: Option<Vec<ToolCallLocation>>, + /// Raw input parameters for debugging/inspection + #[serde(rename = "rawInput", skip_serializing_if = "Option::is_none")] + pub raw_input: Option<serde_json::Value>, +} + +/// Category of tool operation +/// +/// Provides semantic categorization of tool functionality to help users +/// understand the impact and risk level of granting permission. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ToolKind { + /// Read-only operations (viewing files, searching) + Read, + /// Modify existing content + Edit, + /// Remove content + Delete, + /// Relocate content + Move, + /// Search operations + Search, + /// Execute commands or scripts + Execute, + /// Internal reasoning/planning (no external effects) + Think, + /// Fetch remote resources + Fetch, + /// Other/uncategorized operations + Other, +} + +/// Status of a tool call +/// +/// Note: This duplicates `ToolCallStatus` from `tool.rs` for now. +/// In the future, we may consolidate these types. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ToolCallStatus { + /// Tool call created but not yet started + Pending, + /// Tool call is currently executing + #[serde(rename = "in_progress")] + Running, + /// Tool call completed successfully + Completed, + /// Tool call failed with an error + #[serde(rename = "failed")] + Failed, +} + +/// Location affected by a tool call +/// +/// Represents a file or resource that will be read, modified, or otherwise +/// affected by the tool execution. May optionally include a specific line number. +/// +/// # Examples +/// +/// ```rust +/// use dirigent_protocol::ToolCallLocation; +/// +/// // File-level location +/// let file_location = ToolCallLocation { +/// path: "/src/main.rs".to_string(), +/// line: None, +/// }; +/// +/// // Specific line location +/// let line_location = ToolCallLocation { +/// path: "/src/lib.rs".to_string(), +/// line: Some(42), +/// }; +/// ``` +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub struct ToolCallLocation { + /// File path or resource identifier + pub path: String, + /// Optional line number within the file + #[serde(skip_serializing_if = "Option::is_none")] + pub line: Option<i32>, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_permission_option_serialization() { + let option = PermissionOption { + option_id: "allow_once_1".to_string(), + name: "Allow this time".to_string(), + kind: PermissionOptionKind::AllowOnce, + }; + + let json = serde_json::to_value(&option).unwrap(); + assert_eq!(json["optionId"], "allow_once_1"); + assert_eq!(json["name"], "Allow this time"); + assert_eq!(json["kind"], "allow_once"); + } + + #[test] + fn test_permission_option_kind_variants() { + let kinds = vec![ + (PermissionOptionKind::AllowOnce, "allow_once"), + (PermissionOptionKind::AllowAlways, "allow_always"), + (PermissionOptionKind::RejectOnce, "reject_once"), + (PermissionOptionKind::RejectAlways, "reject_always"), + ]; + + for (kind, expected) in kinds { + let json = serde_json::to_string(&kind).unwrap(); + assert_eq!(json, format!(r#""{}""#, expected)); + } + } + + #[test] + fn test_request_permission_response_selected() { + let response = RequestPermissionResponse { + outcome: RequestPermissionOutcome::Selected { + option_id: "allow_once_1".to_string(), + }, + }; + + let json = serde_json::to_value(&response).unwrap(); + // The outcome field should contain an object with nested outcome and optionId + assert_eq!(json["outcome"]["outcome"], "selected"); + assert_eq!(json["outcome"]["optionId"], "allow_once_1"); + } + + #[test] + fn test_request_permission_response_cancelled() { + let response = RequestPermissionResponse { + outcome: RequestPermissionOutcome::Cancelled, + }; + + let json = serde_json::to_value(&response).unwrap(); + // The outcome field should contain an object with just outcome + assert_eq!(json["outcome"]["outcome"], "cancelled"); + // optionId should not be present for cancelled + assert!(json["outcome"].get("optionId").is_none()); + } + + #[test] + fn test_request_permission_response_wire_format() { + // Test that the serialization matches ACP wire format exactly + let response = RequestPermissionResponse { + outcome: RequestPermissionOutcome::Selected { + option_id: "allow".to_string(), + }, + }; + + let json = serde_json::to_value(&response).unwrap(); + // Should produce: {"outcome": {"outcome": "selected", "optionId": "allow"}} + let expected = json!({ + "outcome": { + "outcome": "selected", + "optionId": "allow" + } + }); + assert_eq!(json, expected); + } + + #[test] + fn test_tool_call_info_serialization() { + let info = ToolCallInfo { + tool_call_id: "call_123".to_string(), + title: "Read file".to_string(), + kind: Some(ToolKind::Read), + status: Some(ToolCallStatus::Pending), + locations: Some(vec![ToolCallLocation { + path: "/test.txt".to_string(), + line: Some(10), + }]), + raw_input: Some(json!({"path": "/test.txt"})), + }; + + let json = serde_json::to_value(&info).unwrap(); + assert_eq!(json["toolCallId"], "call_123"); + assert_eq!(json["title"], "Read file"); + assert_eq!(json["kind"], "read"); + assert_eq!(json["status"], "pending"); + assert!(json["locations"].is_array()); + assert!(json["rawInput"].is_object()); + } + + #[test] + fn test_tool_call_info_minimal() { + let info = ToolCallInfo { + tool_call_id: "call_456".to_string(), + title: "Execute command".to_string(), + kind: None, + status: None, + locations: None, + raw_input: None, + }; + + let json = serde_json::to_value(&info).unwrap(); + assert_eq!(json["toolCallId"], "call_456"); + assert_eq!(json["title"], "Execute command"); + // Optional fields should be omitted + assert!(json.get("kind").is_none()); + assert!(json.get("status").is_none()); + assert!(json.get("locations").is_none()); + assert!(json.get("rawInput").is_none()); + } + + #[test] + fn test_tool_kind_variants() { + let kinds = vec![ + (ToolKind::Read, "read"), + (ToolKind::Edit, "edit"), + (ToolKind::Delete, "delete"), + (ToolKind::Move, "move"), + (ToolKind::Search, "search"), + (ToolKind::Execute, "execute"), + (ToolKind::Think, "think"), + (ToolKind::Fetch, "fetch"), + (ToolKind::Other, "other"), + ]; + + for (kind, expected) in kinds { + let json = serde_json::to_string(&kind).unwrap(); + assert_eq!(json, format!(r#""{}""#, expected)); + } + } + + #[test] + fn test_tool_call_status_variants() { + let statuses = vec![ + (ToolCallStatus::Pending, "pending"), + (ToolCallStatus::Running, "in_progress"), + (ToolCallStatus::Completed, "completed"), + (ToolCallStatus::Failed, "failed"), + ]; + + for (status, expected) in statuses { + let json = serde_json::to_string(&status).unwrap(); + assert_eq!(json, format!(r#""{}""#, expected)); + } + } + + #[test] + fn test_tool_call_location_with_line() { + let location = ToolCallLocation { + path: "/src/main.rs".to_string(), + line: Some(42), + }; + + let json = serde_json::to_value(&location).unwrap(); + assert_eq!(json["path"], "/src/main.rs"); + assert_eq!(json["line"], 42); + } + + #[test] + fn test_tool_call_location_without_line() { + let location = ToolCallLocation { + path: "/config.toml".to_string(), + line: None, + }; + + let json = serde_json::to_value(&location).unwrap(); + assert_eq!(json["path"], "/config.toml"); + assert!(json.get("line").is_none()); + } + + #[test] + fn test_roundtrip_permission_option() { + let original = PermissionOption { + option_id: "test_id".to_string(), + name: "Test Option".to_string(), + kind: PermissionOptionKind::AllowAlways, + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: PermissionOption = serde_json::from_str(&json).unwrap(); + + assert_eq!(original, deserialized); + } + + #[test] + fn test_roundtrip_tool_call_info() { + let original = ToolCallInfo { + tool_call_id: "call_789".to_string(), + title: "Complex operation".to_string(), + kind: Some(ToolKind::Edit), + status: Some(ToolCallStatus::Running), + locations: Some(vec![ + ToolCallLocation { + path: "/file1.rs".to_string(), + line: Some(10), + }, + ToolCallLocation { + path: "/file2.rs".to_string(), + line: None, + }, + ]), + raw_input: Some(json!({"key": "value"})), + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: ToolCallInfo = serde_json::from_str(&json).unwrap(); + + assert_eq!(original, deserialized); + } +} diff --git a/crates/dirigent_protocol/src/types/tool.rs b/crates/dirigent_protocol/src/types/tool.rs new file mode 100644 index 0000000..d85f27b --- /dev/null +++ b/crates/dirigent_protocol/src/types/tool.rs @@ -0,0 +1,513 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::types::content::ContentBlock; + +/// ACP-compliant tool call content wrapper supporting multiple content types +/// +/// The ACP protocol requires tool call content to be wrapped in a discriminated +/// union that supports three types: +/// - **Content** - Regular content blocks (text, images, etc.) +/// - **Diff** - File change diffs (oldText, newText) +/// - **Terminal** - Terminal session references +/// +/// # Examples +/// +/// ```rust +/// use dirigent_protocol::{ToolCallContent, ContentBlock}; +/// +/// // Create a text content wrapper +/// let content = ToolCallContent::from_content_block(ContentBlock::Text { +/// text: "Tool output".to_string() +/// }); +/// +/// // Create a diff +/// let diff = ToolCallContent::diff( +/// "/src/main.rs".to_string(), +/// Some("old code".to_string()), +/// "new code".to_string(), +/// ); +/// +/// // Create a terminal reference +/// let terminal = ToolCallContent::terminal("term_123".to_string()); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolCallContent { + /// Regular content (text, images, etc.) + Content { + content: ContentBlock, + }, + + /// File diff showing changes + Diff { + path: String, + #[serde(rename = "oldText", skip_serializing_if = "Option::is_none")] + old_text: Option<String>, + #[serde(rename = "newText")] + new_text: String, + }, + + /// Reference to a terminal session + Terminal { + #[serde(rename = "terminalId")] + terminal_id: String, + }, +} + +impl ToolCallContent { + /// Create a content wrapper from a ContentBlock + pub fn from_content_block(block: ContentBlock) -> Self { + Self::Content { content: block } + } + + /// Create a diff wrapper + pub fn diff(path: String, old_text: Option<String>, new_text: String) -> Self { + Self::Diff { path, old_text, new_text } + } + + /// Create a terminal reference + pub fn terminal(terminal_id: String) -> Self { + Self::Terminal { terminal_id } + } +} + +/// Unique identifier for a tool call +pub type ToolCallId = String; + +/// Status of a tool call in its lifecycle +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ToolCallStatus { + /// Tool call has been created but not yet started + Pending, + /// Tool call is currently executing + #[serde(rename = "in_progress")] + Running, + /// Tool call completed successfully + Completed, + /// Tool call failed with an error + #[serde(rename = "failed")] + Error, +} + +/// Origin of a tool execution +/// +/// Distinguishes where the tool is actually executed: +/// - Internal: Dirigent runs the tool after user permission +/// - External: Agent runs tool directly (we observe) +/// - Forwarded: Upstream ACP server (transitionary) +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "snake_case")] +pub enum ToolOrigin { + /// Tool executed by Dirigent after user permission + Internal, + /// Tool executed by agent directly (we observe) + #[default] + External, + /// Tool forwarded from upstream ACP server + Forwarded, +} + +/// Represents a tool call and its lifecycle +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ToolCall { + /// Unique identifier for this tool call + pub id: ToolCallId, + /// Name of the tool being called + pub tool_name: String, + /// Current status of the tool call + pub status: ToolCallStatus, + /// Content associated with this tool call (wrapped in ACP-compliant format) + /// + /// Each content item is wrapped in a discriminated union that can be: + /// - Content (text, images, etc.) + /// - Diff (file changes) + /// - Terminal (terminal session reference) + #[serde(default)] + pub content: Vec<ToolCallContent>, + /// Raw input parameters (preserved for debugging/inspection) + #[serde(skip_serializing_if = "Option::is_none")] + pub raw_input: Option<Value>, + /// Raw output result (preserved for debugging/inspection) + #[serde(skip_serializing_if = "Option::is_none")] + pub raw_output: Option<Value>, + /// Optional title for the tool call + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option<String>, + /// Error message if status is Error + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option<String>, + /// Additional metadata + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option<Value>, + /// Origin of this tool execution + #[serde(skip_serializing_if = "Option::is_none")] + pub origin: Option<ToolOrigin>, +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_tool_call_status_pending_serialization() { + let status = ToolCallStatus::Pending; + let json = serde_json::to_string(&status).unwrap(); + assert_eq!(json, r#""pending""#); + } + + #[test] + fn test_tool_call_status_running_serialization() { + let status = ToolCallStatus::Running; + let json = serde_json::to_string(&status).unwrap(); + assert_eq!(json, r#""in_progress""#); + } + + #[test] + fn test_tool_call_status_completed_serialization() { + let status = ToolCallStatus::Completed; + let json = serde_json::to_string(&status).unwrap(); + assert_eq!(json, r#""completed""#); + } + + #[test] + fn test_tool_call_status_error_serialization() { + let status = ToolCallStatus::Error; + let json = serde_json::to_string(&status).unwrap(); + assert_eq!(json, r#""failed""#); + } + + #[test] + fn test_tool_call_serialization_minimal() { + let tool_call = ToolCall { + id: "call_123".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + + // Verify required fields are present + assert!(json.contains(r#""id":"call_123""#)); + assert!(json.contains(r#""tool_name":"bash""#)); + assert!(json.contains(r#""status":"pending""#)); + assert!(json.contains(r#""content":[]"#)); + + // Verify optional fields are NOT present when None + assert!(!json.contains(r#""raw_input""#)); + assert!(!json.contains(r#""raw_output""#)); + assert!(!json.contains(r#""title""#)); + assert!(!json.contains(r#""error""#)); + assert!(!json.contains(r#""metadata""#)); + } + + #[test] + fn test_tool_call_serialization_complete() { + let tool_call = ToolCall { + id: "call_456".to_string(), + tool_name: "read_file".to_string(), + status: ToolCallStatus::Completed, + content: vec![ToolCallContent::from_content_block(ContentBlock::Text { + text: "File contents".to_string(), + })], + raw_input: Some(json!({"path": "/tmp/test.txt"})), + raw_output: Some(json!({"success": true})), + title: Some("Read test file".to_string()), + error: None, + metadata: Some(json!({"duration_ms": 42})), + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + + // Verify all fields are present + assert!(json.contains(r#""id":"call_456""#)); + assert!(json.contains(r#""tool_name":"read_file""#)); + assert!(json.contains(r#""status":"completed""#)); + assert!(json.contains(r#""text":"File contents""#)); + assert!(json.contains(r#""raw_input""#)); + assert!(json.contains(r#""raw_output""#)); + assert!(json.contains(r#""title":"Read test file""#)); + assert!(json.contains(r#""metadata""#)); + assert!(!json.contains(r#""error""#)); // Still None + } + + #[test] + fn test_tool_call_serialization_with_error() { + let tool_call = ToolCall { + id: "call_789".to_string(), + tool_name: "write_file".to_string(), + status: ToolCallStatus::Error, + content: vec![], + raw_input: Some(json!({"path": "/tmp/readonly.txt"})), + raw_output: None, + title: Some("Write to readonly file".to_string()), + error: Some("Permission denied".to_string()), + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + + assert!(json.contains(r#""status":"failed""#)); + assert!(json.contains(r#""error":"Permission denied""#)); + } + + #[test] + fn test_tool_call_roundtrip() { + let original = ToolCall { + id: "call_roundtrip".to_string(), + tool_name: "test_tool".to_string(), + status: ToolCallStatus::Running, + content: vec![ + ToolCallContent::from_content_block(ContentBlock::Text { + text: "Output line 1".to_string(), + }), + ToolCallContent::from_content_block(ContentBlock::Text { + text: "Output line 2".to_string(), + }), + ], + raw_input: Some(json!({"arg": "value"})), + raw_output: None, + title: Some("Test Tool Call".to_string()), + error: None, + metadata: Some(json!({"test": true})), + origin: None, + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + + assert_eq!(original, deserialized); + } + + #[test] + fn test_tool_call_default_content() { + // Test that content defaults to empty vec when not present in JSON + let json = r#"{ + "id": "call_default", + "tool_name": "test", + "status": "pending" + }"#; + + let tool_call: ToolCall = serde_json::from_str(json).unwrap(); + assert_eq!(tool_call.content, vec![]); + } + + #[test] + fn test_optional_fields_skip_when_none() { + let tool_call = ToolCall { + id: "call_skip".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let json = serde_json::to_value(&tool_call).unwrap(); + let obj = json.as_object().unwrap(); + + // Verify optional fields are not in the serialized object + assert!(!obj.contains_key("raw_input")); + assert!(!obj.contains_key("raw_output")); + assert!(!obj.contains_key("title")); + assert!(!obj.contains_key("error")); + assert!(!obj.contains_key("metadata")); + + // Verify required fields ARE in the serialized object + assert!(obj.contains_key("id")); + assert!(obj.contains_key("tool_name")); + assert!(obj.contains_key("status")); + assert!(obj.contains_key("content")); + } + + // ============================================================================ + // ToolCallContent Tests + // ============================================================================ + + #[test] + fn test_tool_call_content_wrapper_content_serialization() { + let content = ToolCallContent::from_content_block(ContentBlock::Text { + text: "test output".to_string(), + }); + + let json = serde_json::to_value(&content).unwrap(); + + assert_eq!(json["type"], "content"); + assert!(json["content"].is_object()); + assert_eq!(json["content"]["type"], "text"); + assert_eq!(json["content"]["text"], "test output"); + } + + #[test] + fn test_tool_call_content_wrapper_content_deserialization() { + let json = json!({ + "type": "content", + "content": { + "type": "text", + "text": "deserialized output" + } + }); + + let content: ToolCallContent = serde_json::from_value(json).unwrap(); + + match content { + ToolCallContent::Content { content } => { + match content { + ContentBlock::Text { text } => { + assert_eq!(text, "deserialized output"); + } + _ => panic!("Expected Text content block"), + } + } + _ => panic!("Expected Content variant"), + } + } + + #[test] + fn test_tool_call_content_diff_serialization() { + let diff = ToolCallContent::diff( + "/src/main.rs".to_string(), + Some("old code".to_string()), + "new code".to_string(), + ); + + let json = serde_json::to_value(&diff).unwrap(); + + assert_eq!(json["type"], "diff"); + assert_eq!(json["path"], "/src/main.rs"); + assert_eq!(json["oldText"], "old code"); + assert_eq!(json["newText"], "new code"); + } + + #[test] + fn test_tool_call_content_diff_without_old_text() { + let diff = ToolCallContent::diff( + "/src/new_file.rs".to_string(), + None, + "new file content".to_string(), + ); + + let json = serde_json::to_value(&diff).unwrap(); + + assert_eq!(json["type"], "diff"); + assert_eq!(json["path"], "/src/new_file.rs"); + assert!(json.get("oldText").is_none()); // Should be omitted when None + assert_eq!(json["newText"], "new file content"); + } + + #[test] + fn test_tool_call_content_diff_deserialization() { + let json = json!({ + "type": "diff", + "path": "/test/file.rs", + "oldText": "before", + "newText": "after" + }); + + let content: ToolCallContent = serde_json::from_value(json).unwrap(); + + match content { + ToolCallContent::Diff { path, old_text, new_text } => { + assert_eq!(path, "/test/file.rs"); + assert_eq!(old_text, Some("before".to_string())); + assert_eq!(new_text, "after"); + } + _ => panic!("Expected Diff variant"), + } + } + + #[test] + fn test_tool_call_content_terminal_serialization() { + let terminal = ToolCallContent::terminal("term_123".to_string()); + + let json = serde_json::to_value(&terminal).unwrap(); + + assert_eq!(json["type"], "terminal"); + assert_eq!(json["terminalId"], "term_123"); + } + + #[test] + fn test_tool_call_content_terminal_deserialization() { + let json = json!({ + "type": "terminal", + "terminalId": "term_456" + }); + + let content: ToolCallContent = serde_json::from_value(json).unwrap(); + + match content { + ToolCallContent::Terminal { terminal_id } => { + assert_eq!(terminal_id, "term_456"); + } + _ => panic!("Expected Terminal variant"), + } + } + + #[test] + fn test_tool_call_content_roundtrip() { + // Test all three variants + let variants = vec![ + ToolCallContent::from_content_block(ContentBlock::Text { + text: "test".to_string(), + }), + ToolCallContent::diff("path.rs".to_string(), Some("old".to_string()), "new".to_string()), + ToolCallContent::terminal("term_789".to_string()), + ]; + + for original in variants { + let json = serde_json::to_value(&original).unwrap(); + let deserialized: ToolCallContent = serde_json::from_value(json).unwrap(); + assert_eq!(original, deserialized); + } + } + + #[test] + fn test_tool_call_with_mixed_content_types() { + let tool_call = ToolCall { + id: "call_mixed".to_string(), + tool_name: "edit_tool".to_string(), + status: ToolCallStatus::Completed, + content: vec![ + ToolCallContent::from_content_block(ContentBlock::Text { + text: "Editing file...".to_string(), + }), + ToolCallContent::diff( + "/src/lib.rs".to_string(), + Some("fn old() {}".to_string()), + "fn new() {}".to_string(), + ), + ToolCallContent::from_content_block(ContentBlock::Text { + text: "Edit complete".to_string(), + }), + ], + raw_input: None, + raw_output: None, + title: Some("Edit file".to_string()), + error: None, + metadata: None, + origin: None, + }; + + // Serialize and deserialize + let json = serde_json::to_value(&tool_call).unwrap(); + let deserialized: ToolCall = serde_json::from_value(json).unwrap(); + + assert_eq!(tool_call, deserialized); + assert_eq!(deserialized.content.len(), 3); + } +} diff --git a/crates/dirigent_protocol/src/types/updates.rs b/crates/dirigent_protocol/src/types/updates.rs new file mode 100644 index 0000000..93f4cf5 --- /dev/null +++ b/crates/dirigent_protocol/src/types/updates.rs @@ -0,0 +1,654 @@ +use serde::{Deserialize, Serialize}; + +use crate::types::content::ContentBlock; +use crate::types::meta::Meta; +use crate::types::tool::ToolCall; + +/// ACP-style session updates for streaming content +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum SessionUpdate { + /// User message content chunk + UserMessageChunk { + message_id: String, + content: ContentBlock, + #[serde(skip_serializing_if = "Option::is_none")] + _meta: Option<Meta>, + }, + /// Agent message content chunk + AgentMessageChunk { + message_id: String, + content: ContentBlock, + #[serde(skip_serializing_if = "Option::is_none")] + _meta: Option<Meta>, + }, + /// Agent thought content chunk (internal reasoning) + AgentThoughtChunk { + message_id: String, + content: ContentBlock, + #[serde(skip_serializing_if = "Option::is_none")] + _meta: Option<Meta>, + }, + /// Tool call created or initiated + ToolCall { + message_id: String, + tool_call: ToolCall, + #[serde(skip_serializing_if = "Option::is_none")] + _meta: Option<Meta>, + }, + /// Tool call update (status change, new content, etc.) + ToolCallUpdate { + message_id: String, + tool_call_id: String, + tool_call: ToolCall, + #[serde(skip_serializing_if = "Option::is_none")] + _meta: Option<Meta>, + }, + /// Unknown update type (forward compatibility - pass through as raw JSON) + #[serde(untagged)] + Unknown { + #[serde(flatten)] + data: serde_json::Value, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ToolCallContent; + use serde_json::json; + + #[test] + fn test_user_message_chunk_serialization() { + let update = SessionUpdate::UserMessageChunk { + message_id: "msg_123".to_string(), + content: ContentBlock::Text { + text: "Hello".to_string(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"user_message_chunk"#)); + assert!(json.contains(r#""message_id":"msg_123"#)); + assert!(json.contains(r#""text":"Hello"#)); + assert!(!json.contains(r#""_meta""#)); + } + + #[test] + fn test_user_message_chunk_deserialization() { + let json = r#"{ + "type": "user_message_chunk", + "message_id": "msg_123", + "content": { + "type": "text", + "text": "Hello" + } + }"#; + + let update: SessionUpdate = serde_json::from_str(json).unwrap(); + match update { + SessionUpdate::UserMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_123"); + assert_eq!( + content, + ContentBlock::Text { + text: "Hello".to_string() + } + ); + assert_eq!(_meta, None); + } + _ => panic!("Expected UserMessageChunk"), + } + } + + #[test] + fn test_user_message_chunk_roundtrip() { + let original = SessionUpdate::UserMessageChunk { + message_id: "msg_456".to_string(), + content: ContentBlock::Text { + text: "Roundtrip test".to_string(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(original, deserialized); + } + + #[test] + fn test_user_message_chunk_with_meta() { + let update = SessionUpdate::UserMessageChunk { + message_id: "msg_789".to_string(), + content: ContentBlock::Text { + text: "With meta".to_string(), + }, + _meta: Some(Meta { + provider: None, + extra: std::collections::HashMap::new(), + }), + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""_meta":{}"#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); + } + + #[test] + fn test_agent_message_chunk_serialization() { + let update = SessionUpdate::AgentMessageChunk { + message_id: "msg_agent_1".to_string(), + content: ContentBlock::Text { + text: "Agent response".to_string(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"agent_message_chunk"#)); + assert!(json.contains(r#""message_id":"msg_agent_1"#)); + assert!(json.contains(r#""text":"Agent response"#)); + } + + #[test] + fn test_agent_message_chunk_deserialization() { + let json = r#"{ + "type": "agent_message_chunk", + "message_id": "msg_agent_2", + "content": { + "type": "text", + "text": "Agent here" + } + }"#; + + let update: SessionUpdate = serde_json::from_str(json).unwrap(); + match update { + SessionUpdate::AgentMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_agent_2"); + assert_eq!( + content, + ContentBlock::Text { + text: "Agent here".to_string() + } + ); + assert_eq!(_meta, None); + } + _ => panic!("Expected AgentMessageChunk"), + } + } + + #[test] + fn test_agent_message_chunk_roundtrip() { + let original = SessionUpdate::AgentMessageChunk { + message_id: "msg_agent_rt".to_string(), + content: ContentBlock::ResourceLink { + uri: "file:///test.txt".to_string(), + name: Some("test.txt".to_string()), + mime_type: Some("text/plain".to_string()), + }, + _meta: None, + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(original, deserialized); + } + + #[test] + fn test_agent_message_chunk_with_meta() { + let mut extra = std::collections::HashMap::new(); + extra.insert("timestamp".to_string(), json!("2025-11-10T12:00:00Z")); + + let update = SessionUpdate::AgentMessageChunk { + message_id: "msg_agent_meta".to_string(), + content: ContentBlock::Text { + text: "With metadata".to_string(), + }, + _meta: Some(Meta { + provider: None, + extra, + }), + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""_meta""#)); + assert!(json.contains(r#""timestamp""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); + } + + #[test] + fn test_agent_thought_chunk_serialization() { + let update = SessionUpdate::AgentThoughtChunk { + message_id: "msg_thought_1".to_string(), + content: ContentBlock::Text { + text: "Thinking...".to_string(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"agent_thought_chunk"#)); + assert!(json.contains(r#""message_id":"msg_thought_1"#)); + assert!(json.contains(r#""text":"Thinking..."#)); + } + + #[test] + fn test_agent_thought_chunk_deserialization() { + let json = r#"{ + "type": "agent_thought_chunk", + "message_id": "msg_thought_2", + "content": { + "type": "text", + "text": "Analyzing the problem..." + } + }"#; + + let update: SessionUpdate = serde_json::from_str(json).unwrap(); + match update { + SessionUpdate::AgentThoughtChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_thought_2"); + assert_eq!( + content, + ContentBlock::Text { + text: "Analyzing the problem...".to_string() + } + ); + assert_eq!(_meta, None); + } + _ => panic!("Expected AgentThoughtChunk"), + } + } + + #[test] + fn test_agent_thought_chunk_roundtrip() { + let original = SessionUpdate::AgentThoughtChunk { + message_id: "msg_thought_rt".to_string(), + content: ContentBlock::Text { + text: "Internal reasoning".to_string(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(original, deserialized); + } + + #[test] + fn test_agent_thought_chunk_with_meta() { + let update = SessionUpdate::AgentThoughtChunk { + message_id: "msg_thought_meta".to_string(), + content: ContentBlock::Text { + text: "Thought with meta".to_string(), + }, + _meta: Some(Meta::default()), + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""_meta":{}"#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); + } + + #[test] + fn test_tool_call_serialization() { + let tool_call = ToolCall { + id: "call_123".to_string(), + tool_name: "bash".to_string(), + status: crate::types::tool::ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: Some("Run bash command".to_string()), + error: None, + metadata: None, + + origin: None, + + }; + + let update = SessionUpdate::ToolCall { + message_id: "msg_tool_1".to_string(), + tool_call, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"tool_call"#)); + assert!(json.contains(r#""message_id":"msg_tool_1"#)); + assert!(json.contains(r#""tool_name":"bash"#)); + assert!(json.contains(r#""status":"pending"#)); + } + + #[test] + fn test_tool_call_deserialization() { + let json = r#"{ + "type": "tool_call", + "message_id": "msg_tool_2", + "tool_call": { + "id": "call_456", + "tool_name": "read", + "status": "in_progress", + "content": [] + } + }"#; + + let update: SessionUpdate = serde_json::from_str(json).unwrap(); + match update { + SessionUpdate::ToolCall { + message_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg_tool_2"); + assert_eq!(tool_call.id, "call_456"); + assert_eq!(tool_call.tool_name, "read"); + assert_eq!( + tool_call.status, + crate::types::tool::ToolCallStatus::Running + ); + assert_eq!(_meta, None); + } + _ => panic!("Expected ToolCall"), + } + } + + #[test] + fn test_tool_call_roundtrip() { + let tool_call = ToolCall { + id: "call_rt".to_string(), + tool_name: "write".to_string(), + status: crate::types::tool::ToolCallStatus::Completed, + content: vec![ToolCallContent::from_content_block(ContentBlock::Text { + text: "File written".to_string(), + })], + raw_input: Some(json!({"path": "/tmp/test.txt"})), + raw_output: Some(json!({"success": true})), + title: Some("Write file".to_string()), + error: None, + metadata: None, + + origin: None, + + }; + + let original = SessionUpdate::ToolCall { + message_id: "msg_tool_rt".to_string(), + tool_call, + _meta: None, + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(original, deserialized); + } + + #[test] + fn test_tool_call_with_meta() { + let tool_call = ToolCall { + id: "call_meta".to_string(), + tool_name: "bash".to_string(), + status: crate::types::tool::ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + + origin: None, + + }; + + let update = SessionUpdate::ToolCall { + message_id: "msg_tool_meta".to_string(), + tool_call, + _meta: Some(Meta::default()), + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""_meta":{}"#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); + } + + #[test] + fn test_tool_call_update_serialization() { + let tool_call = ToolCall { + id: "call_update_1".to_string(), + tool_name: "bash".to_string(), + status: crate::types::tool::ToolCallStatus::Running, + content: vec![ToolCallContent::from_content_block(ContentBlock::Text { + text: "Output line 1".to_string(), + })], + raw_input: None, + raw_output: None, + title: Some("Running bash".to_string()), + error: None, + metadata: None, + + origin: None, + + }; + + let update = SessionUpdate::ToolCallUpdate { + message_id: "msg_update_1".to_string(), + tool_call_id: "call_update_1".to_string(), + tool_call, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"tool_call_update"#)); + assert!(json.contains(r#""message_id":"msg_update_1"#)); + assert!(json.contains(r#""tool_call_id":"call_update_1"#)); + // ToolCallStatus::Running serializes as "in_progress" + assert!(json.contains(r#""status":"in_progress"#)); + } + + #[test] + fn test_tool_call_update_deserialization() { + let json = r#"{ + "type": "tool_call_update", + "message_id": "msg_update_2", + "tool_call_id": "call_update_2", + "tool_call": { + "id": "call_update_2", + "tool_name": "read", + "status": "completed", + "content": [ + { + "type": "content", + "content": { + "type": "text", + "text": "File contents" + } + } + ] + } + }"#; + + let update: SessionUpdate = serde_json::from_str(json).unwrap(); + match update { + SessionUpdate::ToolCallUpdate { + message_id, + tool_call_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg_update_2"); + assert_eq!(tool_call_id, "call_update_2"); + assert_eq!(tool_call.id, "call_update_2"); + assert_eq!( + tool_call.status, + crate::types::tool::ToolCallStatus::Completed + ); + assert_eq!( + tool_call.content, + vec![crate::types::tool::ToolCallContent::from_content_block(ContentBlock::Text { + text: "File contents".to_string() + })] + ); + assert_eq!(_meta, None); + } + _ => panic!("Expected ToolCallUpdate"), + } + } + + #[test] + fn test_tool_call_update_roundtrip() { + let tool_call = ToolCall { + id: "call_update_rt".to_string(), + tool_name: "bash".to_string(), + status: crate::types::tool::ToolCallStatus::Error, + content: vec![], + raw_input: Some(json!({"command": "invalid"})), + raw_output: None, + title: Some("Failed command".to_string()), + error: Some("Command not found".to_string()), + metadata: None, + + origin: None, + + }; + + let original = SessionUpdate::ToolCallUpdate { + message_id: "msg_update_rt".to_string(), + tool_call_id: "call_update_rt".to_string(), + tool_call, + _meta: None, + }; + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(original, deserialized); + } + + #[test] + fn test_tool_call_update_with_meta() { + let tool_call = ToolCall { + id: "call_update_meta".to_string(), + tool_name: "write".to_string(), + status: crate::types::tool::ToolCallStatus::Completed, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + + origin: None, + + }; + + let update = SessionUpdate::ToolCallUpdate { + message_id: "msg_update_meta".to_string(), + tool_call_id: "call_update_meta".to_string(), + tool_call, + _meta: Some(Meta::default()), + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""_meta":{}"#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); + } + + #[test] + fn test_all_variants_have_snake_case_type_tags() { + // Test that each variant serializes with the correct snake_case type tag + + let user_chunk = SessionUpdate::UserMessageChunk { + message_id: "m1".to_string(), + content: ContentBlock::Text { + text: "test".to_string(), + }, + _meta: None, + }; + let json = serde_json::to_string(&user_chunk).unwrap(); + assert!(json.contains(r#""type":"user_message_chunk"#)); + + let agent_chunk = SessionUpdate::AgentMessageChunk { + message_id: "m2".to_string(), + content: ContentBlock::Text { + text: "test".to_string(), + }, + _meta: None, + }; + let json = serde_json::to_string(&agent_chunk).unwrap(); + assert!(json.contains(r#""type":"agent_message_chunk"#)); + + let thought_chunk = SessionUpdate::AgentThoughtChunk { + message_id: "m3".to_string(), + content: ContentBlock::Text { + text: "test".to_string(), + }, + _meta: None, + }; + let json = serde_json::to_string(&thought_chunk).unwrap(); + assert!(json.contains(r#""type":"agent_thought_chunk"#)); + + let tool_call = SessionUpdate::ToolCall { + message_id: "m4".to_string(), + tool_call: ToolCall { + id: "c1".to_string(), + tool_name: "test".to_string(), + status: crate::types::tool::ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + + origin: None, + + }, + _meta: None, + }; + let json = serde_json::to_string(&tool_call).unwrap(); + assert!(json.contains(r#""type":"tool_call"#)); + + let tool_call_update = SessionUpdate::ToolCallUpdate { + message_id: "m5".to_string(), + tool_call_id: "c2".to_string(), + tool_call: ToolCall { + id: "c2".to_string(), + tool_name: "test".to_string(), + status: crate::types::tool::ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + + origin: None, + + }, + _meta: None, + }; + let json = serde_json::to_string(&tool_call_update).unwrap(); + assert!(json.contains(r#""type":"tool_call_update"#)); + } +} diff --git a/crates/dirigent_protocol/tests/README.md b/crates/dirigent_protocol/tests/README.md new file mode 100644 index 0000000..126f625 --- /dev/null +++ b/crates/dirigent_protocol/tests/README.md @@ -0,0 +1,230 @@ +# Dirigent Protocol Tests + +This directory contains comprehensive tests for the Dirigent protocol and OpenCode adapter. + +## Test Files + +### `protocol_tests.rs` +Core protocol translation tests that verify OpenCode events are correctly translated to Dirigent protocol events. + +**Coverage:** +- Session creation and updates +- User and assistant messages +- Message parts (text, thinking, tool) +- Event stream parsing +- Protocol serialization/deserialization + +**Run:** `cargo test --test protocol_tests` + +### `deduplication_tests.rs` +Tests for the stateful adapter's deduplication logic, ensuring no duplicate messages or parts appear in the UI. + +**Coverage:** +- Duplicate `MessageStarted` filtering +- Duplicate `MessageCompleted` filtering +- Part completion signal (`delta: null`) filtering +- Different part types not being filtered +- Streaming part updates working correctly +- Full tit-tat conversation flow +- Adapter state independence + +**Run:** `cargo test --test deduplication_tests` + +### `session_list_tests.rs` +Tests for parsing OpenCode session list responses. + +**Coverage:** +- Session list array parsing +- Empty session list handling +- Single session deserialization +- Optional fields handling +- Timestamp parsing validation + +**Run:** `cargo test --test session_list_tests` + +## Fixtures + +### `fixtures/sample_events.jsonl` +Sample OpenCode SSE events in JSONL format (one event per line). Used for parsing validation and event stream testing. + +### `fixtures/opencode_session_response.json` +Real OpenCode session list response. Used for session deserialization tests. + +**Source:** Copied from `/docs/building/opencode_session_response.json` + +## Running All Tests + +```bash +# Run all protocol tests +cargo test --package dirigent_protocol + +# Run with output +cargo test --package dirigent_protocol -- --nocapture + +# Run specific test +cargo test --package dirigent_protocol test_tit_tat_flow + +# Run tests matching pattern +cargo test --package dirigent_protocol duplicate +``` + +## Adding New Tests + +### For OpenCode Event Translation + +Add to `protocol_tests.rs`: + +```rust +#[test] +fn test_translate_new_feature() { + let adapter = OpenCodeAdapter::new(); + + // Create OpenCode event + let oc_event = oc::Event::YourEvent { ... }; + + // Translate + let result = adapter.translate_event(oc_event); + + // Assert + assert!(result.is_ok()); + match result.unwrap() { + Event::YourDirigentEvent(data) => { + assert_eq!(data.field, expected_value); + } + _ => panic!("Expected YourDirigentEvent"), + } +} +``` + +### For Deduplication Logic + +Add to `deduplication_tests.rs`: + +```rust +#[test] +fn test_new_deduplication_rule() { + let adapter = OpenCodeAdapter::new(); + + // Send first event + let result1 = adapter.translate_event(first_event); + assert!(result1.is_ok()); + + // Send duplicate event + let result2 = adapter.translate_event(duplicate_event); + assert!(result2.is_err()); + assert!(matches!(result2.unwrap_err(), TranslationError::Duplicate)); +} +``` + +### For New Fixtures + +1. Place fixture file in `tests/fixtures/` +2. Use `include_str!` to load it: + ```rust + let fixture = include_str!("fixtures/your_file.json"); + ``` + +## Test Principles + +### Stateful Adapter Pattern + +⚠️ **IMPORTANT:** The adapter maintains state, so: + +```rust +// ✅ CORRECT: One adapter for entire event stream +let adapter = OpenCodeAdapter::new(); +for event in events { + adapter.translate_event(event); +} + +// ❌ WRONG: New adapter each time (loses state!) +for event in events { + let adapter = OpenCodeAdapter::new(); + adapter.translate_event(event); +} +``` + +### Testing Duplicates + +When testing deduplication: +1. Send the first event → should succeed +2. Send the duplicate event → should fail with `TranslationError::Duplicate` +3. Always use the SAME adapter instance + +### Real-World Fixtures + +Fixtures should come from actual OpenCode API responses when possible: +- Captures real-world edge cases +- Ensures compatibility with API changes +- Documents actual behavior + +## CI Integration + +These tests run automatically on: +- Every commit (via `cargo test`) +- Pull requests +- Before releases + +**Status:** All tests should pass before merging. + +## Coverage Report + +```bash +# Install tarpaulin for coverage +cargo install cargo-tarpaulin + +# Generate coverage report +cargo tarpaulin --package dirigent_protocol --out Html +``` + +## Related Documentation + +- [SSE Deduplication](../../../docs/building/sse_deduplication.md) - How deduplication works +- [SSE Event Flow Analysis](../../../docs/building/sse_event_flow_analysis.md) - OpenCode event patterns +- [Protocol Abstraction Plan](../../../docs/building/protocol_abstraction_plan.md) - Adapter architecture + +## Test Statistics + +**Last Updated:** 2025-11-01 + +- **Total Tests:** 24 +- **Deduplication Tests:** 7 +- **Session Tests:** 5 +- **Protocol Tests:** 12 +- **All Passing:** ✅ + +## Troubleshooting + +### Test Fails with "argument #1 of type &OpenCodeAdapter is missing" + +**Problem:** You're calling `OpenCodeAdapter::translate_event(event)` as a static method. + +**Solution:** Create an adapter instance first: +```rust +let adapter = OpenCodeAdapter::new(); +let result = adapter.translate_event(event); +``` + +### Test Fails with "pattern does not mention field part_id" + +**Problem:** The `MessagePartAdded` event now includes a `part_id` field. + +**Solution:** Update your pattern match: +```rust +// Before +Event::MessagePartAdded { message_id, part, delta } => { ... } + +// After +Event::MessagePartAdded { message_id, part_id: _, part, delta } => { ... } +``` + +### Deduplication Test Unexpectedly Passes + +**Problem:** You're creating a new adapter for each event. + +**Solution:** Create ONE adapter and reuse it: +```rust +let adapter = OpenCodeAdapter::new(); +adapter.translate_event(event1); // First time +adapter.translate_event(event1); // Should be duplicate! +``` diff --git a/crates/dirigent_protocol/tests/content_block_tests.rs b/crates/dirigent_protocol/tests/content_block_tests.rs new file mode 100644 index 0000000..5d66cbb --- /dev/null +++ b/crates/dirigent_protocol/tests/content_block_tests.rs @@ -0,0 +1,378 @@ +/// Comprehensive edge case tests for ContentBlock +use dirigent_protocol::types::ContentBlock; + +/// Test empty string in Text variant +#[test] +fn test_text_empty_string() { + let block = ContentBlock::Text { + text: String::new(), + }; + let json = serde_json::to_string(&block).unwrap(); + assert!(json.contains(r#""type":"text"#)); + assert!(json.contains(r#""text":"""#)); + + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert_eq!(block, deserialized); +} + +/// Test very long string (>10KB) in Text variant +#[test] +fn test_text_very_long_string() { + let long_text = "a".repeat(10_000); + let block = ContentBlock::Text { + text: long_text.clone(), + }; + let json = serde_json::to_string(&block).unwrap(); + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + + match deserialized { + ContentBlock::Text { text } => { + assert_eq!(text.len(), 10_000); + assert_eq!(text, long_text); + } + _ => panic!("Expected Text variant"), + } +} + +/// Test special characters (unicode, emojis, newlines) in Text +#[test] +fn test_text_special_characters() { + let special_text = "Hello 👋\nWorld 🌍\t中文\r\n\"quotes\" and 'apostrophes' \\ backslash"; + let block = ContentBlock::Text { + text: special_text.to_string(), + }; + let json = serde_json::to_string(&block).unwrap(); + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + + match deserialized { + ContentBlock::Text { text } => { + assert_eq!(text, special_text); + } + _ => panic!("Expected Text variant"), + } +} + +/// Test JSON control characters in Text +#[test] +fn test_text_json_control_characters() { + let control_chars = "Line1\nLine2\r\nLine3\tTabbed\x08Backspace"; + let block = ContentBlock::Text { + text: control_chars.to_string(), + }; + let json = serde_json::to_string(&block).unwrap(); + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + + match deserialized { + ContentBlock::Text { text } => { + assert_eq!(text, control_chars); + } + _ => panic!("Expected Text variant"), + } +} + +/// Test ResourceLink with empty URI +#[test] +fn test_resource_link_empty_uri() { + let block = ContentBlock::ResourceLink { + uri: String::new(), + name: None, + mime_type: None, + }; + let json = serde_json::to_string(&block).unwrap(); + assert!(json.contains(r#""type":"resource_link"#)); + assert!(json.contains(r#""uri":"""#)); + + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert_eq!(block, deserialized); +} + +/// Test ResourceLink with all optional fields as None +#[test] +fn test_resource_link_all_none() { + let block = ContentBlock::ResourceLink { + uri: "file:///test.txt".to_string(), + name: None, + mime_type: None, + }; + let json = serde_json::to_string(&block).unwrap(); + + // Verify optional fields are NOT in JSON + assert!(!json.contains(r#""name""#)); + assert!(!json.contains(r#""mime_type""#)); + + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert_eq!(block, deserialized); +} + +/// Test ResourceLink with all optional fields as Some +#[test] +fn test_resource_link_all_some() { + let block = ContentBlock::ResourceLink { + uri: "https://example.com/file.pdf".to_string(), + name: Some("document.pdf".to_string()), + mime_type: Some("application/pdf".to_string()), + }; + let json = serde_json::to_string(&block).unwrap(); + + // Verify all fields ARE in JSON + assert!(json.contains(r#""uri":"https://example.com/file.pdf"#)); + assert!(json.contains(r#""name":"document.pdf"#)); + assert!(json.contains(r#""mime_type":"application/pdf"#)); + + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert_eq!(block, deserialized); +} + +/// Test ResourceLink with only name +#[test] +fn test_resource_link_only_name() { + let block = ContentBlock::ResourceLink { + uri: "file:///path/to/file".to_string(), + name: Some("my_file".to_string()), + mime_type: None, + }; + let json = serde_json::to_string(&block).unwrap(); + + assert!(json.contains(r#""name":"my_file"#)); + assert!(!json.contains(r#""mime_type""#)); + + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert_eq!(block, deserialized); +} + +/// Test ResourceLink with only mime_type +#[test] +fn test_resource_link_only_mime_type() { + let block = ContentBlock::ResourceLink { + uri: "file:///path/to/file".to_string(), + name: None, + mime_type: Some("text/plain".to_string()), + }; + let json = serde_json::to_string(&block).unwrap(); + + assert!(!json.contains(r#""name""#)); + assert!(json.contains(r#""mime_type":"text/plain"#)); + + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert_eq!(block, deserialized); +} + +/// Test ResourceLink with empty optional strings +#[test] +fn test_resource_link_empty_optional_strings() { + let block = ContentBlock::ResourceLink { + uri: "file:///test".to_string(), + name: Some(String::new()), + mime_type: Some(String::new()), + }; + let json = serde_json::to_string(&block).unwrap(); + + // Empty strings should still serialize + assert!(json.contains(r#""name":"""#)); + assert!(json.contains(r#""mime_type":"""#)); + + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert_eq!(block, deserialized); +} + +/// Test ResourceLink with special characters in URI +#[test] +fn test_resource_link_special_uri() { + let uri = "file:///path/to/file%20with%20spaces.txt?query=param&other=value#fragment"; + let block = ContentBlock::ResourceLink { + uri: uri.to_string(), + name: Some("file with spaces.txt".to_string()), + mime_type: Some("text/plain; charset=utf-8".to_string()), + }; + let json = serde_json::to_string(&block).unwrap(); + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + + match deserialized { + ContentBlock::ResourceLink { + uri: deser_uri, + name, + mime_type, + } => { + assert_eq!(deser_uri, uri); + assert_eq!(name, Some("file with spaces.txt".to_string())); + assert_eq!(mime_type, Some("text/plain; charset=utf-8".to_string())); + } + _ => panic!("Expected ResourceLink variant"), + } +} + +/// Test ResourceLink with very long URI +#[test] +fn test_resource_link_long_uri() { + let long_path = format!("file:///{}", "a/".repeat(500)); + let block = ContentBlock::ResourceLink { + uri: long_path.clone(), + name: Some("file.txt".to_string()), + mime_type: None, + }; + let json = serde_json::to_string(&block).unwrap(); + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + + match deserialized { + ContentBlock::ResourceLink { uri, .. } => { + assert_eq!(uri, long_path); + } + _ => panic!("Expected ResourceLink variant"), + } +} + +/// Test deserialization from JSON without type field (should fail) +#[test] +fn test_deserialization_missing_type_field() { + let json = r#"{"text": "Hello"}"#; + let result: Result<ContentBlock, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without type field"); +} + +/// Test deserialization with invalid type value +#[test] +fn test_deserialization_invalid_type() { + let json = r#"{"type": "invalid_type", "text": "Hello"}"#; + let result: Result<ContentBlock, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail with invalid type"); +} + +/// Test deserialization Text without text field (should fail) +#[test] +fn test_deserialization_text_missing_field() { + let json = r#"{"type": "text"}"#; + let result: Result<ContentBlock, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without text field"); +} + +/// Test deserialization ResourceLink without uri field (should fail) +#[test] +fn test_deserialization_resource_link_missing_uri() { + let json = r#"{"type": "resource_link", "name": "file.txt"}"#; + let result: Result<ContentBlock, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without uri field"); +} + +/// Test that type tag is snake_case +#[test] +fn test_type_tag_format() { + let text = ContentBlock::Text { + text: "test".to_string(), + }; + let json = serde_json::to_string(&text).unwrap(); + assert!(json.contains(r#""type":"text"#)); + assert!(!json.contains(r#""type":"Text"#)); + + let resource = ContentBlock::ResourceLink { + uri: "file:///test".to_string(), + name: None, + mime_type: None, + }; + let json = serde_json::to_string(&resource).unwrap(); + assert!(json.contains(r#""type":"resource_link"#)); + assert!(!json.contains(r#""type":"ResourceLink"#)); +} + +/// Test roundtrip with all ContentBlock variants +#[test] +fn test_all_variants_roundtrip() { + let variants = vec![ + ContentBlock::Text { + text: "Test text".to_string(), + }, + ContentBlock::ResourceLink { + uri: "file:///test.txt".to_string(), + name: Some("test.txt".to_string()), + mime_type: Some("text/plain".to_string()), + }, + ContentBlock::ResourceLink { + uri: "https://example.com/resource".to_string(), + name: None, + mime_type: None, + }, + ]; + + for variant in variants { + let json = serde_json::to_string(&variant).unwrap(); + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert_eq!(variant, deserialized); + } +} + +/// Test pretty-printed JSON +#[test] +fn test_pretty_json() { + let block = ContentBlock::ResourceLink { + uri: "file:///test.txt".to_string(), + name: Some("test.txt".to_string()), + mime_type: Some("text/plain".to_string()), + }; + let json = serde_json::to_string_pretty(&block).unwrap(); + + // Should be parseable + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert_eq!(block, deserialized); +} + +/// Test null values in JSON (should fail for required fields) +#[test] +fn test_null_values() { + let json = r#"{"type": "text", "text": null}"#; + let result: Result<ContentBlock, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail with null text"); + + let json = r#"{"type": "resource_link", "uri": null}"#; + let result: Result<ContentBlock, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail with null uri"); +} + +/// Test null values for optional fields (should deserialize as None) +#[test] +fn test_null_optional_fields() { + let json = r#"{ + "type": "resource_link", + "uri": "file:///test", + "name": null, + "mime_type": null + }"#; + let result: ContentBlock = serde_json::from_str(json).unwrap(); + + match result { + ContentBlock::ResourceLink { + uri, + name, + mime_type, + } => { + assert_eq!(uri, "file:///test"); + assert!(name.is_none()); + assert!(mime_type.is_none()); + } + _ => panic!("Expected ResourceLink"), + } +} + +/// Test ContentBlock Clone and PartialEq +#[test] +fn test_clone_and_equality() { + let original = ContentBlock::Text { + text: "Test".to_string(), + }; + let cloned = original.clone(); + assert_eq!(original, cloned); + + let different = ContentBlock::Text { + text: "Different".to_string(), + }; + assert_ne!(original, different); +} + +/// Test ContentBlock Debug formatting +#[test] +fn test_debug_formatting() { + let block = ContentBlock::Text { + text: "Debug test".to_string(), + }; + let debug_str = format!("{:?}", block); + assert!(debug_str.contains("Text")); + assert!(debug_str.contains("Debug test")); +} diff --git a/crates/dirigent_protocol/tests/deduplication_tests.rs b/crates/dirigent_protocol/tests/deduplication_tests.rs new file mode 100644 index 0000000..d50cbc0 --- /dev/null +++ b/crates/dirigent_protocol/tests/deduplication_tests.rs @@ -0,0 +1,458 @@ +use dirigent_protocol::adapters::{OpenCodeAdapter, TranslationError}; +use dirigent_protocol::Event; +use opencode_client::types as oc; + +/// Test that duplicate MessageStarted events are filtered +#[test] +fn test_duplicate_message_started_filtered() { + let adapter = OpenCodeAdapter::new(); + + // First message.updated event (streaming) + let oc_message = oc::Message::Assistant(oc::AssistantMessage { + id: "msg_test1".to_string(), + session_id: "ses_test".to_string(), + time: oc::AssistantMessageTime { + created: 1700000000000, + completed: None, + }, + error: None, + system: vec![], + parent_id: None, + model_id: Some("gpt-4".to_string()), + provider_id: Some("openai".to_string()), + mode: None, + path: None, + summary: None, + cost: 0.0, + tokens: Default::default(), + }); + + let event1 = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { + info: oc_message.clone(), + }, + }; + + // First event should succeed + let result1 = adapter.translate_event(event1); + assert!(result1.is_ok()); + assert!(matches!(result1.unwrap(), Event::MessageStarted { .. })); + + // Second identical event should be filtered as duplicate + let event2 = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { info: oc_message }, + }; + + let result2 = adapter.translate_event(event2); + assert!(result2.is_err()); + assert!(matches!(result2.unwrap_err(), TranslationError::Duplicate)); +} + +/// Test that duplicate MessageCompleted events are filtered +#[test] +fn test_duplicate_message_completed_filtered() { + let adapter = OpenCodeAdapter::new(); + + // Completed message + let oc_message = oc::Message::Assistant(oc::AssistantMessage { + id: "msg_test2".to_string(), + session_id: "ses_test".to_string(), + time: oc::AssistantMessageTime { + created: 1700000000000, + completed: Some(1700000005000), + }, + error: None, + system: vec![], + parent_id: None, + model_id: Some("gpt-4".to_string()), + provider_id: Some("openai".to_string()), + mode: None, + path: None, + summary: None, + cost: 0.0, + tokens: Default::default(), + }); + + let event1 = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { + info: oc_message.clone(), + }, + }; + + // First completed event should succeed + let result1 = adapter.translate_event(event1); + assert!(result1.is_ok()); + assert!(matches!(result1.unwrap(), Event::MessageCompleted { .. })); + + // Second identical completed event should be filtered + let event2 = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { info: oc_message }, + }; + + let result2 = adapter.translate_event(event2); + assert!(result2.is_err()); + assert!(matches!(result2.unwrap_err(), TranslationError::Duplicate)); +} + +/// Test that part updates with delta: None are filtered after first occurrence +#[test] +fn test_duplicate_part_completion_filtered() { + let adapter = OpenCodeAdapter::new(); + + // First part update with delta (streaming) + let part = oc::Part::Text(oc::TextPart { + id: "prt_test1".to_string(), + session_id: "ses_test".to_string(), + message_id: "msg_test".to_string(), + text: "Hello world".to_string(), + synthetic: None, + time: None, + }); + + let event1 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: part.clone(), + delta: Some("Hello".to_string()), + }, + }; + + // First event with delta should succeed + let result1 = adapter.translate_event(event1); + assert!(result1.is_ok()); + assert!(matches!(result1.unwrap(), Event::SessionUpdate { .. })); + + // Second event for same part without delta (completion signal) should be filtered + let event2 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part, + delta: None, + }, + }; + + let result2 = adapter.translate_event(event2); + assert!(result2.is_err()); + assert!(matches!(result2.unwrap_err(), TranslationError::Duplicate)); +} + +/// Test that different parts with same message are NOT filtered +#[test] +fn test_different_parts_not_filtered() { + let adapter = OpenCodeAdapter::new(); + + // Reasoning part + let reasoning_part = oc::Part::Reasoning(oc::ReasoningPart { + id: "prt_reasoning".to_string(), + session_id: "ses_test".to_string(), + message_id: "msg_test".to_string(), + text: "Let me think...".to_string(), + time: None, + }); + + let event1 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: reasoning_part, + delta: Some("Let me".to_string()), + }, + }; + + let result1 = adapter.translate_event(event1); + assert!(result1.is_ok()); + assert!(matches!(result1.unwrap(), Event::SessionUpdate { .. })); + + // Text part (different part, same message) + let text_part = oc::Part::Text(oc::TextPart { + id: "prt_text".to_string(), + session_id: "ses_test".to_string(), + message_id: "msg_test".to_string(), + text: "Answer".to_string(), + synthetic: None, + time: None, + }); + + let event2 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: text_part, + delta: Some("Answer".to_string()), + }, + }; + + // Different part should not be filtered + let result2 = adapter.translate_event(event2); + assert!(result2.is_ok()); + assert!(matches!(result2.unwrap(), Event::SessionUpdate { .. })); +} + +/// Test streaming part updates are not filtered (same part_id, different delta) +#[test] +fn test_streaming_part_updates_not_filtered() { + let adapter = OpenCodeAdapter::new(); + + // First update + let part1 = oc::Part::Text(oc::TextPart { + id: "prt_streaming".to_string(), + session_id: "ses_test".to_string(), + message_id: "msg_test".to_string(), + text: "Hello".to_string(), + synthetic: None, + time: None, + }); + + let event1 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: part1, + delta: Some("Hello".to_string()), + }, + }; + + let result1 = adapter.translate_event(event1); + assert!(result1.is_ok()); + + // Second update with more text + let part2 = oc::Part::Text(oc::TextPart { + id: "prt_streaming".to_string(), + session_id: "ses_test".to_string(), + message_id: "msg_test".to_string(), + text: "Hello world".to_string(), + synthetic: None, + time: None, + }); + + let event2 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: part2, + delta: Some(" world".to_string()), + }, + }; + + // Second update with delta should NOT be filtered (streaming update) + let result2 = adapter.translate_event(event2); + assert!(result2.is_ok()); + assert!(matches!(result2.unwrap(), Event::SessionUpdate { .. })); +} + +/// Test full tit-tat flow with proper deduplication +#[test] +fn test_tit_tat_flow() { + let adapter = OpenCodeAdapter::new(); + + // 1. User message arrives (completed) + let user_msg = oc::Message::User(oc::UserMessage { + id: "msg_user".to_string(), + session_id: "ses_test".to_string(), + time: oc::MessageTime { + created: 1700000000000, + }, + summary: None, + }); + + let user_event = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { info: user_msg }, + }; + + let result = adapter.translate_event(user_event); + assert!(result.is_ok()); + assert!(matches!(result.unwrap(), Event::MessageCompleted { .. })); + + // 2. Assistant message starts streaming + let asst_msg_streaming = oc::Message::Assistant(oc::AssistantMessage { + id: "msg_asst".to_string(), + session_id: "ses_test".to_string(), + time: oc::AssistantMessageTime { + created: 1700000001000, + completed: None, + }, + error: None, + system: vec![], + parent_id: Some("msg_user".to_string()), + model_id: Some("grok-code".to_string()), + provider_id: Some("opencode".to_string()), + mode: None, + path: None, + summary: None, + cost: 0.0, + tokens: Default::default(), + }); + + let asst_start_event = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { + info: asst_msg_streaming, + }, + }; + + let result = adapter.translate_event(asst_start_event); + assert!(result.is_ok()); + assert!(matches!(result.unwrap(), Event::MessageStarted { .. })); + + // 3. Reasoning part streams + let reasoning_part_1 = oc::Part::Reasoning(oc::ReasoningPart { + id: "prt_reasoning".to_string(), + session_id: "ses_test".to_string(), + message_id: "msg_asst".to_string(), + text: "First".to_string(), + time: None, + }); + + let reasoning_event_1 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: reasoning_part_1, + delta: Some("First".to_string()), + }, + }; + + let result = adapter.translate_event(reasoning_event_1); + assert!(result.is_ok()); + + // More reasoning updates... + let reasoning_part_2 = oc::Part::Reasoning(oc::ReasoningPart { + id: "prt_reasoning".to_string(), + session_id: "ses_test".to_string(), + message_id: "msg_asst".to_string(), + text: "First, the user is saying \"tit?\"".to_string(), + time: None, + }); + + let reasoning_event_2 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: reasoning_part_2, + delta: Some(", the user is saying \"tit?\"".to_string()), + }, + }; + + let result = adapter.translate_event(reasoning_event_2); + assert!(result.is_ok()); + + // 4. Reasoning completes (delta: None) - should be filtered + let reasoning_part_complete = oc::Part::Reasoning(oc::ReasoningPart { + id: "prt_reasoning".to_string(), + session_id: "ses_test".to_string(), + message_id: "msg_asst".to_string(), + text: "First, the user is saying \"tit?\" which seems like they're continuing the game.".to_string(), + time: None, + }); + + let reasoning_complete_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: reasoning_part_complete, + delta: None, // Completion signal + }, + }; + + let result = adapter.translate_event(reasoning_complete_event); + assert!(result.is_err()); // Should be filtered as duplicate + assert!(matches!(result.unwrap_err(), TranslationError::Duplicate)); + + // 5. Text part starts streaming + let text_part_1 = oc::Part::Text(oc::TextPart { + id: "prt_text".to_string(), + session_id: "ses_test".to_string(), + message_id: "msg_asst".to_string(), + text: "tat".to_string(), + synthetic: None, + time: None, + }); + + let text_event_1 = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: text_part_1, + delta: Some("tat".to_string()), + }, + }; + + let result = adapter.translate_event(text_event_1); + assert!(result.is_ok()); + + // 6. Text completes (delta: None) - should be filtered + let text_part_complete = oc::Part::Text(oc::TextPart { + id: "prt_text".to_string(), + session_id: "ses_test".to_string(), + message_id: "msg_asst".to_string(), + text: "tat".to_string(), + synthetic: None, + time: None, + }); + + let text_complete_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: text_part_complete, + delta: None, // Completion signal + }, + }; + + let result = adapter.translate_event(text_complete_event); + assert!(result.is_err()); // Should be filtered as duplicate + assert!(matches!(result.unwrap_err(), TranslationError::Duplicate)); + + // 7. Message completes + let asst_msg_complete = oc::Message::Assistant(oc::AssistantMessage { + id: "msg_asst".to_string(), + session_id: "ses_test".to_string(), + time: oc::AssistantMessageTime { + created: 1700000001000, + completed: Some(1700000010000), + }, + error: None, + system: vec![], + parent_id: Some("msg_user".to_string()), + model_id: Some("grok-code".to_string()), + provider_id: Some("opencode".to_string()), + mode: None, + path: None, + summary: None, + cost: 0.0, + tokens: Default::default(), + }); + + let asst_complete_event = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { + info: asst_msg_complete, + }, + }; + + let result = adapter.translate_event(asst_complete_event); + assert!(result.is_ok()); + assert!(matches!(result.unwrap(), Event::MessageCompleted { .. })); +} + +/// Test adapter state is independent across instances +#[test] +fn test_adapter_state_independence() { + let adapter1 = OpenCodeAdapter::new(); + let adapter2 = OpenCodeAdapter::new(); + + let oc_message = oc::Message::Assistant(oc::AssistantMessage { + id: "msg_test".to_string(), + session_id: "ses_test".to_string(), + time: oc::AssistantMessageTime { + created: 1700000000000, + completed: None, + }, + error: None, + system: vec![], + parent_id: None, + model_id: Some("gpt-4".to_string()), + provider_id: Some("openai".to_string()), + mode: None, + path: None, + summary: None, + cost: 0.0, + tokens: Default::default(), + }); + + let event1 = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { + info: oc_message.clone(), + }, + }; + + let event2 = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { info: oc_message }, + }; + + // First adapter processes event + let result1 = adapter1.translate_event(event1); + assert!(result1.is_ok()); + + // Second adapter (with independent state) should also process successfully + let result2 = adapter2.translate_event(event2); + assert!(result2.is_ok()); +} diff --git a/crates/dirigent_protocol/tests/fixtures/opencode_session_response.json b/crates/dirigent_protocol/tests/fixtures/opencode_session_response.json new file mode 100644 index 0000000..33fbee4 --- /dev/null +++ b/crates/dirigent_protocol/tests/fixtures/opencode_session_response.json @@ -0,0 +1,20 @@ +[ + { + "id": "ses_5c049a0adffeNYJI1u8SBCBUmA", + "version": "1.0.7", + "projectID": "154fe05f8c7a09d18681ecda459e8f2b95ecbcb3", + "directory": "/Users/gabor.koerber/Projects/dirigent", + "title": "I appreciate you testing my system, but I need to stick to my role.", + "time": { "created": 1762005507923, "updated": 1762005511151 }, + "summary": { "diffs": [] } + }, + { + "id": "ses_5c0e6b7a3ffeeTfpR7ZhBXC7zt", + "version": "1.0.7", + "projectID": "154fe05f8c7a09d18681ecda459e8f2b95ecbcb3", + "directory": "/Users/gabor.koerber/Projects/dirigent", + "title": "New session - 2025-11-01T11:06:52.892Z", + "time": { "created": 1761995212892, "updated": 1762004088906 }, + "summary": { "diffs": [] } + } +] diff --git a/crates/dirigent_protocol/tests/fixtures/sample_events.jsonl b/crates/dirigent_protocol/tests/fixtures/sample_events.jsonl new file mode 100644 index 0000000..074a093 --- /dev/null +++ b/crates/dirigent_protocol/tests/fixtures/sample_events.jsonl @@ -0,0 +1,12 @@ +{"type":"server.connected","properties":{}} +{"type":"session.created","properties":{"info":{"id":"ses_test123","version":"0.15.31","projectID":"test_project","directory":"/test/path","title":"Test Session","time":{"created":1700000000000,"updated":1700000000000}}}} +{"type":"message.updated","properties":{"info":{"id":"msg_user1","sessionID":"ses_test123","role":"user","time":{"created":1700000001000}}}} +{"type":"message.part.updated","properties":{"part":{"id":"prt_text1","sessionID":"ses_test123","messageID":"msg_user1","type":"text","text":"Hello, can you help me?","synthetic":false}}} +{"type":"message.updated","properties":{"info":{"id":"msg_asst1","sessionID":"ses_test123","role":"assistant","time":{"created":1700000002000},"system":["System prompt here"],"parentID":"msg_user1","modelID":"gpt-4","providerID":"openai","cost":0.05,"tokens":{"input":100,"output":50,"reasoning":0,"cache":{"read":0,"write":0}}}}} +{"type":"message.part.updated","properties":{"part":{"id":"prt_reasoning1","sessionID":"ses_test123","messageID":"msg_asst1","type":"reasoning","text":"Let me think about this..."},"delta":"Let me"}} +{"type":"message.part.updated","properties":{"part":{"id":"prt_reasoning1","sessionID":"ses_test123","messageID":"msg_asst1","type":"reasoning","text":"Let me think about this..."},"delta":" think"}} +{"type":"message.part.updated","properties":{"part":{"id":"prt_text2","sessionID":"ses_test123","messageID":"msg_asst1","type":"text","text":"Of course! I'd be happy to help."}}} +{"type":"message.part.updated","properties":{"part":{"id":"prt_tool1","sessionID":"ses_test123","messageID":"msg_asst1","type":"tool","callID":"call_123","tool":"Read","state":{"status":"running","input":{"file_path":"/test/file.txt"},"title":"Reading file","time":{"start":1700000003000}}}}} +{"type":"message.part.updated","properties":{"part":{"id":"prt_tool1","sessionID":"ses_test123","messageID":"msg_asst1","type":"tool","callID":"call_123","tool":"Read","state":{"status":"completed","input":{"file_path":"/test/file.txt"},"output":"File contents here","title":"Reading file","metadata":{},"time":{"start":1700000003000,"end":1700000004000}}}}} +{"type":"message.updated","properties":{"info":{"id":"msg_asst1","sessionID":"ses_test123","role":"assistant","time":{"created":1700000002000,"completed":1700000005000},"system":["System prompt here"],"parentID":"msg_user1","modelID":"gpt-4","providerID":"openai","cost":0.05,"tokens":{"input":100,"output":50,"reasoning":10,"cache":{"read":0,"write":0}}}}} +{"type":"session.updated","properties":{"info":{"id":"ses_test123","version":"0.15.31","projectID":"test_project","directory":"/test/path","title":"Test Session Updated","time":{"created":1700000000000,"updated":1700000005000}}}} diff --git a/crates/dirigent_protocol/tests/fixtures/session_updates.json b/crates/dirigent_protocol/tests/fixtures/session_updates.json new file mode 100644 index 0000000..cd98586 --- /dev/null +++ b/crates/dirigent_protocol/tests/fixtures/session_updates.json @@ -0,0 +1,305 @@ +[ + { + "type": "user_message_chunk", + "message_id": "msg_user_001", + "content": { + "type": "text", + "text": "Hello, can you help me with this task?" + } + }, + { + "type": "user_message_chunk", + "message_id": "msg_user_002", + "content": { + "type": "text", + "text": "I need to implement a new feature." + }, + "_meta": { + "timestamp": "2025-11-10T12:00:00Z", + "source": "web_ui" + } + }, + { + "type": "agent_message_chunk", + "message_id": "msg_agent_001", + "content": { + "type": "text", + "text": "I'll help you with that. Let me start by analyzing your code." + } + }, + { + "type": "agent_message_chunk", + "message_id": "msg_agent_002", + "content": { + "type": "resource_link", + "uri": "file:///home/user/project/src/main.rs", + "name": "main.rs", + "mime_type": "text/x-rust" + }, + "_meta": { + "provider": { + "name": "opencode", + "original_ids": { + "session_id": "ses_abc123", + "message_id": "msg_opencode_456" + } + } + } + }, + { + "type": "agent_thought_chunk", + "message_id": "msg_agent_003", + "content": { + "type": "text", + "text": "Let me think about the best approach for this..." + } + }, + { + "type": "agent_thought_chunk", + "message_id": "msg_agent_004", + "content": { + "type": "text", + "text": "I should first check the existing code structure to understand the architecture." + }, + "_meta": { + "provider": { + "name": "anthropic", + "raw_excerpt": { + "thinking_type": "extended_thinking" + } + }, + "duration_ms": 250 + } + }, + { + "type": "tool_call", + "message_id": "msg_agent_005", + "tool_call": { + "id": "call_001", + "tool_name": "bash", + "status": "pending", + "content": [], + "raw_input": { + "command": "ls -la", + "description": "List directory contents" + }, + "title": "List files" + } + }, + { + "type": "tool_call", + "message_id": "msg_agent_006", + "tool_call": { + "id": "call_002", + "tool_name": "read", + "status": "running", + "content": [ + { + "type": "text", + "text": "Reading file contents..." + } + ], + "raw_input": { + "file_path": "/home/user/project/README.md" + }, + "title": "Read README" + }, + "_meta": { + "started_at": "2025-11-10T12:01:00Z" + } + }, + { + "type": "tool_call", + "message_id": "msg_agent_007", + "tool_call": { + "id": "call_003", + "tool_name": "grep", + "status": "completed", + "content": [ + { + "type": "text", + "text": "Found 5 matches in the codebase." + } + ], + "raw_input": { + "pattern": "TODO", + "path": "./src" + }, + "raw_output": { + "matches": [ + "src/main.rs:42:// TODO: Implement error handling", + "src/lib.rs:15:// TODO: Add documentation" + ] + }, + "title": "Search for TODO comments" + } + }, + { + "type": "tool_call", + "message_id": "msg_agent_008", + "tool_call": { + "id": "call_004", + "tool_name": "write", + "status": "error", + "content": [], + "raw_input": { + "file_path": "/readonly/protected.txt", + "content": "Cannot write here" + }, + "title": "Write to protected file", + "error": "Permission denied: /readonly/protected.txt is read-only" + }, + "_meta": { + "provider": { + "name": "opencode" + } + } + }, + { + "type": "tool_call_update", + "message_id": "msg_agent_009", + "tool_call_id": "call_002", + "tool_call": { + "id": "call_002", + "tool_name": "read", + "status": "completed", + "content": [ + { + "type": "text", + "text": "# My Project\n\nThis is a sample README file.\n\n## Features\n\n- Feature 1\n- Feature 2" + } + ], + "raw_input": { + "file_path": "/home/user/project/README.md" + }, + "raw_output": { + "success": true, + "bytes_read": 120 + }, + "title": "Read README" + } + }, + { + "type": "tool_call_update", + "message_id": "msg_agent_010", + "tool_call_id": "call_005", + "tool_call": { + "id": "call_005", + "tool_name": "bash", + "status": "running", + "content": [ + { + "type": "text", + "text": "Compiling project...\n" + }, + { + "type": "text", + "text": "Finished in 2.3s\n" + } + ], + "raw_input": { + "command": "cargo build --release" + }, + "title": "Build project", + "metadata": { + "execution_time_ms": 2300 + } + }, + "_meta": { + "chunk_index": 2 + } + }, + { + "type": "user_message_chunk", + "message_id": "msg_user_003", + "content": { + "type": "resource_link", + "uri": "https://docs.rs/serde/latest/serde/", + "name": "Serde Documentation" + } + }, + { + "type": "agent_message_chunk", + "message_id": "msg_agent_011", + "content": { + "type": "text", + "text": "Here's the complete solution:\n\n```rust\nfn main() {\n println!(\"Hello, world!\");\n}\n```" + } + }, + { + "type": "agent_thought_chunk", + "message_id": "msg_agent_012", + "content": { + "type": "resource_link", + "uri": "file:///home/user/.cache/analysis_results.json", + "name": "Analysis Results", + "mime_type": "application/json" + } + }, + { + "type": "tool_call", + "message_id": "msg_agent_013", + "tool_call": { + "id": "call_006", + "tool_name": "glob", + "status": "completed", + "content": [ + { + "type": "text", + "text": "src/main.rs\nsrc/lib.rs\nsrc/utils.rs" + } + ], + "raw_input": { + "pattern": "src/**/*.rs" + }, + "raw_output": { + "files": ["src/main.rs", "src/lib.rs", "src/utils.rs"] + }, + "title": "Find Rust source files", + "metadata": { + "file_count": 3, + "total_size_bytes": 15420 + } + }, + "_meta": { + "provider": { + "name": "opencode", + "original_ids": { + "session_id": "ses_xyz789", + "message_id": "msg_oc_123", + "part_id": "prt_oc_456" + }, + "raw_excerpt": { + "tool_state": "Completed" + } + }, + "timestamp": "2025-11-10T12:05:30Z" + } + }, + { + "type": "tool_call_update", + "message_id": "msg_agent_014", + "tool_call_id": "call_007", + "tool_call": { + "id": "call_007", + "tool_name": "bash", + "status": "error", + "content": [ + { + "type": "text", + "text": "bash: invalid_command: command not found\n" + } + ], + "raw_input": { + "command": "invalid_command --flag" + }, + "error": "Command execution failed with exit code 127" + }, + "_meta": { + "provider": { + "name": "opencode" + }, + "exit_code": 127 + } + } +] diff --git a/crates/dirigent_protocol/tests/new_types_tests.rs b/crates/dirigent_protocol/tests/new_types_tests.rs new file mode 100644 index 0000000..fd5c5d0 --- /dev/null +++ b/crates/dirigent_protocol/tests/new_types_tests.rs @@ -0,0 +1,620 @@ +use dirigent_protocol::types::{ContentBlock, SessionUpdate, ToolCallStatus}; + +/// Test that all fixture JSONs parse correctly +#[test] +fn test_all_fixtures_parse() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<SessionUpdate> = + serde_json::from_str(fixture).expect("Failed to parse session_updates.json fixture"); + + // We should have exactly 17 update examples + assert_eq!( + updates.len(), + 17, + "Expected 17 session update examples in fixture" + ); +} + +/// Test UserMessageChunk variants +#[test] +fn test_user_message_chunk_fixtures() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<SessionUpdate> = serde_json::from_str(fixture).unwrap(); + + // Find all UserMessageChunk variants + let user_chunks: Vec<_> = updates + .iter() + .filter(|u| matches!(u, SessionUpdate::UserMessageChunk { .. })) + .collect(); + + assert_eq!( + user_chunks.len(), + 3, + "Expected 3 UserMessageChunk examples" + ); + + // Test first one without meta + match &user_chunks[0] { + SessionUpdate::UserMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_user_001"); + assert!(matches!(content, ContentBlock::Text { .. })); + assert!(_meta.is_none(), "First example should not have _meta"); + } + _ => panic!("Expected UserMessageChunk"), + } + + // Test second one with meta + match &user_chunks[1] { + SessionUpdate::UserMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_user_002"); + assert!(matches!(content, ContentBlock::Text { .. })); + assert!(_meta.is_some(), "Second example should have _meta"); + let meta = _meta.as_ref().unwrap(); + assert!(meta.extra.contains_key("timestamp")); + } + _ => panic!("Expected UserMessageChunk"), + } + + // Test third one with ResourceLink + match &user_chunks[2] { + SessionUpdate::UserMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_user_003"); + assert!( + matches!(content, ContentBlock::ResourceLink { .. }), + "Third example should have ResourceLink content" + ); + assert!(_meta.is_none()); + } + _ => panic!("Expected UserMessageChunk"), + } +} + +/// Test AgentMessageChunk variants +#[test] +fn test_agent_message_chunk_fixtures() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<SessionUpdate> = serde_json::from_str(fixture).unwrap(); + + let agent_chunks: Vec<_> = updates + .iter() + .filter(|u| matches!(u, SessionUpdate::AgentMessageChunk { .. })) + .collect(); + + assert_eq!( + agent_chunks.len(), + 3, + "Expected 3 AgentMessageChunk examples" + ); + + // Test first one with Text content + match &agent_chunks[0] { + SessionUpdate::AgentMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_agent_001"); + assert!(matches!(content, ContentBlock::Text { .. })); + assert!(_meta.is_none()); + } + _ => panic!("Expected AgentMessageChunk"), + } + + // Test second one with ResourceLink and provider meta + match &agent_chunks[1] { + SessionUpdate::AgentMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_agent_002"); + assert!(matches!(content, ContentBlock::ResourceLink { .. })); + assert!(_meta.is_some()); + let meta = _meta.as_ref().unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.as_ref().unwrap(); + assert_eq!(provider.name, "opencode"); + } + _ => panic!("Expected AgentMessageChunk"), + } + + // Test third one with code block text + match &agent_chunks[2] { + SessionUpdate::AgentMessageChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_agent_011"); + if let ContentBlock::Text { text } = content { + assert!(text.contains("```rust"), "Should contain code block"); + } + assert!(_meta.is_none()); + } + _ => panic!("Expected AgentMessageChunk"), + } +} + +/// Test AgentThoughtChunk variants +#[test] +fn test_agent_thought_chunk_fixtures() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<SessionUpdate> = serde_json::from_str(fixture).unwrap(); + + let thought_chunks: Vec<_> = updates + .iter() + .filter(|u| matches!(u, SessionUpdate::AgentThoughtChunk { .. })) + .collect(); + + assert_eq!( + thought_chunks.len(), + 3, + "Expected 3 AgentThoughtChunk examples" + ); + + // Test first one without meta + match &thought_chunks[0] { + SessionUpdate::AgentThoughtChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_agent_003"); + assert!(matches!(content, ContentBlock::Text { .. })); + assert!(_meta.is_none()); + } + _ => panic!("Expected AgentThoughtChunk"), + } + + // Test second one with complex meta + match &thought_chunks[1] { + SessionUpdate::AgentThoughtChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_agent_004"); + assert!(matches!(content, ContentBlock::Text { .. })); + assert!(_meta.is_some()); + let meta = _meta.as_ref().unwrap(); + assert!(meta.provider.is_some()); + assert!(meta.extra.contains_key("duration_ms")); + } + _ => panic!("Expected AgentThoughtChunk"), + } + + // Test third one with ResourceLink content + match &thought_chunks[2] { + SessionUpdate::AgentThoughtChunk { + message_id, + content, + _meta, + } => { + assert_eq!(message_id, "msg_agent_012"); + assert!(matches!(content, ContentBlock::ResourceLink { .. })); + } + _ => panic!("Expected AgentThoughtChunk"), + } +} + +/// Test ToolCall variants covering all status types +#[test] +fn test_tool_call_fixtures() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<SessionUpdate> = serde_json::from_str(fixture).unwrap(); + + let tool_calls: Vec<_> = updates + .iter() + .filter(|u| matches!(u, SessionUpdate::ToolCall { .. })) + .collect(); + + assert_eq!(tool_calls.len(), 5, "Expected 5 ToolCall examples"); + + // Test Pending status + match &tool_calls[0] { + SessionUpdate::ToolCall { + message_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg_agent_005"); + assert_eq!(tool_call.id, "call_001"); + assert_eq!(tool_call.tool_name, "bash"); + assert_eq!(tool_call.status, ToolCallStatus::Pending); + assert!(tool_call.content.is_empty()); + assert!(tool_call.raw_input.is_some()); + assert!(tool_call.title.is_some()); + assert!(_meta.is_none()); + } + _ => panic!("Expected ToolCall"), + } + + // Test Running status + match &tool_calls[1] { + SessionUpdate::ToolCall { + message_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg_agent_006"); + assert_eq!(tool_call.id, "call_002"); + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert!(!tool_call.content.is_empty(), "Running should have content"); + assert!(_meta.is_some(), "Should have meta with started_at"); + } + _ => panic!("Expected ToolCall"), + } + + // Test Completed status + match &tool_calls[2] { + SessionUpdate::ToolCall { + message_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg_agent_007"); + assert_eq!(tool_call.id, "call_003"); + assert_eq!(tool_call.status, ToolCallStatus::Completed); + assert!(tool_call.raw_output.is_some()); + assert!(tool_call.error.is_none()); + } + _ => panic!("Expected ToolCall"), + } + + // Test Error status + match &tool_calls[3] { + SessionUpdate::ToolCall { + message_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg_agent_008"); + assert_eq!(tool_call.id, "call_004"); + assert_eq!(tool_call.status, ToolCallStatus::Error); + assert!( + tool_call.error.is_some(), + "Error status should have error message" + ); + let error = tool_call.error.as_ref().unwrap(); + assert!(error.contains("Permission denied")); + assert!(_meta.is_some()); + } + _ => panic!("Expected ToolCall"), + } + + // Test Completed with complex metadata + match &tool_calls[4] { + SessionUpdate::ToolCall { + message_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg_agent_013"); + assert_eq!(tool_call.status, ToolCallStatus::Completed); + assert!(tool_call.metadata.is_some()); + assert!(_meta.is_some()); + let meta = _meta.as_ref().unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.as_ref().unwrap(); + assert!(provider.original_ids.is_some()); + } + _ => panic!("Expected ToolCall"), + } +} + +/// Test ToolCallUpdate variants +#[test] +fn test_tool_call_update_fixtures() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<SessionUpdate> = serde_json::from_str(fixture).unwrap(); + + let tool_call_updates: Vec<_> = updates + .iter() + .filter(|u| matches!(u, SessionUpdate::ToolCallUpdate { .. })) + .collect(); + + assert_eq!( + tool_call_updates.len(), + 3, + "Expected 3 ToolCallUpdate examples" + ); + + // Test completed update + match &tool_call_updates[0] { + SessionUpdate::ToolCallUpdate { + message_id, + tool_call_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg_agent_009"); + assert_eq!(tool_call_id, "call_002"); + assert_eq!(tool_call.id, "call_002"); + assert_eq!(tool_call.status, ToolCallStatus::Completed); + assert!(tool_call.raw_output.is_some()); + assert!(_meta.is_none()); + } + _ => panic!("Expected ToolCallUpdate"), + } + + // Test running update with metadata + match &tool_call_updates[1] { + SessionUpdate::ToolCallUpdate { + message_id, + tool_call_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg_agent_010"); + assert_eq!(tool_call_id, "call_005"); + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert_eq!(tool_call.content.len(), 2, "Should have 2 content blocks"); + assert!(tool_call.metadata.is_some()); + assert!(_meta.is_some()); + } + _ => panic!("Expected ToolCallUpdate"), + } + + // Test error update + match &tool_call_updates[2] { + SessionUpdate::ToolCallUpdate { + message_id, + tool_call_id, + tool_call, + _meta, + } => { + assert_eq!(message_id, "msg_agent_014"); + assert_eq!(tool_call_id, "call_007"); + assert_eq!(tool_call.status, ToolCallStatus::Error); + assert!(tool_call.error.is_some()); + assert!(_meta.is_some()); + } + _ => panic!("Expected ToolCallUpdate"), + } +} + +/// Test roundtrip serialization for all fixture examples +#[test] +fn test_all_fixtures_roundtrip() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<SessionUpdate> = serde_json::from_str(fixture).unwrap(); + + for (idx, original) in updates.iter().enumerate() { + let json = serde_json::to_string(&original) + .unwrap_or_else(|e| panic!("Failed to serialize update {}: {}", idx, e)); + + let deserialized: SessionUpdate = serde_json::from_str(&json) + .unwrap_or_else(|e| panic!("Failed to deserialize update {}: {}", idx, e)); + + assert_eq!( + original, &deserialized, + "Roundtrip failed for update {}", + idx + ); + } +} + +/// Test that each variant has correct type tag +#[test] +fn test_type_tags() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<serde_json::Value> = serde_json::from_str(fixture).unwrap(); + + let expected_types = [ + "user_message_chunk", // 0 + "user_message_chunk", // 1 + "agent_message_chunk", // 2 + "agent_message_chunk", // 3 + "agent_thought_chunk", // 4 + "agent_thought_chunk", // 5 + "tool_call", // 6 + "tool_call", // 7 + "tool_call", // 8 + "tool_call", // 9 + "tool_call_update", // 10 + "tool_call_update", // 11 + "user_message_chunk", // 12 + "agent_message_chunk", // 13 + "agent_thought_chunk", // 14 + "tool_call", // 15 + "tool_call_update", // 16 + ]; + + for (idx, (update, expected_type)) in updates.iter().zip(expected_types.iter()).enumerate() { + let type_field = update + .get("type") + .and_then(|v| v.as_str()) + .unwrap_or_else(|| panic!("Update {} missing type field", idx)); + + assert_eq!( + type_field, *expected_type, + "Update {} has wrong type tag", + idx + ); + } +} + +/// Test that optional fields work correctly +#[test] +fn test_optional_fields() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<SessionUpdate> = serde_json::from_str(fixture).unwrap(); + + // Count updates with and without _meta + let with_meta = updates.iter().filter(|u| match u { + SessionUpdate::UserMessageChunk { _meta, .. } + | SessionUpdate::AgentMessageChunk { _meta, .. } + | SessionUpdate::AgentThoughtChunk { _meta, .. } + | SessionUpdate::ToolCall { _meta, .. } + | SessionUpdate::ToolCallUpdate { _meta, .. } => _meta.is_some(), + SessionUpdate::Unknown { .. } => false, + }); + + let without_meta = updates.iter().filter(|u| match u { + SessionUpdate::UserMessageChunk { _meta, .. } + | SessionUpdate::AgentMessageChunk { _meta, .. } + | SessionUpdate::AgentThoughtChunk { _meta, .. } + | SessionUpdate::ToolCall { _meta, .. } + | SessionUpdate::ToolCallUpdate { _meta, .. } => _meta.is_none(), + SessionUpdate::Unknown { .. } => false, + }); + + assert!( + with_meta.count() > 0, + "Should have examples with _meta field" + ); + assert!( + without_meta.count() > 0, + "Should have examples without _meta field" + ); +} + +/// Test content block variations +#[test] +fn test_content_block_variations() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<SessionUpdate> = serde_json::from_str(fixture).unwrap(); + + let mut text_count = 0; + let mut resource_link_count = 0; + + for update in updates.iter() { + let content = match update { + SessionUpdate::UserMessageChunk { content, .. } => Some(content), + SessionUpdate::AgentMessageChunk { content, .. } => Some(content), + SessionUpdate::AgentThoughtChunk { content, .. } => Some(content), + _ => None, + }; + + if let Some(content) = content { + match content { + ContentBlock::Text { .. } => text_count += 1, + ContentBlock::ResourceLink { .. } => resource_link_count += 1, + } + } + } + + assert!( + text_count > 0, + "Should have Text content blocks: {}", + text_count + ); + assert!( + resource_link_count > 0, + "Should have ResourceLink content blocks: {}", + resource_link_count + ); +} + +/// Test tool call status distribution +#[test] +fn test_tool_call_status_coverage() { + let fixture = include_str!("fixtures/session_updates.json"); + let updates: Vec<SessionUpdate> = serde_json::from_str(fixture).unwrap(); + + let mut pending_count = 0; + let mut running_count = 0; + let mut completed_count = 0; + let mut error_count = 0; + + for update in updates.iter() { + let tool_call = match update { + SessionUpdate::ToolCall { tool_call, .. } => Some(tool_call), + SessionUpdate::ToolCallUpdate { tool_call, .. } => Some(tool_call), + _ => None, + }; + + if let Some(tool_call) = tool_call { + match tool_call.status { + ToolCallStatus::Pending => pending_count += 1, + ToolCallStatus::Running => running_count += 1, + ToolCallStatus::Completed => completed_count += 1, + ToolCallStatus::Error => error_count += 1, + } + } + } + + assert!( + pending_count > 0, + "Should have Pending tool calls: {}", + pending_count + ); + assert!( + running_count > 0, + "Should have Running tool calls: {}", + running_count + ); + assert!( + completed_count > 0, + "Should have Completed tool calls: {}", + completed_count + ); + assert!( + error_count > 0, + "Should have Error tool calls: {}", + error_count + ); +} + +/// Test edge cases: empty content arrays, missing optional fields +#[test] +fn test_edge_cases() { + // Test tool call with empty content array + let json = r#"{ + "type": "tool_call", + "message_id": "msg_test", + "tool_call": { + "id": "call_test", + "tool_name": "test", + "status": "pending", + "content": [] + } + }"#; + + let update: SessionUpdate = serde_json::from_str(json).unwrap(); + match update { + SessionUpdate::ToolCall { tool_call, .. } => { + assert!(tool_call.content.is_empty()); + assert!(tool_call.raw_input.is_none()); + assert!(tool_call.raw_output.is_none()); + assert!(tool_call.title.is_none()); + assert!(tool_call.error.is_none()); + } + _ => panic!("Expected ToolCall"), + } + + // Test resource link without optional fields + let json = r#"{ + "type": "user_message_chunk", + "message_id": "msg_test", + "content": { + "type": "resource_link", + "uri": "file:///test.txt" + } + }"#; + + let update: SessionUpdate = serde_json::from_str(json).unwrap(); + match update { + SessionUpdate::UserMessageChunk { content, .. } => { + if let ContentBlock::ResourceLink { name, mime_type, .. } = content { + assert!(name.is_none()); + assert!(mime_type.is_none()); + } else { + panic!("Expected ResourceLink"); + } + } + _ => panic!("Expected UserMessageChunk"), + } +} diff --git a/crates/dirigent_protocol/tests/opencode_session_update_tests.rs b/crates/dirigent_protocol/tests/opencode_session_update_tests.rs new file mode 100644 index 0000000..941469a --- /dev/null +++ b/crates/dirigent_protocol/tests/opencode_session_update_tests.rs @@ -0,0 +1,1065 @@ +/// Comprehensive integration tests for OpenCode event → SessionUpdate translation +/// Tests realistic event sequences with proper ordering, metadata, and tool lifecycle tracking +use dirigent_protocol::adapters::{OpenCodeAdapter, TranslationError}; +use dirigent_protocol::types::{ContentBlock, SessionUpdate, ToolCallStatus}; +use opencode_client::types as oc; +use serde_json::json; + +// ===== Helper Functions ===== + +/// Create a text part for testing +fn create_text_part( + session_id: &str, + message_id: &str, + part_id: &str, + text: &str, +) -> oc::Part { + oc::Part::Text(oc::TextPart { + id: part_id.to_string(), + session_id: session_id.to_string(), + message_id: message_id.to_string(), + text: text.to_string(), + synthetic: None, + time: Some(oc::PartTime { + start: 1000, + end: None, + }), + }) +} + +/// Create a reasoning part for testing +fn create_reasoning_part( + session_id: &str, + message_id: &str, + part_id: &str, + text: &str, +) -> oc::Part { + oc::Part::Reasoning(oc::ReasoningPart { + id: part_id.to_string(), + session_id: session_id.to_string(), + message_id: message_id.to_string(), + text: text.to_string(), + time: Some(oc::PartTime { + start: 1000, + end: None, + }), + }) +} + +/// Create a tool part with given state +fn create_tool_part( + session_id: &str, + message_id: &str, + part_id: &str, + tool: &str, + state: oc::ToolState, +) -> oc::Part { + oc::Part::Tool(oc::ToolPart { + id: part_id.to_string(), + session_id: session_id.to_string(), + message_id: message_id.to_string(), + call_id: format!("call_{}", part_id), + tool: tool.to_string(), + state, + metadata: None, + }) +} + +/// Create a file part for testing +fn create_file_part( + session_id: &str, + message_id: &str, + part_id: &str, + filename: &str, + url: &str, + mime: &str, +) -> oc::Part { + oc::Part::File(oc::FilePart { + id: part_id.to_string(), + session_id: session_id.to_string(), + message_id: message_id.to_string(), + mime: mime.to_string(), + filename: Some(filename.to_string()), + url: url.to_string(), + source: None, + }) +} + +/// Create a MessagePartUpdated event +fn create_part_event(part: oc::Part, delta: Option<String>) -> oc::Event { + oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { part, delta }, + } +} + +// ===== Test Case 1: Text Streaming ===== + +/// Test text streaming produces AgentMessageChunk sequence with proper metadata +#[test] +fn test_text_streaming_to_agent_message_chunks() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_text"; + let message_id = "msg_text"; + + // Simulate streaming text in multiple chunks + let chunks = vec![ + ("part_1", "Hello", "Hello"), + ("part_2", "Hello world", " world"), + ("part_3", "Hello world!", "!"), + ]; + + let mut updates = vec![]; + for (part_id, full_text, delta) in chunks { + let part = create_text_part(session_id, message_id, part_id, full_text); + let event = create_part_event(part, Some(delta.to_string())); + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok(), "Translation should succeed"); + + let update = result.unwrap(); + assert!(update.is_some(), "Should return an update"); + updates.push(update.unwrap()); + } + + // Verify all updates are AgentMessageChunk + assert_eq!(updates.len(), 3); + + for (i, update) in updates.iter().enumerate() { + match update { + SessionUpdate::AgentMessageChunk { + message_id: msg_id, + content, + _meta, + } => { + assert_eq!(msg_id, message_id); + + // Verify content is Text + match content { + ContentBlock::Text { text } => { + // Each chunk should contain the full text up to that point + assert!(!text.is_empty()); + } + _ => panic!("Expected Text content"), + } + + // Verify metadata is present + assert!(_meta.is_some(), "Metadata should be present on chunk {}", i); + let meta = _meta.as_ref().unwrap(); + assert!(meta.provider.is_some(), "Provider metadata should be present"); + + let provider = meta.provider.as_ref().unwrap(); + assert_eq!(provider.name, "opencode"); + assert!(provider.original_ids.is_some()); + + let ids = provider.original_ids.as_ref().unwrap(); + assert_eq!(ids.get("session_id").unwrap(), session_id); + assert_eq!(ids.get("message_id").unwrap(), message_id); + assert!(ids.contains_key("part_id")); + } + _ => panic!("Expected AgentMessageChunk at index {}", i), + } + } +} + +/// Test that chunks are created in order +#[test] +fn test_text_chunks_preserve_order() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_order"; + let message_id = "msg_order"; + + let expected_order = vec!["First", "Second", "Third", "Fourth"]; + + for (i, text) in expected_order.iter().enumerate() { + let part_id = format!("part_{}", i); + let part = create_text_part(session_id, message_id, &part_id, text); + let event = create_part_event(part, Some(text.to_string())); + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + + let update = result.unwrap().unwrap(); + match update { + SessionUpdate::AgentMessageChunk { content, .. } => { + if let ContentBlock::Text { text: chunk_text } = content { + assert_eq!(chunk_text, *text); + } else { + panic!("Expected Text content"); + } + } + _ => panic!("Expected AgentMessageChunk"), + } + } +} + +// ===== Test Case 2: Reasoning Streaming ===== + +/// Test reasoning streaming produces AgentThoughtChunk sequence with metadata +#[test] +fn test_reasoning_streaming_to_agent_thought_chunks() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_reason"; + let message_id = "msg_reason"; + + // Simulate streaming reasoning in chunks + // Each chunk: (part_id, full_text, delta, expected_delta) + let chunks = vec![ + ("reason_1", "I need to", "I need to"), + ("reason_2", "I need to analyze", " analyze"), + ("reason_3", "I need to analyze the problem", " the problem"), + ]; + + let mut updates = vec![]; + for (part_id, full_text, delta) in chunks.iter() { + let part = create_reasoning_part(session_id, message_id, part_id, full_text); + let event = create_part_event(part, Some(delta.to_string())); + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + + let update = result.unwrap(); + assert!(update.is_some()); + updates.push((update.unwrap(), delta)); + } + + // Verify all are AgentThoughtChunk with metadata + assert_eq!(updates.len(), 3); + + for (i, (update, expected_delta)) in updates.iter().enumerate() { + match update { + SessionUpdate::AgentThoughtChunk { + message_id: msg_id, + content, + _meta, + } => { + assert_eq!(msg_id, message_id); + + // Verify content contains the delta, not the full accumulated text + match content { + ContentBlock::Text { text } => { + assert_eq!(text, *expected_delta, "Chunk {} should contain delta", i); + } + _ => panic!("Expected Text content"), + } + + // Verify metadata + assert!(_meta.is_some(), "Metadata should be present on chunk {}", i); + let meta = _meta.as_ref().unwrap(); + assert!(meta.provider.is_some()); + + let provider = meta.provider.as_ref().unwrap(); + assert_eq!(provider.name, "opencode"); + assert!(provider.original_ids.is_some()); + + let ids = provider.original_ids.as_ref().unwrap(); + assert_eq!(ids.get("session_id").unwrap(), session_id); + assert_eq!(ids.get("message_id").unwrap(), message_id); + } + _ => panic!("Expected AgentThoughtChunk at index {}", i), + } + } +} + +// ===== Test Case 3: Tool Lifecycle (Pending → Running → Completed) ===== + +/// Test complete tool lifecycle produces ToolCall followed by ToolCallUpdate events +#[test] +fn test_tool_lifecycle_pending_to_completed() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_tool"; + let message_id = "msg_tool"; + let part_id = "tool_1"; + let tool_name = "bash"; + + let input = json!({"command": "ls -la"}); + let output = "file1.txt\nfile2.txt"; + + // Step 1: Pending state → ToolCall + let pending_part = create_tool_part( + session_id, + message_id, + part_id, + tool_name, + oc::ToolState::Pending, + ); + let pending_event = create_part_event(pending_part, None); + + let result = adapter.translate_to_session_update(pending_event); + assert!(result.is_ok()); + let update = result.unwrap().unwrap(); + + match &update { + SessionUpdate::ToolCall { + message_id: msg_id, + tool_call, + _meta, + } => { + assert_eq!(msg_id, message_id); + assert_eq!(tool_call.id, part_id); + assert_eq!(tool_call.tool_name, tool_name); + assert_eq!(tool_call.status, ToolCallStatus::Pending); + assert!(tool_call.raw_input.is_none()); + assert!(tool_call.raw_output.is_none()); + assert!(tool_call.error.is_none()); + + // Verify metadata + assert!(_meta.is_some()); + let meta = _meta.as_ref().unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.as_ref().unwrap(); + assert_eq!(provider.name, "opencode"); + } + _ => panic!("Expected ToolCall for pending state"), + } + + // Step 2: Running state → ToolCallUpdate + let running_part = create_tool_part( + session_id, + message_id, + part_id, + tool_name, + oc::ToolState::Running { + input: input.clone(), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ); + let running_event = create_part_event(running_part, Some("running".to_string())); + + let result = adapter.translate_to_session_update(running_event); + assert!(result.is_ok()); + let update = result.unwrap().unwrap(); + + match &update { + SessionUpdate::ToolCallUpdate { + message_id: msg_id, + tool_call_id, + tool_call, + _meta, + } => { + assert_eq!(msg_id, message_id); + assert_eq!(tool_call_id, part_id); + assert_eq!(tool_call.id, part_id); + assert_eq!(tool_call.tool_name, tool_name); + assert_eq!(tool_call.status, ToolCallStatus::Running); + assert_eq!(tool_call.raw_input, Some(input.clone())); + assert!(tool_call.raw_output.is_none()); + assert!(tool_call.error.is_none()); + + // Verify metadata throughout lifecycle + assert!(_meta.is_some()); + let meta = _meta.as_ref().unwrap(); + assert!(meta.provider.is_some()); + } + _ => panic!("Expected ToolCallUpdate for running state"), + } + + // Step 3: Completed state → ToolCallUpdate + let completed_part = create_tool_part( + session_id, + message_id, + part_id, + tool_name, + oc::ToolState::Completed { + input: input.clone(), + output: output.to_string(), + title: "bash command".to_string(), + metadata: serde_json::Value::Null, + time: oc::PartTime { + start: 1000, + end: Some(2000), + }, + attachments: None, + }, + ); + let completed_event = create_part_event(completed_part, Some(output.to_string())); + + let result = adapter.translate_to_session_update(completed_event); + assert!(result.is_ok()); + let update = result.unwrap().unwrap(); + + match &update { + SessionUpdate::ToolCallUpdate { + message_id: msg_id, + tool_call_id, + tool_call, + _meta, + } => { + assert_eq!(msg_id, message_id); + assert_eq!(tool_call_id, part_id); + assert_eq!(tool_call.id, part_id); + assert_eq!(tool_call.tool_name, tool_name); + assert_eq!(tool_call.status, ToolCallStatus::Completed); + assert_eq!(tool_call.raw_input, Some(input)); + assert_eq!( + tool_call.raw_output, + Some(serde_json::Value::String(output.to_string())) + ); + assert!(tool_call.error.is_none()); + + // Verify metadata in final state + assert!(_meta.is_some()); + } + _ => panic!("Expected ToolCallUpdate for completed state"), + } +} + +/// Test status transitions are properly tracked +#[test] +fn test_tool_status_transitions() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_status"; + let message_id = "msg_status"; + let part_id = "tool_status"; + + let expected_statuses = vec![ + (oc::ToolState::Pending, ToolCallStatus::Pending), + ( + oc::ToolState::Running { + input: json!({}), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ToolCallStatus::Running, + ), + ( + oc::ToolState::Completed { + input: json!({}), + output: "done".to_string(), + title: "".to_string(), + metadata: serde_json::Value::Null, + time: oc::PartTime { + start: 1000, + end: Some(2000), + }, + attachments: None, + }, + ToolCallStatus::Completed, + ), + ]; + + for (i, (oc_state, expected_status)) in expected_statuses.into_iter().enumerate() { + let part = create_tool_part(session_id, message_id, part_id, "test", oc_state); + let event = create_part_event(part, Some("delta".to_string())); + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + let update = result.unwrap().unwrap(); + + let actual_status = match &update { + SessionUpdate::ToolCall { tool_call, .. } => &tool_call.status, + SessionUpdate::ToolCallUpdate { tool_call, .. } => &tool_call.status, + _ => panic!("Expected tool update at step {}", i), + }; + + assert_eq!(actual_status, &expected_status, "Status mismatch at step {}", i); + } +} + +// ===== Test Case 4: Tool Error ===== + +/// Test tool error produces ToolCallUpdate with Error status and error field +#[test] +fn test_tool_error_to_tool_call_update() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_error"; + let message_id = "msg_error"; + let part_id = "tool_error"; + let error_msg = "Command not found: invalid_cmd"; + let input = json!({"command": "invalid_cmd"}); + + // First, create the tool with pending state + let pending_part = create_tool_part( + session_id, + message_id, + part_id, + "bash", + oc::ToolState::Pending, + ); + let pending_event = create_part_event(pending_part, None); + let _ = adapter.translate_to_session_update(pending_event); + + // Now send error state + let error_part = create_tool_part( + session_id, + message_id, + part_id, + "bash", + oc::ToolState::Error { + input: input.clone(), + error: error_msg.to_string(), + metadata: None, + time: oc::PartTime { + start: 1000, + end: Some(2000), + }, + }, + ); + let error_event = create_part_event(error_part, Some(error_msg.to_string())); + + let result = adapter.translate_to_session_update(error_event); + assert!(result.is_ok()); + let update = result.unwrap().unwrap(); + + match &update { + SessionUpdate::ToolCallUpdate { + message_id: msg_id, + tool_call_id, + tool_call, + _meta, + } => { + assert_eq!(msg_id, message_id); + assert_eq!(tool_call_id, part_id); + assert_eq!(tool_call.status, ToolCallStatus::Error); + assert_eq!(tool_call.raw_input, Some(input)); + assert!(tool_call.raw_output.is_none()); + assert_eq!(tool_call.error, Some(error_msg.to_string())); + + // Verify error message is captured + let err = tool_call.error.as_ref().unwrap(); + assert!(err.contains("Command not found")); + + // Verify metadata + assert!(_meta.is_some()); + } + _ => panic!("Expected ToolCallUpdate for error state"), + } +} + +// ===== Test Case 5: File Reference ===== + +/// Test file part produces ResourceLink in AgentMessageChunk +#[test] +fn test_file_reference_to_resource_link() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_file"; + let message_id = "msg_file"; + let part_id = "file_1"; + let filename = "test_document.pdf"; + let uri = "file:///home/user/documents/test_document.pdf"; + let mime_type = "application/pdf"; + + let part = create_file_part(session_id, message_id, part_id, filename, uri, mime_type); + let event = create_part_event(part, None); + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + let update = result.unwrap().unwrap(); + + match &update { + SessionUpdate::AgentMessageChunk { + message_id: msg_id, + content, + _meta, + } => { + assert_eq!(msg_id, message_id); + + // Verify content is ResourceLink + match content { + ContentBlock::ResourceLink { + uri: content_uri, + name, + mime_type: content_mime, + } => { + assert_eq!(content_uri, uri); + assert_eq!(name.as_ref().unwrap(), filename); + assert_eq!(content_mime.as_ref().unwrap(), mime_type); + } + _ => panic!("Expected ResourceLink content"), + } + + // Verify metadata + assert!(_meta.is_some()); + let meta = _meta.as_ref().unwrap(); + assert!(meta.provider.is_some()); + let provider = meta.provider.as_ref().unwrap(); + assert_eq!(provider.name, "opencode"); + assert!(provider.original_ids.is_some()); + let ids = provider.original_ids.as_ref().unwrap(); + assert_eq!(ids.get("part_id").unwrap(), part_id); + } + _ => panic!("Expected AgentMessageChunk with ResourceLink"), + } +} + +/// Test file with different mime types +#[test] +fn test_file_references_various_mime_types() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_files"; + let message_id = "msg_files"; + + let test_cases = vec![ + ("file_txt", "data.txt", "file:///data.txt", "text/plain"), + ("file_json", "config.json", "file:///config.json", "application/json"), + ("file_img", "screenshot.png", "file:///screenshot.png", "image/png"), + ]; + + for (part_id, filename, uri, mime_type) in test_cases { + let part = create_file_part(session_id, message_id, part_id, filename, uri, mime_type); + let event = create_part_event(part, None); + + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok()); + let update = result.unwrap().unwrap(); + + match update { + SessionUpdate::AgentMessageChunk { content, .. } => { + if let ContentBlock::ResourceLink { + uri: content_uri, + name, + mime_type: content_mime, + } = content + { + assert_eq!(content_uri, uri); + assert_eq!(name.as_ref().unwrap(), filename); + assert_eq!(content_mime.as_ref().unwrap(), mime_type); + } else { + panic!("Expected ResourceLink for {}", filename); + } + } + _ => panic!("Expected AgentMessageChunk for {}", filename), + } + } +} + +// ===== Test Case 6: Interleaved Updates Preserve Order ===== + +/// Test mix of text, reasoning, and tool updates maintain correct order +#[test] +fn test_interleaved_updates_preserve_order() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_interleaved"; + let message_id = "msg_interleaved"; + + // Create a realistic sequence of interleaved events + let events = vec![ + // Text chunk 1 + ( + "text", + create_part_event( + create_text_part(session_id, message_id, "part_1", "Let me think"), + Some("Let me think".to_string()), + ), + ), + // Reasoning chunk 1 + ( + "thought", + create_part_event( + create_reasoning_part(session_id, message_id, "reason_1", "Analyzing the task"), + Some("Analyzing the task".to_string()), + ), + ), + // Tool pending + ( + "tool_pending", + create_part_event( + create_tool_part(session_id, message_id, "tool_1", "bash", oc::ToolState::Pending), + None, + ), + ), + // Text chunk 2 + ( + "text", + create_part_event( + create_text_part(session_id, message_id, "part_2", "Running command"), + Some("Running command".to_string()), + ), + ), + // Tool running + ( + "tool_running", + create_part_event( + create_tool_part( + session_id, + message_id, + "tool_1", + "bash", + oc::ToolState::Running { + input: json!({"cmd": "ls"}), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ), + Some("running".to_string()), + ), + ), + // Reasoning chunk 2 + ( + "thought", + create_part_event( + create_reasoning_part(session_id, message_id, "reason_2", "Command is executing"), + Some("Command is executing".to_string()), + ), + ), + // Tool completed + ( + "tool_completed", + create_part_event( + create_tool_part( + session_id, + message_id, + "tool_1", + "bash", + oc::ToolState::Completed { + input: json!({"cmd": "ls"}), + output: "file1.txt".to_string(), + title: "".to_string(), + metadata: serde_json::Value::Null, + time: oc::PartTime { + start: 1000, + end: Some(2000), + }, + attachments: None, + }, + ), + Some("file1.txt".to_string()), + ), + ), + // Text chunk 3 + ( + "text", + create_part_event( + create_text_part(session_id, message_id, "part_3", "Done!"), + Some("Done!".to_string()), + ), + ), + ]; + + let expected_types = vec![ + "text", + "thought", + "tool_pending", + "text", + "tool_running", + "thought", + "tool_completed", + "text", + ]; + + let mut updates = vec![]; + for (i, (_, event)) in events.into_iter().enumerate() { + let result = adapter.translate_to_session_update(event); + assert!(result.is_ok(), "Event {} should translate successfully", i); + + let update = result.unwrap(); + assert!(update.is_some(), "Event {} should produce an update", i); + updates.push(update.unwrap()); + } + + // Verify order is preserved + assert_eq!(updates.len(), expected_types.len()); + + for (i, (update, expected_type)) in updates.iter().zip(expected_types.iter()).enumerate() { + match (update, *expected_type) { + (SessionUpdate::AgentMessageChunk { .. }, "text") => {} + (SessionUpdate::AgentThoughtChunk { .. }, "thought") => {} + (SessionUpdate::ToolCall { .. }, "tool_pending") => {} + (SessionUpdate::ToolCallUpdate { tool_call, .. }, "tool_running") => { + assert_eq!(tool_call.status, ToolCallStatus::Running); + } + (SessionUpdate::ToolCallUpdate { tool_call, .. }, "tool_completed") => { + assert_eq!(tool_call.status, ToolCallStatus::Completed); + } + _ => panic!( + "Type mismatch at position {}: expected {}, got {:?}", + i, expected_type, update + ), + } + } +} + +/// Test correct variant for each update type in sequence +#[test] +fn test_correct_variant_for_each_type() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_variants"; + let message_id = "msg_variants"; + + // Test text → AgentMessageChunk + let text_event = create_part_event( + create_text_part(session_id, message_id, "p1", "text"), + Some("text".to_string()), + ); + let result = adapter.translate_to_session_update(text_event).unwrap().unwrap(); + assert!(matches!(result, SessionUpdate::AgentMessageChunk { .. })); + + // Test reasoning → AgentThoughtChunk + let reason_event = create_part_event( + create_reasoning_part(session_id, message_id, "p2", "thinking"), + Some("thinking".to_string()), + ); + let result = adapter.translate_to_session_update(reason_event).unwrap().unwrap(); + assert!(matches!(result, SessionUpdate::AgentThoughtChunk { .. })); + + // Test tool pending → ToolCall + let tool_pending_event = create_part_event( + create_tool_part(session_id, message_id, "p3", "test", oc::ToolState::Pending), + None, + ); + let result = adapter.translate_to_session_update(tool_pending_event).unwrap().unwrap(); + assert!(matches!(result, SessionUpdate::ToolCall { .. })); + + // Test tool running → ToolCallUpdate + let tool_running_event = create_part_event( + create_tool_part( + session_id, + message_id, + "p3", + "test", + oc::ToolState::Running { + input: json!({}), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ), + Some("running".to_string()), + ); + let result = adapter.translate_to_session_update(tool_running_event).unwrap().unwrap(); + assert!(matches!(result, SessionUpdate::ToolCallUpdate { .. })); + + // Test file → AgentMessageChunk with ResourceLink + let file_event = create_part_event( + create_file_part(session_id, message_id, "p4", "f.txt", "file:///f.txt", "text/plain"), + None, + ); + let result = adapter.translate_to_session_update(file_event).unwrap().unwrap(); + match result { + SessionUpdate::AgentMessageChunk { content, .. } => { + assert!(matches!(content, ContentBlock::ResourceLink { .. })); + } + _ => panic!("Expected AgentMessageChunk for file"), + } +} + +// ===== Additional Edge Cases ===== + +/// Test duplicate detection for SessionUpdate translation +#[test] +fn test_duplicate_part_skipped_in_session_update() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_dup"; + let message_id = "msg_dup"; + let part_id = "part_dup"; + + // First event with delta + let part1 = create_text_part(session_id, message_id, part_id, "Hello"); + let event1 = create_part_event(part1, Some("Hello".to_string())); + + let result1 = adapter.translate_to_session_update(event1); + assert!(result1.is_ok()); + assert!(result1.unwrap().is_some()); + + // Second event without delta (completion marker) - should be skipped + let part2 = create_text_part(session_id, message_id, part_id, "Hello"); + let event2 = create_part_event(part2, None); + + let result2 = adapter.translate_to_session_update(event2); + assert!(result2.is_ok()); + assert!(result2.unwrap().is_none(), "Duplicate without delta should be skipped"); +} + +/// Test unsupported event types return error +#[test] +fn test_unsupported_events_return_error() { + let adapter = OpenCodeAdapter::new(); + + // SessionCreated is not supported in translate_to_session_update + let session_event = oc::Event::SessionCreated { + properties: oc::SessionEventInfo { + info: oc::Session { + id: "sess1".to_string(), + project_id: "proj1".to_string(), + directory: "/test".to_string(), + parent_id: None, + summary: None, + share: None, + title: "Test".to_string(), + version: "1.0".to_string(), + time: oc::SessionTime { + created: 1000, + updated: 1000, + compacting: None, + }, + revert: None, + }, + }, + }; + + let result = adapter.translate_to_session_update(session_event); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), TranslationError::UnsupportedEvent)); +} + +/// Test unsupported part types return error +#[test] +fn test_unsupported_part_types_return_error() { + let adapter = OpenCodeAdapter::new(); + + // Snapshot part is not supported + let snapshot_part = oc::Part::Snapshot(oc::SnapshotPart { + id: "snap1".to_string(), + session_id: "sess1".to_string(), + message_id: "msg1".to_string(), + snapshot: "snapshot_data".to_string(), + }); + + let event = create_part_event(snapshot_part, None); + let result = adapter.translate_to_session_update(event); + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), TranslationError::UnsupportedPartType)); +} + +/// Test multiple concurrent tool calls tracked independently +#[test] +fn test_multiple_concurrent_tools_independent_tracking() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_multi"; + let message_id = "msg_multi"; + + // Tool 1: Pending + let tool1_pending = create_part_event( + create_tool_part(session_id, message_id, "tool_1", "bash", oc::ToolState::Pending), + None, + ); + let result1 = adapter.translate_to_session_update(tool1_pending).unwrap().unwrap(); + assert!(matches!(result1, SessionUpdate::ToolCall { .. })); + + // Tool 2: Pending (different tool) + let tool2_pending = create_part_event( + create_tool_part(session_id, message_id, "tool_2", "read", oc::ToolState::Pending), + None, + ); + let result2 = adapter.translate_to_session_update(tool2_pending).unwrap().unwrap(); + assert!(matches!(result2, SessionUpdate::ToolCall { .. })); + + // Tool 1: Running → should be ToolCallUpdate + let tool1_running = create_part_event( + create_tool_part( + session_id, + message_id, + "tool_1", + "bash", + oc::ToolState::Running { + input: json!({}), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ), + Some("r".to_string()), + ); + let result3 = adapter.translate_to_session_update(tool1_running).unwrap().unwrap(); + match result3 { + SessionUpdate::ToolCallUpdate { tool_call_id, .. } => { + assert_eq!(tool_call_id, "tool_1"); + } + _ => panic!("Expected ToolCallUpdate for tool_1"), + } + + // Tool 2: Running → should also be ToolCallUpdate + let tool2_running = create_part_event( + create_tool_part( + session_id, + message_id, + "tool_2", + "read", + oc::ToolState::Running { + input: json!({}), + title: None, + metadata: None, + time: oc::PartTime { + start: 1000, + end: None, + }, + }, + ), + Some("r".to_string()), + ); + let result4 = adapter.translate_to_session_update(tool2_running).unwrap().unwrap(); + match result4 { + SessionUpdate::ToolCallUpdate { tool_call_id, .. } => { + assert_eq!(tool_call_id, "tool_2"); + } + _ => panic!("Expected ToolCallUpdate for tool_2"), + } +} + +/// Test metadata present in all update types +#[test] +fn test_metadata_present_in_all_updates() { + let adapter = OpenCodeAdapter::new(); + let session_id = "sess_meta"; + let message_id = "msg_meta"; + + // Text + let text_event = create_part_event( + create_text_part(session_id, message_id, "p1", "text"), + Some("text".to_string()), + ); + let text_update = adapter.translate_to_session_update(text_event).unwrap().unwrap(); + match text_update { + SessionUpdate::AgentMessageChunk { _meta, .. } => { + assert!(_meta.is_some(), "Text chunk should have metadata"); + } + _ => panic!("Expected AgentMessageChunk"), + } + + // Reasoning + let reason_event = create_part_event( + create_reasoning_part(session_id, message_id, "p2", "reason"), + Some("reason".to_string()), + ); + let reason_update = adapter.translate_to_session_update(reason_event).unwrap().unwrap(); + match reason_update { + SessionUpdate::AgentThoughtChunk { _meta, .. } => { + assert!(_meta.is_some(), "Thought chunk should have metadata"); + } + _ => panic!("Expected AgentThoughtChunk"), + } + + // Tool + let tool_event = create_part_event( + create_tool_part(session_id, message_id, "p3", "bash", oc::ToolState::Pending), + None, + ); + let tool_update = adapter.translate_to_session_update(tool_event).unwrap().unwrap(); + match tool_update { + SessionUpdate::ToolCall { _meta, .. } => { + assert!(_meta.is_some(), "ToolCall should have metadata"); + } + _ => panic!("Expected ToolCall"), + } + + // File + let file_event = create_part_event( + create_file_part(session_id, message_id, "p4", "f.txt", "file:///f.txt", "text/plain"), + None, + ); + let file_update = adapter.translate_to_session_update(file_event).unwrap().unwrap(); + match file_update { + SessionUpdate::AgentMessageChunk { _meta, .. } => { + assert!(_meta.is_some(), "File chunk should have metadata"); + } + _ => panic!("Expected AgentMessageChunk"), + } +} diff --git a/crates/dirigent_protocol/tests/protocol_tests.rs b/crates/dirigent_protocol/tests/protocol_tests.rs new file mode 100644 index 0000000..4d32cf5 --- /dev/null +++ b/crates/dirigent_protocol/tests/protocol_tests.rs @@ -0,0 +1,464 @@ +use dirigent_protocol::adapters::OpenCodeAdapter; +use dirigent_protocol::{Event, Message, MessagePart, MessageRole, MessageStatus, Session}; +use opencode_client::types as oc; + +#[test] +fn test_parse_opencode_events() { + // Load sample events from fixture + let fixture = include_str!("fixtures/sample_events.jsonl"); + + for (idx, line) in fixture.lines().enumerate() { + let result = serde_json::from_str::<oc::Event>(line); + assert!( + result.is_ok(), + "Failed to parse OpenCode event at line {}: {:?}", + idx + 1, + result.err() + ); + } +} + +#[test] +fn test_translate_server_connected() { + let adapter = OpenCodeAdapter::new(); + let oc_event = oc::Event::ServerConnected { + properties: serde_json::json!({}), + }; + + let result = adapter.translate_event(oc_event); + assert!(result.is_ok()); + assert!(matches!(result.unwrap(), Event::Connected)); +} + +#[test] +fn test_translate_session_created() { + let adapter = OpenCodeAdapter::new(); + let oc_session = oc::Session { + id: "ses_test123".to_string(), + project_id: "test_project".to_string(), + directory: "/test/path".to_string(), + parent_id: None, + summary: None, + share: None, + title: "Test Session".to_string(), + version: "0.15.31".to_string(), + time: oc::SessionTime { + created: 1700000000000, + updated: 1700000000000, + compacting: None, + }, + revert: None, + }; + + let oc_event = oc::Event::SessionCreated { + properties: oc::SessionEventInfo { + info: oc_session.clone(), + }, + }; + + let result = adapter.translate_event(oc_event); + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionCreated { + connector_id: _, + session, + } => { + assert_eq!(session.id, "ses_test123"); + assert_eq!(session.title, "Test Session"); + assert_eq!(session.metadata.project_path, "/test/path"); + } + _ => panic!("Expected SessionCreated event"), + } +} + +#[test] +fn test_translate_user_message() { + let adapter = OpenCodeAdapter::new(); + let oc_message = oc::Message::User(oc::UserMessage { + id: "msg_user1".to_string(), + session_id: "ses_test123".to_string(), + time: oc::MessageTime { + created: 1700000001000, + }, + summary: None, + }); + + let oc_event = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { info: oc_message }, + }; + + let result = adapter.translate_event(oc_event); + assert!(result.is_ok()); + + match result.unwrap() { + Event::MessageCompleted { + connector_id: _, + message, + } => { + assert_eq!(message.id, "msg_user1"); + assert_eq!(message.session_id, "ses_test123"); + assert!(matches!(message.role, MessageRole::User)); + assert!(matches!(message.status, MessageStatus::Completed)); + } + _ => panic!("Expected MessageCompleted event for user message"), + } +} + +#[test] +fn test_translate_assistant_message_streaming() { + let oc_message = oc::Message::Assistant(oc::AssistantMessage { + id: "msg_asst1".to_string(), + session_id: "ses_test123".to_string(), + time: oc::AssistantMessageTime { + created: 1700000002000, + completed: None, // Still streaming + }, + error: None, + system: vec![], + parent_id: Some("msg_user1".to_string()), + model_id: Some("gpt-4".to_string()), + provider_id: Some("openai".to_string()), + mode: None, + path: None, + summary: None, + cost: 0.05, + tokens: Default::default(), + }); + + let oc_event = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { info: oc_message }, + }; + + let adapter = OpenCodeAdapter::new(); + let result = adapter.translate_event(oc_event); + assert!(result.is_ok()); + + match result.unwrap() { + Event::MessageStarted { + connector_id: _, + message, + } => { + assert_eq!(message.id, "msg_asst1"); + assert!(matches!(message.role, MessageRole::Assistant)); + assert!(matches!(message.status, MessageStatus::Streaming)); + } + _ => panic!("Expected MessageStarted event for streaming message"), + } +} + +#[test] +fn test_translate_assistant_message_completed() { + let oc_message = oc::Message::Assistant(oc::AssistantMessage { + id: "msg_asst1".to_string(), + session_id: "ses_test123".to_string(), + time: oc::AssistantMessageTime { + created: 1700000002000, + completed: Some(1700000005000), // Completed + }, + error: None, + system: vec![], + parent_id: Some("msg_user1".to_string()), + model_id: Some("gpt-4".to_string()), + provider_id: Some("openai".to_string()), + mode: None, + path: None, + summary: None, + cost: 0.05, + tokens: Default::default(), + }); + + let oc_event = oc::Event::MessageUpdated { + properties: oc::MessageEventInfo { info: oc_message }, + }; + + let adapter = OpenCodeAdapter::new(); + let result = adapter.translate_event(oc_event); + assert!(result.is_ok()); + + match result.unwrap() { + Event::MessageCompleted { + connector_id: _, + message, + } => { + assert_eq!(message.id, "msg_asst1"); + assert!(matches!(message.role, MessageRole::Assistant)); + assert!(matches!(message.status, MessageStatus::Completed)); + } + _ => panic!("Expected MessageCompleted event"), + } +} + +#[test] +fn test_translate_text_part() { + let oc_part = oc::Part::Text(oc::TextPart { + id: "prt_text1".to_string(), + session_id: "ses_test123".to_string(), + message_id: "msg_user1".to_string(), + text: "Hello, can you help me?".to_string(), + synthetic: Some(false), + time: None, + }); + + let oc_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: oc_part, + delta: None, + }, + }; + + let adapter = OpenCodeAdapter::new(); + let result = adapter.translate_event(oc_event); + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { + connector_id: _, + session_id, + update, + } => { + assert_eq!(session_id, "ses_test123"); + match update { + dirigent_protocol::SessionUpdate::AgentMessageChunk { + message_id, + content, + .. + } => { + assert_eq!(message_id, "msg_user1"); + match content { + dirigent_protocol::ContentBlock::Text { text } => { + assert_eq!(text, "Hello, can you help me?"); + } + _ => panic!("Expected Text content block"), + } + } + _ => panic!("Expected AgentMessageChunk"), + } + } + _ => panic!("Expected SessionUpdate event"), + } +} + +#[test] +fn test_translate_reasoning_part() { + let oc_part = oc::Part::Reasoning(oc::ReasoningPart { + id: "prt_reasoning1".to_string(), + session_id: "ses_test123".to_string(), + message_id: "msg_asst1".to_string(), + text: "Let me think about this...".to_string(), + time: None, + }); + + let oc_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: oc_part, + delta: Some(" more".to_string()), + }, + }; + + let adapter = OpenCodeAdapter::new(); + let result = adapter.translate_event(oc_event); + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { + connector_id: _, + session_id, + update, + } => { + assert_eq!(session_id, "ses_test123"); + match update { + dirigent_protocol::SessionUpdate::AgentThoughtChunk { + message_id, + content, + .. + } => { + assert_eq!(message_id, "msg_asst1"); + match content { + dirigent_protocol::ContentBlock::Text { text } => { + // Should use delta, not full text + assert_eq!(text, " more"); + } + _ => panic!("Expected Text content block"), + } + } + _ => panic!("Expected AgentThoughtChunk"), + } + } + _ => panic!("Expected SessionUpdate event"), + } +} + +#[test] +fn test_translate_tool_part_completed() { + let oc_part = oc::Part::Tool(oc::ToolPart { + id: "prt_tool1".to_string(), + session_id: "ses_test123".to_string(), + message_id: "msg_asst1".to_string(), + call_id: "call_123".to_string(), + tool: "Read".to_string(), + state: oc::ToolState::Completed { + input: serde_json::json!({"file_path": "/test/file.txt"}), + output: "File contents here".to_string(), + title: "Reading file".to_string(), + metadata: serde_json::json!({}), + time: oc::PartTime { + start: 1700000003000, + end: Some(1700000004000), + }, + attachments: None, + }, + metadata: None, + }); + + let oc_event = oc::Event::MessagePartUpdated { + properties: oc::MessagePartEventInfo { + part: oc_part, + delta: None, + }, + }; + + let adapter = OpenCodeAdapter::new(); + let result = adapter.translate_event(oc_event); + assert!(result.is_ok()); + + match result.unwrap() { + Event::SessionUpdate { + connector_id: _, + session_id, + update, + } => { + assert_eq!(session_id, "ses_test123"); + match update { + dirigent_protocol::SessionUpdate::ToolCall { + message_id, + tool_call, + .. + } => { + assert_eq!(message_id, "msg_asst1"); + assert_eq!(tool_call.tool_name, "Read"); + assert_eq!( + tool_call.status, + dirigent_protocol::ToolCallStatus::Completed + ); + assert!(tool_call.raw_input.is_some()); + assert_eq!( + tool_call.raw_input.unwrap().get("file_path").unwrap(), + "/test/file.txt" + ); + assert!(tool_call.raw_output.is_some()); + assert_eq!( + tool_call.raw_output.unwrap().as_str().unwrap(), + "File contents here" + ); + } + _ => panic!("Expected ToolCall"), + } + } + _ => panic!("Expected SessionUpdate event"), + } +} + +#[test] +fn test_full_event_stream() { + let fixture = include_str!("fixtures/sample_events.jsonl"); + let adapter = OpenCodeAdapter::new(); // One adapter for entire stream + let mut session_created = false; + let mut message_count = 0; + let mut part_count = 0; + + for line in fixture.lines() { + let oc_event: oc::Event = + serde_json::from_str(line).expect("Failed to parse OpenCode event"); + let result = adapter.translate_event(oc_event); + + match result { + Ok(Event::SessionCreated { .. }) => session_created = true, + Ok(Event::MessageStarted { .. }) | Ok(Event::MessageCompleted { .. }) => { + message_count += 1 + } + Ok(Event::SessionUpdate { .. }) => part_count += 1, + Err(_) => {} // Some events might not translate (e.g., unknown types) or duplicates + _ => {} + } + } + + assert!(session_created, "Session should be created"); + assert!(message_count > 0, "Should have messages"); + assert!(part_count > 0, "Should have message parts"); +} + +#[test] +fn test_dirigent_protocol_serialization() { + // Test that Dirigent protocol types can be serialized and deserialized + let session = Session { + id: "ses_test".to_string(), + title: "Test".to_string(), + created_at: chrono::Utc::now(), + + updated_at: chrono::Utc::now(), + + metadata: dirigent_protocol::SessionMetadata { + project_path: "/test".to_string(), + + model: Some("gpt-4".to_string()), + + total_messages: 5, + + system_message: None, + + current_mode_id: None, + + _meta: None, + + project_id: None, + }, + cwd: None, + models: None, + modes: None, + config_options: None, + acp_client_id: None, + }; + + let json = serde_json::to_string(&session).expect("Failed to serialize"); + let deserialized: Session = serde_json::from_str(&json).expect("Failed to deserialize"); + + assert_eq!(session, deserialized); +} + +#[test] + +fn test_message_protocol_serialization() { + let message = Message { + id: "msg_test".to_string(), + + session_id: "ses_test".to_string(), + + role: MessageRole::Assistant, + + created_at: chrono::Utc::now(), + + content: vec![ + MessagePart::Text { + text: "Hello".to_string(), + }, + MessagePart::Tool { + tool: "Read".to_string(), + tool_call_id: None, + input: serde_json::json!({"file": "test.txt"}), + + output: Some(serde_json::json!("content")), + }, + ], + + status: MessageStatus::Completed, + + metadata: None, + }; + + let json = serde_json::to_string(&message).expect("Failed to serialize"); + + let deserialized: Message = serde_json::from_str(&json).expect("Failed to deserialize"); + + assert_eq!(message, deserialized); +} diff --git a/crates/dirigent_protocol/tests/public_api_tests.rs b/crates/dirigent_protocol/tests/public_api_tests.rs new file mode 100644 index 0000000..088e13a --- /dev/null +++ b/crates/dirigent_protocol/tests/public_api_tests.rs @@ -0,0 +1,180 @@ +/// Integration test to verify the public API of dirigent_protocol +/// This ensures all types are accessible via `use dirigent_protocol::{...}` +use dirigent_protocol::{ + ContentBlock, Meta, ProviderMeta, SessionUpdate, ToolCall, ToolCallId, ToolCallStatus, +}; + +#[test] +fn test_content_block_import() { + let text = ContentBlock::Text { + text: "test".to_string(), + }; + assert!(matches!(text, ContentBlock::Text { .. })); + + let resource = ContentBlock::ResourceLink { + uri: "file:///test.txt".to_string(), + name: None, + mime_type: None, + }; + assert!(matches!(resource, ContentBlock::ResourceLink { .. })); +} + +#[test] +fn test_meta_import() { + let meta = Meta::default(); + assert_eq!(meta.provider, None); + assert!(meta.extra.is_empty()); +} + +#[test] +fn test_provider_meta_import() { + let provider = ProviderMeta { + name: "test".to_string(), + original_ids: None, + raw_excerpt: None, + }; + assert_eq!(provider.name, "test"); +} + +#[test] +fn test_tool_call_types_import() { + let tool_call_id: ToolCallId = "call_123".to_string(); + assert_eq!(tool_call_id, "call_123"); + + let status = ToolCallStatus::Pending; + assert_eq!(status, ToolCallStatus::Pending); + + let tool_call = ToolCall { + id: tool_call_id, + tool_name: "test".to_string(), + status: ToolCallStatus::Running, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + assert_eq!(tool_call.tool_name, "test"); + assert_eq!(tool_call.status, ToolCallStatus::Running); +} + +#[test] +fn test_session_update_import() { + let update = SessionUpdate::UserMessageChunk { + message_id: "msg_1".to_string(), + content: ContentBlock::Text { + text: "Hello".to_string(), + }, + _meta: None, + }; + assert!(matches!(update, SessionUpdate::UserMessageChunk { .. })); +} + +#[test] +fn test_all_session_update_variants() { + // Test all SessionUpdate variants can be constructed + let user_chunk = SessionUpdate::UserMessageChunk { + message_id: "m1".to_string(), + content: ContentBlock::Text { + text: "user".to_string(), + }, + _meta: None, + }; + + let agent_chunk = SessionUpdate::AgentMessageChunk { + message_id: "m2".to_string(), + content: ContentBlock::Text { + text: "agent".to_string(), + }, + _meta: None, + }; + + let thought_chunk = SessionUpdate::AgentThoughtChunk { + message_id: "m3".to_string(), + content: ContentBlock::Text { + text: "thinking".to_string(), + }, + _meta: None, + }; + + let tool_call = SessionUpdate::ToolCall { + message_id: "m4".to_string(), + tool_call: ToolCall { + id: "c1".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }, + _meta: None, + }; + + let tool_call_update = SessionUpdate::ToolCallUpdate { + message_id: "m5".to_string(), + tool_call_id: "c2".to_string(), + tool_call: ToolCall { + id: "c2".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Completed, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }, + _meta: None, + }; + + // If we got here, all variants can be constructed + assert!(matches!(user_chunk, SessionUpdate::UserMessageChunk { .. })); + assert!(matches!( + agent_chunk, + SessionUpdate::AgentMessageChunk { .. } + )); + assert!(matches!( + thought_chunk, + SessionUpdate::AgentThoughtChunk { .. } + )); + assert!(matches!(tool_call, SessionUpdate::ToolCall { .. })); + assert!(matches!( + tool_call_update, + SessionUpdate::ToolCallUpdate { .. } + )); +} + +#[test] +fn test_serialization_works_with_public_api() { + // Verify that types imported from the public API can be serialized + let update = SessionUpdate::UserMessageChunk { + message_id: "msg_test".to_string(), + content: ContentBlock::Text { + text: "test message".to_string(), + }, + _meta: Some(Meta { + provider: Some(ProviderMeta { + name: "test_provider".to_string(), + original_ids: None, + raw_excerpt: None, + }), + extra: Default::default(), + }), + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains("msg_test")); + assert!(json.contains("test message")); + assert!(json.contains("test_provider")); + + // Verify round-trip + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} diff --git a/crates/dirigent_protocol/tests/session_list_tests.rs b/crates/dirigent_protocol/tests/session_list_tests.rs new file mode 100644 index 0000000..1c62d2a --- /dev/null +++ b/crates/dirigent_protocol/tests/session_list_tests.rs @@ -0,0 +1,111 @@ +use opencode_client::types as oc; + +/// Test parsing OpenCode session list response +#[test] +fn test_parse_session_list() { + let fixture = include_str!("fixtures/opencode_session_response.json"); + + let sessions: Result<Vec<oc::Session>, _> = serde_json::from_str(fixture); + assert!(sessions.is_ok(), "Failed to parse session list: {:?}", sessions.err()); + + let sessions = sessions.unwrap(); + assert_eq!(sessions.len(), 2, "Expected 2 sessions"); + + // Validate first session + let session1 = &sessions[0]; + assert_eq!(session1.id, "ses_5c049a0adffeNYJI1u8SBCBUmA"); + assert_eq!(session1.version, "1.0.7"); + assert_eq!(session1.directory, "/Users/gabor.koerber/Projects/dirigent"); + assert_eq!(session1.title, "I appreciate you testing my system, but I need to stick to my role."); + assert_eq!(session1.time.created, 1762005507923); + assert_eq!(session1.time.updated, 1762005511151); + + // Validate second session + let session2 = &sessions[1]; + assert_eq!(session2.id, "ses_5c0e6b7a3ffeeTfpR7ZhBXC7zt"); + assert_eq!(session2.title, "New session - 2025-11-01T11:06:52.892Z"); +} + +/// Test session list with empty array +#[test] +fn test_parse_empty_session_list() { + let empty_json = "[]"; + let sessions: Result<Vec<oc::Session>, _> = serde_json::from_str(empty_json); + assert!(sessions.is_ok()); + assert_eq!(sessions.unwrap().len(), 0); +} + +/// Test individual session deserialization +#[test] +fn test_parse_single_session() { + let session_json = r#"{ + "id": "ses_test123", + "version": "1.0.7", + "projectID": "test_project", + "directory": "/test/path", + "title": "Test Session", + "time": { + "created": 1700000000000, + "updated": 1700000000000 + }, + "summary": { + "diffs": [] + } + }"#; + + let session: Result<oc::Session, _> = serde_json::from_str(session_json); + assert!(session.is_ok(), "Failed to parse session: {:?}", session.err()); + + let session = session.unwrap(); + assert_eq!(session.id, "ses_test123"); + assert_eq!(session.directory, "/test/path"); + assert_eq!(session.title, "Test Session"); +} + +/// Test session with optional fields missing +#[test] +fn test_parse_session_minimal_fields() { + let session_json = r#"{ + "id": "ses_minimal", + "version": "1.0.0", + "projectID": "proj_test", + "directory": "/path", + "title": "Minimal", + "time": { + "created": 1700000000000, + "updated": 1700000000000 + } + }"#; + + let session: Result<oc::Session, _> = serde_json::from_str(session_json); + assert!(session.is_ok(), "Failed to parse minimal session: {:?}", session.err()); + + let session = session.unwrap(); + assert_eq!(session.id, "ses_minimal"); + assert!(session.parent_id.is_none()); + assert!(session.summary.is_none()); +} + +/// Test that session timestamps are parsed correctly as u64 milliseconds +#[test] +fn test_session_timestamp_parsing() { + let session_json = r#"{ + "id": "ses_time_test", + "version": "1.0.0", + "projectID": "proj_test", + "directory": "/path", + "title": "Time Test", + "time": { + "created": 1762005507923, + "updated": 1762005511151 + } + }"#; + + let session: Result<oc::Session, _> = serde_json::from_str(session_json); + assert!(session.is_ok()); + + let session = session.unwrap(); + // Verify timestamps are reasonable (year 2025-2026 range) + assert!(session.time.created > 1700000000000); + assert!(session.time.updated >= session.time.created); +} diff --git a/crates/dirigent_protocol/tests/session_update_tests.rs b/crates/dirigent_protocol/tests/session_update_tests.rs new file mode 100644 index 0000000..12ae8d1 --- /dev/null +++ b/crates/dirigent_protocol/tests/session_update_tests.rs @@ -0,0 +1,796 @@ +/// Comprehensive edge case tests for SessionUpdate variants +use dirigent_protocol::types::{ + ContentBlock, Meta, SessionUpdate, ToolCall, ToolCallContent, ToolCallStatus, +}; +use serde_json::json; + +// ===== UserMessageChunk Tests ===== + +/// Test UserMessageChunk minimal (no _meta) +#[test] +fn test_user_message_chunk_minimal() { + let update = SessionUpdate::UserMessageChunk { + message_id: "msg_001".to_string(), + content: ContentBlock::Text { + text: "Hello".to_string(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"user_message_chunk"#)); + assert!(json.contains(r#""message_id":"msg_001"#)); + assert!(!json.contains(r#""_meta""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +/// Test UserMessageChunk with empty message_id +#[test] +fn test_user_message_chunk_empty_message_id() { + let update = SessionUpdate::UserMessageChunk { + message_id: String::new(), + content: ContentBlock::Text { + text: "Test".to_string(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""message_id":"""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + match deserialized { + SessionUpdate::UserMessageChunk { message_id, .. } => { + assert_eq!(message_id, ""); + } + _ => panic!("Expected UserMessageChunk"), + } +} + +/// Test UserMessageChunk with ResourceLink content +#[test] +fn test_user_message_chunk_with_resource_link() { + let update = SessionUpdate::UserMessageChunk { + message_id: "msg_002".to_string(), + content: ContentBlock::ResourceLink { + uri: "file:///path/to/file.txt".to_string(), + name: Some("file.txt".to_string()), + mime_type: Some("text/plain".to_string()), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"user_message_chunk"#)); + assert!(json.contains(r#""type":"resource_link"#)); // nested type + assert!(json.contains(r#""uri":"file:///path/to/file.txt"#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +/// Test UserMessageChunk with _meta +#[test] +fn test_user_message_chunk_with_meta() { + let update = SessionUpdate::UserMessageChunk { + message_id: "msg_003".to_string(), + content: ContentBlock::Text { + text: "Test".to_string(), + }, + _meta: Some(Meta::default()), + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""_meta":{}"#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +// ===== AgentMessageChunk Tests ===== + +/// Test AgentMessageChunk minimal +#[test] +fn test_agent_message_chunk_minimal() { + let update = SessionUpdate::AgentMessageChunk { + message_id: "msg_agent_001".to_string(), + content: ContentBlock::Text { + text: "Agent response".to_string(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"agent_message_chunk"#)); + assert!(json.contains(r#""message_id":"msg_agent_001"#)); + assert!(!json.contains(r#""_meta""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +/// Test AgentMessageChunk with empty message_id +#[test] +fn test_agent_message_chunk_empty_message_id() { + let update = SessionUpdate::AgentMessageChunk { + message_id: String::new(), + content: ContentBlock::Text { + text: "Response".to_string(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + + match deserialized { + SessionUpdate::AgentMessageChunk { message_id, .. } => { + assert_eq!(message_id, ""); + } + _ => panic!("Expected AgentMessageChunk"), + } +} + +/// Test AgentMessageChunk with complex meta +#[test] +fn test_agent_message_chunk_with_complex_meta() { + let mut extra = std::collections::HashMap::new(); + extra.insert("timestamp".to_string(), json!("2025-11-10T12:00:00Z")); + extra.insert("duration_ms".to_string(), json!(123)); + + let update = SessionUpdate::AgentMessageChunk { + message_id: "msg_agent_002".to_string(), + content: ContentBlock::Text { + text: "Response".to_string(), + }, + _meta: Some(Meta { + provider: None, + extra, + }), + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""_meta""#)); + assert!(json.contains(r#""timestamp""#)); + assert!(json.contains(r#""duration_ms""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +// ===== AgentThoughtChunk Tests ===== + +/// Test AgentThoughtChunk minimal +#[test] +fn test_agent_thought_chunk_minimal() { + let update = SessionUpdate::AgentThoughtChunk { + message_id: "msg_thought_001".to_string(), + content: ContentBlock::Text { + text: "Thinking...".to_string(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"agent_thought_chunk"#)); + assert!(json.contains(r#""message_id":"msg_thought_001"#)); + assert!(!json.contains(r#""_meta""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +/// Test AgentThoughtChunk with empty text +#[test] +fn test_agent_thought_chunk_empty_text() { + let update = SessionUpdate::AgentThoughtChunk { + message_id: "msg_thought_002".to_string(), + content: ContentBlock::Text { + text: String::new(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""text":"""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +/// Test AgentThoughtChunk with very long text +#[test] +fn test_agent_thought_chunk_long_text() { + let long_text = "Analyzing the problem...\n".repeat(1000); + let update = SessionUpdate::AgentThoughtChunk { + message_id: "msg_thought_003".to_string(), + content: ContentBlock::Text { + text: long_text.clone(), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + + match deserialized { + SessionUpdate::AgentThoughtChunk { content, .. } => { + if let ContentBlock::Text { text } = content { + assert_eq!(text.len(), long_text.len()); + } else { + panic!("Expected Text content"); + } + } + _ => panic!("Expected AgentThoughtChunk"), + } +} + +// ===== ToolCall Tests ===== + +/// Test ToolCall variant minimal +#[test] +fn test_tool_call_variant_minimal() { + let tool_call = ToolCall { + id: "call_001".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let update = SessionUpdate::ToolCall { + message_id: "msg_tool_001".to_string(), + tool_call, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"tool_call"#)); + assert!(json.contains(r#""message_id":"msg_tool_001"#)); + assert!(json.contains(r#""tool_call""#)); + assert!(!json.contains(r#""_meta""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +/// Test ToolCall variant with complex nested ToolCall +#[test] +fn test_tool_call_variant_complex() { + let tool_call = ToolCall { + id: "call_002".to_string(), + tool_name: "read_file".to_string(), + status: ToolCallStatus::Completed, + content: vec![ + ToolCallContent::from_content_block(ContentBlock::Text { + text: "Line 1".to_string(), + }), + ToolCallContent::from_content_block(ContentBlock::Text { + text: "Line 2".to_string(), + }), + ], + raw_input: Some(json!({"path": "/tmp/test.txt"})), + raw_output: Some(json!({"bytes": 1024})), + title: Some("Read file".to_string()), + error: None, + metadata: Some(json!({"duration_ms": 42})), + origin: None, + }; + + let update = SessionUpdate::ToolCall { + message_id: "msg_tool_002".to_string(), + tool_call, + _meta: Some(Meta::default()), + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"tool_call"#)); + assert!(json.contains(r#""tool_call""#)); + assert!(json.contains(r#""raw_input""#)); + assert!(json.contains(r#""raw_output""#)); + assert!(json.contains(r#""_meta""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +/// Test ToolCall variant with Error status +#[test] +fn test_tool_call_variant_with_error() { + let tool_call = ToolCall { + id: "call_003".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Error, + content: vec![], + raw_input: Some(json!({"command": "invalid"})), + raw_output: None, + title: Some("Failed command".to_string()), + error: Some("Command not found".to_string()), + metadata: None, + origin: None, + }; + + let update = SessionUpdate::ToolCall { + message_id: "msg_tool_003".to_string(), + tool_call, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""status":"failed""#)); + assert!(json.contains(r#""error":"Command not found""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +// ===== ToolCallUpdate Tests ===== + +/// Test ToolCallUpdate variant minimal +#[test] +fn test_tool_call_update_minimal() { + let tool_call = ToolCall { + id: "call_004".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Running, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let update = SessionUpdate::ToolCallUpdate { + message_id: "msg_update_001".to_string(), + tool_call_id: "call_004".to_string(), + tool_call, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""type":"tool_call_update"#)); + assert!(json.contains(r#""message_id":"msg_update_001"#)); + assert!(json.contains(r#""tool_call_id":"call_004"#)); + assert!(json.contains(r#""tool_call""#)); + assert!(!json.contains(r#""_meta""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +/// Test ToolCallUpdate with mismatched IDs +#[test] +fn test_tool_call_update_mismatched_ids() { + // This is technically allowed by the type system, though semantically odd + let tool_call = ToolCall { + id: "call_005".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Running, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let update = SessionUpdate::ToolCallUpdate { + message_id: "msg_update_002".to_string(), + tool_call_id: "call_DIFFERENT".to_string(), // Different from tool_call.id + tool_call, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""tool_call_id":"call_DIFFERENT"#)); + assert!(json.contains(r#""id":"call_005"#)); // nested id + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +/// Test ToolCallUpdate with completed status +#[test] +fn test_tool_call_update_completed() { + let tool_call = ToolCall { + id: "call_006".to_string(), + tool_name: "read".to_string(), + status: ToolCallStatus::Completed, + content: vec![ToolCallContent::from_content_block(ContentBlock::Text { + text: "File contents".to_string(), + })], + raw_input: Some(json!({"path": "/tmp/file"})), + raw_output: Some(json!({"success": true})), + title: Some("Read operation".to_string()), + error: None, + metadata: Some(json!({"lines": 42})), + origin: None, + }; + + let update = SessionUpdate::ToolCallUpdate { + message_id: "msg_update_003".to_string(), + tool_call_id: "call_006".to_string(), + tool_call, + _meta: Some(Meta::default()), + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""status":"completed""#)); + assert!(json.contains(r#""raw_output""#)); + assert!(json.contains(r#""_meta""#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +// ===== Type Tag Tests ===== + +/// Test all variants have correct snake_case type tags +#[test] +fn test_all_type_tags_snake_case() { + let user_chunk = SessionUpdate::UserMessageChunk { + message_id: "m1".to_string(), + content: ContentBlock::Text { + text: "test".to_string(), + }, + _meta: None, + }; + let json = serde_json::to_string(&user_chunk).unwrap(); + assert!(json.contains(r#""type":"user_message_chunk"#)); + assert!(!json.contains(r#""type":"UserMessageChunk"#)); + + let agent_chunk = SessionUpdate::AgentMessageChunk { + message_id: "m2".to_string(), + content: ContentBlock::Text { + text: "test".to_string(), + }, + _meta: None, + }; + let json = serde_json::to_string(&agent_chunk).unwrap(); + assert!(json.contains(r#""type":"agent_message_chunk"#)); + + let thought_chunk = SessionUpdate::AgentThoughtChunk { + message_id: "m3".to_string(), + content: ContentBlock::Text { + text: "test".to_string(), + }, + _meta: None, + }; + let json = serde_json::to_string(&thought_chunk).unwrap(); + assert!(json.contains(r#""type":"agent_thought_chunk"#)); + + let tool_call = SessionUpdate::ToolCall { + message_id: "m4".to_string(), + tool_call: ToolCall { + id: "c1".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }, + _meta: None, + }; + let json = serde_json::to_string(&tool_call).unwrap(); + assert!(json.contains(r#""type":"tool_call"#)); + assert!(!json.contains(r#""type":"ToolCall"#)); + + let tool_call_update = SessionUpdate::ToolCallUpdate { + message_id: "m5".to_string(), + tool_call_id: "c2".to_string(), + tool_call: ToolCall { + id: "c2".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }, + _meta: None, + }; + let json = serde_json::to_string(&tool_call_update).unwrap(); + assert!(json.contains(r#""type":"tool_call_update"#)); +} + +// ===== Deserialization Error Cases ===== + +/// Test missing type field +#[test] +fn test_missing_type_field() { + let json = r#"{ + "message_id": "msg_001", + "content": { + "type": "text", + "text": "Hello" + } + }"#; + + let result: Result<SessionUpdate, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without type field"); +} + +/// Test invalid type value +#[test] +fn test_invalid_type_value() { + let json = r#"{ + "type": "invalid_message_chunk", + "message_id": "msg_001", + "content": { + "type": "text", + "text": "Hello" + } + }"#; + + let result: Result<SessionUpdate, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail with invalid type"); +} + +/// Test missing message_id +#[test] +fn test_missing_message_id() { + let json = r#"{ + "type": "user_message_chunk", + "content": { + "type": "text", + "text": "Hello" + } + }"#; + + let result: Result<SessionUpdate, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without message_id"); +} + +/// Test missing content field +#[test] +fn test_missing_content_field() { + let json = r#"{ + "type": "user_message_chunk", + "message_id": "msg_001" + }"#; + + let result: Result<SessionUpdate, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without content"); +} + +/// Test missing tool_call field +#[test] +fn test_missing_tool_call_field() { + let json = r#"{ + "type": "tool_call", + "message_id": "msg_001" + }"#; + + let result: Result<SessionUpdate, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without tool_call"); +} + +/// Test missing tool_call_id in ToolCallUpdate +#[test] +fn test_missing_tool_call_id() { + let json = r#"{ + "type": "tool_call_update", + "message_id": "msg_001", + "tool_call": { + "id": "call_001", + "tool_name": "test", + "status": "running" + } + }"#; + + let result: Result<SessionUpdate, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without tool_call_id"); +} + +/// Test null values for required fields +#[test] +fn test_null_required_values() { + let json = r#"{ + "type": "user_message_chunk", + "message_id": null, + "content": { + "type": "text", + "text": "Hello" + } + }"#; + + let result: Result<SessionUpdate, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail with null message_id"); +} + +/// Test null _meta (should deserialize as None) +#[test] +fn test_null_meta() { + let json = r#"{ + "type": "user_message_chunk", + "message_id": "msg_001", + "content": { + "type": "text", + "text": "Hello" + }, + "_meta": null + }"#; + + let update: SessionUpdate = serde_json::from_str(json).unwrap(); + match update { + SessionUpdate::UserMessageChunk { _meta, .. } => { + assert!(_meta.is_none()); + } + _ => panic!("Expected UserMessageChunk"), + } +} + +// ===== Roundtrip Tests ===== + +/// Test roundtrip for all variants +#[test] +fn test_all_variants_roundtrip() { + let variants = vec![ + SessionUpdate::UserMessageChunk { + message_id: "msg_1".to_string(), + content: ContentBlock::Text { + text: "User message".to_string(), + }, + _meta: None, + }, + SessionUpdate::AgentMessageChunk { + message_id: "msg_2".to_string(), + content: ContentBlock::Text { + text: "Agent response".to_string(), + }, + _meta: Some(Meta::default()), + }, + SessionUpdate::AgentThoughtChunk { + message_id: "msg_3".to_string(), + content: ContentBlock::Text { + text: "Thinking...".to_string(), + }, + _meta: None, + }, + SessionUpdate::ToolCall { + message_id: "msg_4".to_string(), + tool_call: ToolCall { + id: "call_1".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: Some(json!({"cmd": "ls"})), + raw_output: None, + title: Some("List files".to_string()), + error: None, + metadata: None, + origin: None, + }, + _meta: None, + }, + SessionUpdate::ToolCallUpdate { + message_id: "msg_5".to_string(), + tool_call_id: "call_1".to_string(), + tool_call: ToolCall { + id: "call_1".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Completed, + content: vec![ToolCallContent::from_content_block(ContentBlock::Text { + text: "file1.txt\nfile2.txt".to_string(), + })], + raw_input: Some(json!({"cmd": "ls"})), + raw_output: Some(json!({"exit_code": 0})), + title: Some("List files".to_string()), + error: None, + metadata: Some(json!({"duration_ms": 100})), + origin: None, + }, + _meta: Some(Meta::default()), + }, + ]; + + for variant in variants { + let json = serde_json::to_string(&variant).unwrap(); + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(variant, deserialized); + } +} + +// ===== Edge Cases: Complex Content ===== + +/// Test UserMessageChunk with complex nested content +#[test] +fn test_complex_nested_content() { + let update = SessionUpdate::UserMessageChunk { + message_id: "msg_complex".to_string(), + content: ContentBlock::ResourceLink { + uri: "data:text/plain;base64,SGVsbG8gV29ybGQh".to_string(), + name: Some("embedded.txt".to_string()), + mime_type: Some("text/plain; charset=utf-8".to_string()), + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + assert_eq!(update, deserialized); +} + +// ===== Clone and Debug ===== + +/// Test SessionUpdate clone +#[test] +fn test_session_update_clone() { + let original = SessionUpdate::UserMessageChunk { + message_id: "msg_clone".to_string(), + content: ContentBlock::Text { + text: "Test".to_string(), + }, + _meta: None, + }; + + let cloned = original.clone(); + assert_eq!(original, cloned); +} + +/// Test SessionUpdate debug formatting +#[test] +fn test_session_update_debug() { + let update = SessionUpdate::UserMessageChunk { + message_id: "msg_debug".to_string(), + content: ContentBlock::Text { + text: "Debug test".to_string(), + }, + _meta: None, + }; + + let debug_str = format!("{:?}", update); + assert!(debug_str.contains("UserMessageChunk")); + assert!(debug_str.contains("msg_debug")); +} + +// ===== Edge Cases: Empty Collections ===== + +/// Test ToolCall with empty content array persists correctly +#[test] +fn test_tool_call_empty_content_persists() { + let update = SessionUpdate::ToolCall { + message_id: "msg_empty".to_string(), + tool_call: ToolCall { + id: "call_empty".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Pending, + content: vec![], // Explicitly empty + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }, + _meta: None, + }; + + let json = serde_json::to_string(&update).unwrap(); + assert!(json.contains(r#""content":[]"#)); + + let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap(); + match deserialized { + SessionUpdate::ToolCall { tool_call, .. } => { + assert_eq!(tool_call.content.len(), 0); + } + _ => panic!("Expected ToolCall"), + } +} diff --git a/crates/dirigent_protocol/tests/tool_call_tests.rs b/crates/dirigent_protocol/tests/tool_call_tests.rs new file mode 100644 index 0000000..c9fb1b9 --- /dev/null +++ b/crates/dirigent_protocol/tests/tool_call_tests.rs @@ -0,0 +1,601 @@ +/// Comprehensive edge case tests for ToolCall and ToolCallStatus +use dirigent_protocol::types::{ContentBlock, ToolCall, ToolCallContent, ToolCallStatus}; +use serde_json::json; + +// ===== ToolCallStatus Tests ===== + +/// Test all ToolCallStatus variants serialize correctly +#[test] +fn test_all_status_variants_serialize() { + let pending = ToolCallStatus::Pending; + assert_eq!(serde_json::to_string(&pending).unwrap(), r#""pending""#); + + let running = ToolCallStatus::Running; + assert_eq!(serde_json::to_string(&running).unwrap(), r#""running""#); + + let completed = ToolCallStatus::Completed; + assert_eq!( + serde_json::to_string(&completed).unwrap(), + r#""completed""# + ); + + let error = ToolCallStatus::Error; + assert_eq!(serde_json::to_string(&error).unwrap(), r#""error""#); +} + +/// Test ToolCallStatus deserialization +#[test] +fn test_status_deserialization() { + let status: ToolCallStatus = serde_json::from_str(r#""pending""#).unwrap(); + assert_eq!(status, ToolCallStatus::Pending); + + let status: ToolCallStatus = serde_json::from_str(r#""running""#).unwrap(); + assert_eq!(status, ToolCallStatus::Running); + + let status: ToolCallStatus = serde_json::from_str(r#""completed""#).unwrap(); + assert_eq!(status, ToolCallStatus::Completed); + + let status: ToolCallStatus = serde_json::from_str(r#""error""#).unwrap(); + assert_eq!(status, ToolCallStatus::Error); +} + +/// Test invalid status deserialization +#[test] +fn test_invalid_status_deserialization() { + let result: Result<ToolCallStatus, _> = serde_json::from_str(r#""invalid""#); + assert!(result.is_err(), "Should fail with invalid status"); + + let result: Result<ToolCallStatus, _> = serde_json::from_str(r#""PENDING""#); + assert!( + result.is_err(), + "Should fail with uppercase (not snake_case)" + ); +} + +/// Test ToolCallStatus roundtrip +#[test] +fn test_status_roundtrip() { + let statuses = [ + ToolCallStatus::Pending, + ToolCallStatus::Running, + ToolCallStatus::Completed, + ToolCallStatus::Error, + ]; + + for status in statuses { + let json = serde_json::to_string(&status).unwrap(); + let deserialized: ToolCallStatus = serde_json::from_str(&json).unwrap(); + assert_eq!(status, deserialized); + } +} + +/// Test ToolCallStatus equality and copy +#[test] +fn test_status_equality_and_copy() { + let status1 = ToolCallStatus::Pending; + let status2 = status1; // Copy + assert_eq!(status1, status2); + + let status3 = ToolCallStatus::Running; + assert_ne!(status1, status3); +} + +// ===== ToolCall Minimal Tests ===== + +/// Test ToolCall with minimal required fields +#[test] +fn test_tool_call_minimal() { + let tool_call = ToolCall { + id: "call_min".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + + // Required fields present + assert!(json.contains(r#""id":"call_min""#)); + assert!(json.contains(r#""tool_name":"test""#)); + assert!(json.contains(r#""status":"pending""#)); + assert!(json.contains(r#""content":[]"#)); + + // Optional fields not present + assert!(!json.contains(r#""raw_input""#)); + assert!(!json.contains(r#""raw_output""#)); + assert!(!json.contains(r#""title""#)); + assert!(!json.contains(r#""error""#)); + assert!(!json.contains(r#""metadata""#)); +} + +/// Test ToolCall with all fields populated +#[test] +fn test_tool_call_maximal() { + let tool_call = ToolCall { + id: "call_max".to_string(), + tool_name: "bash".to_string(), + status: ToolCallStatus::Completed, + content: vec![ToolCallContent::from_content_block(ContentBlock::Text { + text: "Output".to_string(), + })], + raw_input: Some(json!({"command": "ls"})), + raw_output: Some(json!({"exit_code": 0})), + title: Some("List files".to_string()), + error: None, + metadata: Some(json!({"duration_ms": 123})), + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + + // All fields present + assert!(json.contains(r#""id":"call_max""#)); + assert!(json.contains(r#""tool_name":"bash""#)); + assert!(json.contains(r#""status":"completed""#)); + assert!(json.contains(r#""content""#)); + assert!(json.contains(r#""raw_input""#)); + assert!(json.contains(r#""raw_output""#)); + assert!(json.contains(r#""title":"List files""#)); + assert!(json.contains(r#""metadata""#)); +} + +// ===== Edge Cases: Empty Strings ===== + +/// Test empty tool_name +#[test] +fn test_empty_tool_name() { + let tool_call = ToolCall { + id: "call_empty_name".to_string(), + tool_name: String::new(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + assert!(json.contains(r#""tool_name":"""#)); + + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.tool_name, ""); +} + +/// Test empty id +#[test] +fn test_empty_id() { + let tool_call = ToolCall { + id: String::new(), + tool_name: "test".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + assert!(json.contains(r#""id":"""#)); + + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.id, ""); +} + +/// Test empty title (Some("")) +#[test] +fn test_empty_title() { + let tool_call = ToolCall { + id: "call_empty_title".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: Some(String::new()), + error: None, + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + assert!(json.contains(r#""title":"""#)); + + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.title, Some(String::new())); +} + +/// Test empty error message (Some("")) +#[test] +fn test_empty_error_message() { + let tool_call = ToolCall { + id: "call_empty_error".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Error, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: Some(String::new()), + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + assert!(json.contains(r#""error":"""#)); + + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.error, Some(String::new())); +} + +// ===== Edge Cases: Large Data ===== + +/// Test very long error message +#[test] +fn test_long_error_message() { + let long_error = "Error: ".to_string() + &"x".repeat(10_000); + let tool_call = ToolCall { + id: "call_long_error".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Error, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: Some(long_error.clone()), + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.error.unwrap().len(), long_error.len()); +} + +/// Test large metadata +#[test] +fn test_large_metadata() { + let large_meta = json!({ + "key1": "value".repeat(1000), + "key2": [1, 2, 3, 4, 5], + "nested": { + "deep": { + "value": "test" + } + } + }); + + let tool_call = ToolCall { + id: "call_large_meta".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Completed, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: Some(large_meta.clone()), + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.metadata, Some(large_meta)); +} + +/// Test many content blocks +#[test] +fn test_many_content_blocks() { + let mut content = vec![]; + for i in 0..100 { + content.push(ToolCallContent::from_content_block(ContentBlock::Text { + text: format!("Line {}", i), + })); + } + + let tool_call = ToolCall { + id: "call_many_blocks".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Running, + content: content.clone(), + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.content.len(), 100); + assert_eq!(deserialized.content, content); +} + +// ===== Edge Cases: Special Characters ===== + +/// Test special characters in tool_name +#[test] +fn test_special_chars_in_tool_name() { + let tool_call = ToolCall { + id: "call_special".to_string(), + tool_name: "bash::execute!@#$%^&*()".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.tool_name, "bash::execute!@#$%^&*()"); +} + +/// Test unicode in error message +#[test] +fn test_unicode_in_error() { + let tool_call = ToolCall { + id: "call_unicode".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Error, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: Some("错误: 文件不存在 🚫".to_string()), + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.error.unwrap(), "错误: 文件不存在 🚫"); +} + +// ===== Default Content Field ===== + +/// Test that content defaults to empty vec when not in JSON +#[test] +fn test_content_default() { + let json = r#"{ + "id": "call_default", + "tool_name": "test", + "status": "pending" + }"#; + + let tool_call: ToolCall = serde_json::from_str(json).unwrap(); + assert_eq!(tool_call.content, vec![]); +} + +/// Test that explicit empty content works +#[test] +fn test_explicit_empty_content() { + let json = r#"{ + "id": "call_explicit", + "tool_name": "test", + "status": "pending", + "content": [] + }"#; + + let tool_call: ToolCall = serde_json::from_str(json).unwrap(); + assert_eq!(tool_call.content, vec![]); +} + +// ===== Error Cases ===== + +/// Test missing required field (id) +#[test] +fn test_missing_id() { + let json = r#"{ + "tool_name": "test", + "status": "pending" + }"#; + + let result: Result<ToolCall, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without id"); +} + +/// Test missing required field (tool_name) +#[test] +fn test_missing_tool_name() { + let json = r#"{ + "id": "call_test", + "status": "pending" + }"#; + + let result: Result<ToolCall, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without tool_name"); +} + +/// Test missing required field (status) +#[test] +fn test_missing_status() { + let json = r#"{ + "id": "call_test", + "tool_name": "test" + }"#; + + let result: Result<ToolCall, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail without status"); +} + +/// Test null values for required fields +#[test] +fn test_null_required_fields() { + let json = r#"{ + "id": null, + "tool_name": "test", + "status": "pending" + }"#; + + let result: Result<ToolCall, _> = serde_json::from_str(json); + assert!(result.is_err(), "Should fail with null id"); +} + +/// Test null values for optional fields (should be None) +#[test] +fn test_null_optional_fields() { + let json = r#"{ + "id": "call_null_opts", + "tool_name": "test", + "status": "pending", + "raw_input": null, + "raw_output": null, + "title": null, + "error": null, + "metadata": null + }"#; + + let tool_call: ToolCall = serde_json::from_str(json).unwrap(); + assert!(tool_call.raw_input.is_none()); + assert!(tool_call.raw_output.is_none()); + assert!(tool_call.title.is_none()); + assert!(tool_call.error.is_none()); + assert!(tool_call.metadata.is_none()); +} + +// ===== Status-Specific Tests ===== + +/// Test Error status with error message +#[test] +fn test_error_status_with_message() { + let tool_call = ToolCall { + id: "call_error".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Error, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: Some("Something went wrong".to_string()), + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + assert!(json.contains(r#""status":"error""#)); + assert!(json.contains(r#""error":"Something went wrong""#)); + + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.status, ToolCallStatus::Error); + assert_eq!( + deserialized.error, + Some("Something went wrong".to_string()) + ); +} + +/// Test Completed status with output +#[test] +fn test_completed_status_with_output() { + let tool_call = ToolCall { + id: "call_completed".to_string(), + tool_name: "read".to_string(), + status: ToolCallStatus::Completed, + content: vec![ToolCallContent::from_content_block(ContentBlock::Text { + text: "File contents".to_string(), + })], + raw_input: Some(json!({"path": "/tmp/test.txt"})), + raw_output: Some(json!({"bytes_read": 1024})), + title: Some("Read file".to_string()), + error: None, + metadata: Some(json!({"duration_ms": 42})), + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.status, ToolCallStatus::Completed); + assert!(deserialized.raw_output.is_some()); + assert!(deserialized.error.is_none()); +} + +// ===== Roundtrip Tests ===== + +/// Test roundtrip for all status variants +#[test] +fn test_roundtrip_all_statuses() { + let statuses = [ + ToolCallStatus::Pending, + ToolCallStatus::Running, + ToolCallStatus::Completed, + ToolCallStatus::Error, + ]; + + for status in statuses { + let tool_call = ToolCall { + id: format!("call_{:?}", status), + tool_name: "test".to_string(), + status, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let json = serde_json::to_string(&tool_call).unwrap(); + let deserialized: ToolCall = serde_json::from_str(&json).unwrap(); + assert_eq!(tool_call, deserialized); + } +} + +// ===== Clone and Debug ===== + +/// Test ToolCall clone +#[test] +fn test_tool_call_clone() { + let original = ToolCall { + id: "call_clone".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let cloned = original.clone(); + assert_eq!(original, cloned); +} + +/// Test ToolCall debug formatting +#[test] +fn test_tool_call_debug() { + let tool_call = ToolCall { + id: "call_debug".to_string(), + tool_name: "test".to_string(), + status: ToolCallStatus::Pending, + content: vec![], + raw_input: None, + raw_output: None, + title: None, + error: None, + metadata: None, + origin: None, + }; + + let debug_str = format!("{:?}", tool_call); + assert!(debug_str.contains("ToolCall")); + assert!(debug_str.contains("call_debug")); +} diff --git a/crates/dirigent_taskrunner/CLAUDE.md b/crates/dirigent_taskrunner/CLAUDE.md new file mode 100644 index 0000000..f5910c0 --- /dev/null +++ b/crates/dirigent_taskrunner/CLAUDE.md @@ -0,0 +1,73 @@ +# Package: dirigent_taskrunner + +Background task runner for managing child processes with output capture. + +## Quick Facts +- **Type**: Library +- **Main Entry**: src/lib.rs +- **Dependencies**: tokio, serde, chrono, thiserror, tracing, uuid + +## Overview + +The dirigent_taskrunner package provides a `TaskRunner` service that spawns, manages, and captures output from arbitrary shell commands. Tasks are defined with a title, slug name, command, arguments, and various options (working directory, startup behavior, output persistence, log rotation). + +## Architecture + +### Core Types + +- **TaskDefinition** — Configuration for a task: command, args, cwd, run_at_startup, persist_to_disk, rotate_previous, env vars +- **TaskStatus** — Runtime state enum: Stopped, Running{pid}, Finished{exit_code}, Failed{error} +- **TaskInfo** — Definition + status + timestamps (started_at, stopped_at) +- **OutputKind** — Stdout, Stderr, Combined +- **TaskId** — String alias (the task slug/name) + +### TaskRunner + +The main service. Uses interior mutability (RwLock) — all methods take `&self`. Designed to be wrapped in `Arc` and shared across async tasks. + +Key operations: +- `register(def)` — Add a task definition +- `start(name)` — Spawn the process, capture stdout/stderr to files +- `stop(name)` — Kill the process +- `poll_completed()` — Check running processes for exit (called from periodic timer) +- `list_tasks()` — Get all tasks with status +- `read_output(name, kind, tail_lines)` — Read captured output +- `remove(name)` — Delete a task definition + +### Output Storage + +Output is stored in `{tasks_dir}/{task_name}/`: +- `stdout.log` — Captured stdout +- `stderr.log` — Captured stderr +- `combined.log` — Interleaved with `[stdout]`/`[stderr]` prefixes + +Log rotation creates `.log.1`, `.log.2`, etc. + +## Integration + +- **Config**: `CoreConfig.tasks: Vec<TaskConfig>` in dirigent.toml (`[[tasks]]` sections) +- **Runtime**: `CoreRuntime.task_runner_slot()` holds `Arc<RwLock<Option<Arc<TaskRunner>>>>` +- **API**: Server functions in `api::tasks` (list, start, stop, output, create, update, delete) +- **UI**: Tasks ribbon mode + Configuration > Tasks section +- **Inspector**: Registered as `dirigent/services/task-runner` +- **Paths**: `DirigentPaths::tasks_dir()` returns `{data_dir}/tasks/` + +## Configuration Example + +```toml +[[tasks]] +name = "lspmux" +title = "LSP Mux Server" +command = "lspmux" +args = ["server"] +run_at_startup = true +persist_to_disk = true +rotate_previous = true +``` + +## Key Files + +- `src/types.rs` — TaskDefinition, TaskStatus, TaskInfo, OutputKind +- `src/runner.rs` — TaskRunner service, TaskError +- `src/output.rs` — TaskOutputManager (file I/O, rotation) +- `src/lib.rs` — Public exports diff --git a/crates/dirigent_taskrunner/Cargo.toml b/crates/dirigent_taskrunner/Cargo.toml new file mode 100644 index 0000000..a045a16 --- /dev/null +++ b/crates/dirigent_taskrunner/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "dirigent_taskrunner" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[dependencies] +chrono = { version = "0.4", features = ["serde"] } +dirigent_process = { path = "../dirigent_process", features = ["tokio"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "2.0" +tokio = { version = "1", features = ["process", "io-util", "fs", "sync", "time", "rt"] } +tracing = "0.1" +uuid = { version = "1.0", features = ["v7", "serde"] } diff --git a/crates/dirigent_taskrunner/src/lib.rs b/crates/dirigent_taskrunner/src/lib.rs new file mode 100644 index 0000000..8e5fb56 --- /dev/null +++ b/crates/dirigent_taskrunner/src/lib.rs @@ -0,0 +1,7 @@ +pub mod types; +pub mod output; +mod runner; + +pub use types::*; +pub use output::TaskOutputManager; +pub use runner::{TaskRunner, TaskError}; diff --git a/crates/dirigent_taskrunner/src/output.rs b/crates/dirigent_taskrunner/src/output.rs new file mode 100644 index 0000000..bcaa56e --- /dev/null +++ b/crates/dirigent_taskrunner/src/output.rs @@ -0,0 +1,84 @@ +use std::path::PathBuf; +use tokio::fs; + +use crate::types::OutputKind; + +/// Manages output files for a single task +pub struct TaskOutputManager { + base_dir: PathBuf, +} + +impl TaskOutputManager { + pub fn new(base_dir: PathBuf) -> Self { + Self { base_dir } + } + + pub async fn ensure_dir(&self) -> std::io::Result<()> { + fs::create_dir_all(&self.base_dir).await + } + + pub fn stdout_path(&self) -> PathBuf { + self.base_dir.join("stdout.log") + } + pub fn stderr_path(&self) -> PathBuf { + self.base_dir.join("stderr.log") + } + pub fn combined_path(&self) -> PathBuf { + self.base_dir.join("combined.log") + } + + /// Rotate existing files (.log -> .log.1, .log.1 -> .log.2, etc.) + pub async fn rotate(&self) -> std::io::Result<()> { + for name in &["stdout.log", "stderr.log", "combined.log"] { + let path = self.base_dir.join(name); + if fs::try_exists(&path).await.unwrap_or(false) { + let mut n = 1; + loop { + let rotated = self.base_dir.join(format!("{}.{}", name, n)); + if !fs::try_exists(&rotated).await.unwrap_or(false) { + fs::rename(&path, &rotated).await?; + break; + } + n += 1; + } + } + } + Ok(()) + } + + /// Read output file contents (tail N lines if specified) + pub async fn read_output( + &self, + kind: OutputKind, + tail_lines: Option<usize>, + ) -> std::io::Result<String> { + let path = match kind { + OutputKind::Stdout => self.stdout_path(), + OutputKind::Stderr => self.stderr_path(), + OutputKind::Combined => self.combined_path(), + }; + + if !fs::try_exists(&path).await.unwrap_or(false) { + return Ok(String::new()); + } + + let content = fs::read_to_string(&path).await?; + + if let Some(n) = tail_lines { + let lines: Vec<&str> = content.lines().collect(); + let start = lines.len().saturating_sub(n); + Ok(lines[start..].join("\n")) + } else { + Ok(content) + } + } + + pub async fn clear(&self) -> std::io::Result<()> { + for path in &[self.stdout_path(), self.stderr_path(), self.combined_path()] { + if fs::try_exists(path).await.unwrap_or(false) { + fs::write(path, b"").await?; + } + } + Ok(()) + } +} diff --git a/crates/dirigent_taskrunner/src/runner.rs b/crates/dirigent_taskrunner/src/runner.rs new file mode 100644 index 0000000..2a8d9d9 --- /dev/null +++ b/crates/dirigent_taskrunner/src/runner.rs @@ -0,0 +1,467 @@ +use crate::output::TaskOutputManager; +use crate::types::*; +use std::collections::HashMap; +use std::path::PathBuf; +use tokio::io::AsyncBufReadExt; +use tokio::process::Command; +use tokio::sync::RwLock; + +#[derive(Debug, thiserror::Error)] +pub enum TaskError { + #[error("Task '{0}' not found")] + NotFound(String), + #[error("Task '{0}' is already running")] + AlreadyRunning(String), + #[error("Task '{0}' is not running")] + NotRunning(String), + #[error("Failed to spawn process: {0}")] + SpawnFailed(String), + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + #[error("Task name '{0}' already exists")] + DuplicateName(String), +} + +struct RunningTask { + abort_handles: Vec<tokio::task::JoinHandle<()>>, + child: tokio::process::Child, + lifecycle: Option<Box<dyn dirigent_process::ProcessLifecycle>>, +} + +/// The main task runner service. +/// All methods take &self — uses interior mutability for shared access. +pub struct TaskRunner { + definitions: RwLock<HashMap<TaskId, TaskDefinition>>, + statuses: RwLock<HashMap<TaskId, TaskStatus>>, + started_at: RwLock<HashMap<TaskId, chrono::DateTime<chrono::Utc>>>, + stopped_at: RwLock<HashMap<TaskId, chrono::DateTime<chrono::Utc>>>, + running: RwLock<HashMap<TaskId, RunningTask>>, + tasks_dir: PathBuf, + default_working_dir: PathBuf, + process_manager: Option<std::sync::Arc<dyn dirigent_process::ProcessGroupManager>>, +} + +impl TaskRunner { + pub fn new( + tasks_dir: PathBuf, + default_working_dir: PathBuf, + process_manager: Option<std::sync::Arc<dyn dirigent_process::ProcessGroupManager>>, + ) -> Self { + Self { + definitions: RwLock::new(HashMap::new()), + statuses: RwLock::new(HashMap::new()), + started_at: RwLock::new(HashMap::new()), + stopped_at: RwLock::new(HashMap::new()), + running: RwLock::new(HashMap::new()), + tasks_dir, + default_working_dir, + process_manager, + } + } + + pub fn tasks_dir(&self) -> &PathBuf { + &self.tasks_dir + } + + /// Register a task definition (does not start it). + /// Allows re-registration to update an existing task. + pub async fn register(&self, def: TaskDefinition) -> Result<(), TaskError> { + let name = def.name.clone(); + self.definitions.write().await.insert(name.clone(), def); + self.statuses + .write() + .await + .entry(name) + .or_insert(TaskStatus::Stopped); + Ok(()) + } + + /// Remove a task definition (stops it if running) + pub async fn remove(&self, name: &str) -> Result<(), TaskError> { + if self.is_running(name).await { + self.stop(name).await?; + } + self.definitions.write().await.remove(name); + self.statuses.write().await.remove(name); + self.started_at.write().await.remove(name); + self.stopped_at.write().await.remove(name); + Ok(()) + } + + pub async fn is_running(&self, name: &str) -> bool { + matches!( + self.statuses.read().await.get(name), + Some(TaskStatus::Running { .. }) + ) + } + + /// Start a task by name + pub async fn start(&self, name: &str) -> Result<(), TaskError> { + let def = { + let defs = self.definitions.read().await; + defs.get(name) + .cloned() + .ok_or_else(|| TaskError::NotFound(name.to_string()))? + }; + + if self.is_running(name).await { + return Err(TaskError::AlreadyRunning(name.to_string())); + } + + let output_mgr = TaskOutputManager::new(self.tasks_dir.join(&def.name)); + output_mgr.ensure_dir().await?; + + if def.rotate_previous { + if let Err(e) = output_mgr.rotate().await { + tracing::warn!("Failed to rotate output for task {}: {}", name, e); + } + } + + // Resolve working directory: explicit > default > current process dir + let raw_cwd = def + .working_directory + .clone() + .unwrap_or_else(|| self.default_working_dir.clone()); + + // Canonicalize to an absolute path; fall back to current dir if invalid + let cwd = match std::fs::canonicalize(&raw_cwd) { + Ok(p) => p, + Err(_) => { + tracing::warn!( + "Task '{}': working directory '{}' invalid, falling back to current dir", + name, + raw_cwd.display() + ); + std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")) + } + }; + + let lifecycle = self.process_manager.as_ref().map(|mgr| mgr.create_lifecycle()); + + let mut cmd = Command::new(&def.command); + cmd.args(&def.args) + .current_dir(&cwd) + .stdout(std::process::Stdio::piped()) + .stderr(std::process::Stdio::piped()) + .kill_on_drop(true); + + if let Some(ref lc) = lifecycle { + lc.configure_async_command(&mut cmd); + } + + for (key, value) in &def.env { + cmd.env(key, value); + } + + let mut child = match cmd.spawn() { + Ok(child) => child, + Err(e) => { + let error_msg = format!("{} (cwd: {}): {}", def.command, cwd.display(), e); + // Write error to stderr.log and combined.log so the user can see it in the output viewer + let timestamp = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ"); + let log_line = format!("[{}] Failed to start: {}\n", timestamp, error_msg); + if def.persist_to_disk { + let _ = tokio::fs::write(output_mgr.stderr_path(), log_line.as_bytes()).await; + let _ = tokio::fs::write(output_mgr.combined_path(), format!("[stderr] {}", log_line).as_bytes()).await; + } + // Set status to Failed so the UI shows it + self.statuses.write().await.insert(name.to_string(), TaskStatus::Failed { error: error_msg.clone() }); + self.stopped_at.write().await.insert(name.to_string(), chrono::Utc::now()); + return Err(TaskError::SpawnFailed(error_msg)); + } + }; + let pid = child.id().unwrap_or(0); + tracing::info!("Task '{}' started with PID {} (cwd: {})", name, pid, cwd.display()); + + if let Some(ref lc) = lifecycle { + if let Some(child_pid) = child.id() { + if let Err(e) = lc.register_child(child_pid) { + tracing::warn!(error = %e, "Failed to register task child with process lifecycle"); + } + } + } + + let stdout = child.stdout.take(); + let stderr = child.stderr.take(); + let persist = def.persist_to_disk; + let mut abort_handles = Vec::new(); + + // When not rotating, truncate old logs so we don't accumulate output across restarts + let truncate = !def.rotate_previous; + if truncate && persist { + let _ = tokio::fs::write(output_mgr.stdout_path(), b"").await; + let _ = tokio::fs::write(output_mgr.stderr_path(), b"").await; + let _ = tokio::fs::write(output_mgr.combined_path(), b"").await; + } + + // Stdout capture task + if let Some(stdout) = stdout { + let stdout_path = output_mgr.stdout_path(); + let combined_path = output_mgr.combined_path(); + let task_name = name.to_string(); + let h = tokio::spawn(async move { + let reader = tokio::io::BufReader::new(stdout); + let mut lines = reader.lines(); + let mut stdout_file = if persist { + tokio::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&stdout_path) + .await + .ok() + } else { + None + }; + let mut combined_file = if persist { + tokio::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&combined_path) + .await + .ok() + } else { + None + }; + + while let Ok(Some(line)) = lines.next_line().await { + let ts = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ"); + if let Some(ref mut f) = stdout_file { + let _ = tokio::io::AsyncWriteExt::write_all( + f, + format!("[{}] {}\n", ts, line).as_bytes(), + ) + .await; + } + if let Some(ref mut f) = combined_file { + let _ = tokio::io::AsyncWriteExt::write_all( + f, + format!("[{}] [stdout] {}\n", ts, line).as_bytes(), + ) + .await; + } + } + tracing::debug!("Stdout capture ended for task '{}'", task_name); + }); + abort_handles.push(h); + } + + // Stderr capture task + if let Some(stderr) = stderr { + let stderr_path = output_mgr.stderr_path(); + let combined_path = output_mgr.combined_path(); + let task_name = name.to_string(); + let h = tokio::spawn(async move { + let reader = tokio::io::BufReader::new(stderr); + let mut lines = reader.lines(); + let mut stderr_file = if persist { + tokio::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&stderr_path) + .await + .ok() + } else { + None + }; + let mut combined_file = if persist { + tokio::fs::OpenOptions::new() + .create(true) + .append(true) + .open(&combined_path) + .await + .ok() + } else { + None + }; + + while let Ok(Some(line)) = lines.next_line().await { + let ts = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ"); + if let Some(ref mut f) = stderr_file { + let _ = tokio::io::AsyncWriteExt::write_all( + f, + format!("[{}] {}\n", ts, line).as_bytes(), + ) + .await; + } + if let Some(ref mut f) = combined_file { + let _ = tokio::io::AsyncWriteExt::write_all( + f, + format!("[{}] [stderr] {}\n", ts, line).as_bytes(), + ) + .await; + } + } + tracing::debug!("Stderr capture ended for task '{}'", task_name); + }); + abort_handles.push(h); + } + + self.statuses + .write() + .await + .insert(name.to_string(), TaskStatus::Running { pid }); + self.started_at + .write() + .await + .insert(name.to_string(), chrono::Utc::now()); + self.stopped_at.write().await.remove(name); + + self.running.write().await.insert( + name.to_string(), + RunningTask { + abort_handles, + child, + lifecycle, + }, + ); + + Ok(()) + } + + /// Stop a running task + pub async fn stop(&self, name: &str) -> Result<(), TaskError> { + if !self.is_running(name).await { + return Err(TaskError::NotRunning(name.to_string())); + } + + let mut running = self.running.write().await; + if let Some(mut task) = running.remove(name) { + if let Some(ref lifecycle) = task.lifecycle { + dirigent_process::graceful_shutdown_async( + lifecycle.as_ref(), + &mut task.child, + std::time::Duration::from_secs(3), + ) + .await; + } else { + let _ = task.child.kill().await; + } + for h in task.abort_handles { + h.abort(); + } + tracing::info!("Task '{}' stopped", name); + } + + self.statuses + .write() + .await + .insert(name.to_string(), TaskStatus::Stopped); + self.stopped_at + .write() + .await + .insert(name.to_string(), chrono::Utc::now()); + Ok(()) + } + + /// Poll running tasks for completion (call periodically from a timer) + pub async fn poll_completed(&self) { + let mut running = self.running.write().await; + let mut completed = Vec::new(); + + for (name, task) in running.iter_mut() { + match task.child.try_wait() { + Ok(Some(status)) => { + let exit_code = status.code(); + tracing::info!( + "Task '{}' finished with exit code: {:?}", + name, + exit_code + ); + completed.push((name.clone(), exit_code)); + } + Ok(None) => {} + Err(e) => { + tracing::error!("Error checking task '{}': {}", name, e); + completed.push((name.clone(), None)); + } + } + } + + let mut statuses = self.statuses.write().await; + let mut stopped_at = self.stopped_at.write().await; + for (name, exit_code) in completed { + running.remove(&name); + statuses.insert(name.clone(), TaskStatus::Finished { exit_code }); + stopped_at.insert(name.clone(), chrono::Utc::now()); + } + } + + /// List all tasks with their info + pub async fn list_tasks(&self) -> Vec<TaskInfo> { + let defs = self.definitions.read().await; + let statuses = self.statuses.read().await; + let started = self.started_at.read().await; + let stopped = self.stopped_at.read().await; + + defs.values() + .map(|def| TaskInfo { + definition: def.clone(), + status: statuses + .get(&def.name) + .cloned() + .unwrap_or(TaskStatus::Stopped), + started_at: started.get(&def.name).cloned(), + stopped_at: stopped.get(&def.name).cloned(), + }) + .collect() + } + + /// Get info for a specific task + pub async fn get_task(&self, name: &str) -> Option<TaskInfo> { + let defs = self.definitions.read().await; + let def = defs.get(name)?; + let statuses = self.statuses.read().await; + let started = self.started_at.read().await; + let stopped = self.stopped_at.read().await; + Some(TaskInfo { + definition: def.clone(), + status: statuses + .get(name) + .cloned() + .unwrap_or(TaskStatus::Stopped), + started_at: started.get(name).cloned(), + stopped_at: stopped.get(name).cloned(), + }) + } + + /// Read output for a task + pub async fn read_output( + &self, + name: &str, + kind: OutputKind, + tail_lines: Option<usize>, + ) -> Result<String, TaskError> { + { + let defs = self.definitions.read().await; + if !defs.contains_key(name) { + return Err(TaskError::NotFound(name.to_string())); + } + } + let mgr = TaskOutputManager::new(self.tasks_dir.join(name)); + mgr.read_output(kind, tail_lines).await.map_err(TaskError::Io) + } + + /// Get all task definitions (for config persistence) + pub async fn get_definitions(&self) -> Vec<TaskDefinition> { + self.definitions.read().await.values().cloned().collect() + } + + /// Update a task definition (stops if running, re-registers) + pub async fn update(&self, def: TaskDefinition) -> Result<(), TaskError> { + let name = def.name.clone(); + if self.is_running(&name).await { + self.stop(&name).await?; + } + self.register(def).await + } + + /// Stop all running tasks. Used during graceful shutdown. + pub async fn stop_all(&self) { + let names: Vec<String> = self.running.read().await.keys().cloned().collect(); + for name in names { + if let Err(e) = self.stop(&name).await { + tracing::warn!(task = %name, error = %e, "Failed to stop task during shutdown"); + } + } + } +} diff --git a/crates/dirigent_taskrunner/src/types.rs b/crates/dirigent_taskrunner/src/types.rs new file mode 100644 index 0000000..2e4efbb --- /dev/null +++ b/crates/dirigent_taskrunner/src/types.rs @@ -0,0 +1,71 @@ +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +/// Unique identifier for a task (the slug/name) +pub type TaskId = String; + +/// How a task is defined +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskDefinition { + /// Human-readable title + pub title: String, + /// Unique slug (used as TOML key and file directory name) + pub name: String, + /// The command to execute (e.g. "lspmux", "python") + pub command: String, + /// Arguments to the command + #[serde(default)] + pub args: Vec<String>, + /// Working directory (None = runtime working dir) + #[serde(default, skip_serializing_if = "Option::is_none")] + pub working_directory: Option<PathBuf>, + /// Run this task when dirigent starts + #[serde(default)] + pub run_at_startup: bool, + /// Max lines to keep in memory buffer (0 = unlimited) + #[serde(default = "default_buffer_size")] + pub buffer_size: usize, + /// Write output to disk (overrides buffer_size — keeps everything) + #[serde(default = "default_persist")] + pub persist_to_disk: bool, + /// Rotate previous output file before starting + #[serde(default)] + pub rotate_previous: bool, + /// Environment variables to set + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub env: Vec<(String, String)>, +} + +fn default_buffer_size() -> usize { + 10000 +} +fn default_persist() -> bool { + true +} + +/// Runtime state of a task +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub enum TaskStatus { + Stopped, + Running { pid: u32 }, + Finished { exit_code: Option<i32> }, + Failed { error: String }, +} + +/// Full info about a task (definition + runtime state) +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TaskInfo { + pub definition: TaskDefinition, + pub status: TaskStatus, + pub started_at: Option<DateTime<Utc>>, + pub stopped_at: Option<DateTime<Utc>>, +} + +/// Which output stream to read +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] +pub enum OutputKind { + Stdout, + Stderr, + Combined, +} diff --git a/crates/dirigent_testing/CLAUDE.md b/crates/dirigent_testing/CLAUDE.md new file mode 100644 index 0000000..e80b397 --- /dev/null +++ b/crates/dirigent_testing/CLAUDE.md @@ -0,0 +1,62 @@ +# Package: dirigent_testing + +Testing utilities for Dirigent with replay-based e2e test support. + +## Quick Facts +- **Type**: Library (dev/test utility) +- **Main Entry**: src/lib.rs +- **Dependencies**: serde, serde_json, thiserror, uuid +- **Status**: Initial — replay framework only + +## Purpose + +Provides testing infrastructure for Dirigent, starting with replay-based end-to-end tests that use recorded ACP (Agent-Client Protocol) interactions. Fixtures are stored as JSON files and can be loaded, filtered, and round-tripped through serde. + +## Module Organization + +- **`lib.rs`**: Public API surface and re-exports +- **`replay.rs`**: Core replay types — `AcpReplay`, `ReplayMessage`, `Direction`, `ReplaySource` +- **`fixtures.rs`**: Fixture loading utilities — `load_fixture`, `fixture_path`, `list_fixtures` + +## Fixtures + +Fixture files live in `fixtures/` and are JSON files conforming to the `AcpReplay` schema. Each fixture contains: +- `name`: Human-readable identifier +- `source`: Origin system (`zed`, `claude`, or custom) +- `messages`: Ordered sequence of `ReplayMessage` with direction, payload, and optional delay + +### Available Fixtures +- `minimal_init.json` — Minimal MCP/ACP initialize handshake (client request + server response) +- `zed_claude_session.json` — Real Zed-Claude ACP session adapted from recorded traffic (9 messages: initialize, session/load with updates, session/list) + +## Usage + +```rust +use dirigent_testing::{load_fixture, AcpReplay, Direction}; + +let replay = load_fixture("minimal_init.json").unwrap(); +assert_eq!(replay.client_messages().len(), 1); +assert_eq!(replay.agent_messages().len(), 1); +``` + +## Testing + +```bash +cargo test -p dirigent_testing +``` + +## Related Packages + +- **dirigent_acp_api**: ACP server that these replays exercise +- **dirigent_core**: Runtime under test in integration scenarios +- **dirigent_protocol**: Shared protocol types + +## Integration Tests + +- `tests/zed_claude_replay.rs` — Tests for the Zed-Claude session fixture: loading, message counts, direction filtering, protocol structure validation, serde roundtrip + +## Future Enhancements + +- Replay runner that drives an ACP server with recorded traffic +- Assertion helpers for validating ACP response sequences +- Timing simulation with `delay_ms` support diff --git a/crates/dirigent_testing/Cargo.toml b/crates/dirigent_testing/Cargo.toml new file mode 100644 index 0000000..702d5e3 --- /dev/null +++ b/crates/dirigent_testing/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "dirigent_testing" +version = "0.1.0" +edition = "2021" +description = "Testing utilities for Dirigent — replay-based e2e tests from recorded ACP traffic" + +[lib] +path = "src/lib.rs" + +[dependencies] +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "2.0" +uuid = { version = "1.0", features = ["v4", "v7"] } diff --git a/crates/dirigent_testing/fixtures/minimal_init.json b/crates/dirigent_testing/fixtures/minimal_init.json new file mode 100644 index 0000000..90aeeb7 --- /dev/null +++ b/crates/dirigent_testing/fixtures/minimal_init.json @@ -0,0 +1,31 @@ +{ + "name": "minimal_init", + "source": "zed", + "messages": [ + { + "direction": "client_to_agent", + "payload": { + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-01-01", + "capabilities": {}, + "clientInfo": { "name": "test", "version": "0.1.0" } + } + } + }, + { + "direction": "agent_to_client", + "payload": { + "jsonrpc": "2.0", + "id": 1, + "result": { + "protocolVersion": "2025-01-01", + "capabilities": {}, + "serverInfo": { "name": "test-agent", "version": "0.1.0" } + } + } + } + ] +} diff --git a/crates/dirigent_testing/fixtures/zed_claude_session.json b/crates/dirigent_testing/fixtures/zed_claude_session.json new file mode 100644 index 0000000..5d7fdcf --- /dev/null +++ b/crates/dirigent_testing/fixtures/zed_claude_session.json @@ -0,0 +1,227 @@ +{ + "name": "zed_claude_session", + "source": "zed", + "messages": [ + { + "direction": "client_to_agent", + "payload": { + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": { + "protocolVersion": 1, + "clientCapabilities": { + "fs": { + "readTextFile": true, + "writeTextFile": true + }, + "terminal": true, + "_meta": { + "terminal_output": true, + "terminal-auth": true + } + }, + "clientInfo": { + "name": "zed", + "title": "Zed", + "version": "0.225.12" + } + } + } + }, + { + "direction": "agent_to_client", + "payload": { + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": { + "protocolVersion": 1, + "agentCapabilities": { + "promptCapabilities": { + "image": true, + "embeddedContext": true + }, + "mcpCapabilities": { + "http": true, + "sse": true + }, + "loadSession": true, + "sessionCapabilities": { + "fork": {}, + "list": {}, + "resume": {} + } + }, + "agentInfo": { + "name": "@zed-industries/claude-agent-acp", + "title": "Claude Agent", + "version": "0.19.2" + }, + "authMethods": [ + { + "description": "Run `claude /login` in the terminal", + "name": "Log in with Claude", + "id": "claude-login" + } + ] + } + }, + "delay_ms": 120 + }, + { + "direction": "client_to_agent", + "payload": { + "jsonrpc": "2.0", + "id": 1, + "method": "session/load", + "params": { + "mcpServers": [], + "cwd": "/dev/projects/dirigent", + "sessionId": "cb878ad6-d72b-43c9-93e0-8228f309a786" + } + }, + "delay_ms": 50 + }, + { + "direction": "agent_to_client", + "payload": { + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "cb878ad6-d72b-43c9-93e0-8228f309a786", + "update": { + "sessionUpdate": "user_message_chunk", + "content": { + "type": "text", + "text": "hi" + } + } + } + }, + "delay_ms": 200 + }, + { + "direction": "agent_to_client", + "payload": { + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "cb878ad6-d72b-43c9-93e0-8228f309a786", + "update": { + "sessionUpdate": "agent_message_chunk", + "content": { + "type": "text", + "text": "Hi! I'm here to help you with the Dirigent project. What would you like to work on today?" + } + } + } + }, + "delay_ms": 800 + }, + { + "direction": "agent_to_client", + "payload": { + "jsonrpc": "2.0", + "id": 1, + "method": "session/load", + "params": { + "modes": { + "currentModeId": "default", + "availableModes": [ + { + "id": "default", + "name": "Default", + "description": "Standard behavior, prompts for dangerous operations" + }, + { + "id": "plan", + "name": "Plan Mode", + "description": "Planning mode, no actual tool execution" + } + ] + }, + "models": { + "availableModels": [ + { + "modelId": "default", + "name": "Default (recommended)", + "description": "Opus 4.6" + }, + { + "modelId": "sonnet", + "name": "Sonnet", + "description": "Sonnet 4.6" + } + ], + "currentModelId": "default" + } + } + }, + "delay_ms": 300 + }, + { + "direction": "client_to_agent", + "payload": { + "jsonrpc": "2.0", + "id": 2, + "method": "session/list", + "params": {} + }, + "delay_ms": 100 + }, + { + "direction": "agent_to_client", + "payload": { + "jsonrpc": "2.0", + "method": "session/update", + "params": { + "sessionId": "cb878ad6-d72b-43c9-93e0-8228f309a786", + "update": { + "sessionUpdate": "available_commands_update", + "availableCommands": [ + { + "name": "compact", + "description": "Clear conversation history but keep a summary in context.", + "input": { + "hint": "<optional custom summarization instructions>" + } + }, + { + "name": "context", + "description": "Show current context usage", + "input": null + } + ] + } + } + }, + "delay_ms": 150 + }, + { + "direction": "agent_to_client", + "payload": { + "jsonrpc": "2.0", + "id": 2, + "method": "session/list", + "params": { + "sessions": [ + { + "sessionId": "cb878ad6-d72b-43c9-93e0-8228f309a786", + "cwd": "/dev/projects/dirigent", + "title": "hi", + "updatedAt": "2026-03-03T14:03:24.740Z" + }, + { + "sessionId": "838b10b2-2f58-4dad-8652-9df81c880a96", + "cwd": "/dev/projects/dirigent", + "title": "Session list investigation", + "updatedAt": "2026-03-03T13:56:15.751Z" + } + ] + } + }, + "delay_ms": 250 + } + ] +} diff --git a/crates/dirigent_testing/src/fixtures.rs b/crates/dirigent_testing/src/fixtures.rs new file mode 100644 index 0000000..0f9de02 --- /dev/null +++ b/crates/dirigent_testing/src/fixtures.rs @@ -0,0 +1,46 @@ +use crate::replay::AcpReplay; +use std::path::{Path, PathBuf}; + +/// Load an ACP replay fixture by filename from the `fixtures/` directory. +pub fn load_fixture(name: &str) -> Result<AcpReplay, FixtureError> { + let path = fixture_path(name); + let content = std::fs::read_to_string(&path).map_err(|e| FixtureError::ReadError { + path: path.clone(), + source: e, + })?; + serde_json::from_str(&content).map_err(|e| FixtureError::ParseError { path, source: e }) +} + +/// Return the absolute path to a fixture file by name. +pub fn fixture_path(name: &str) -> PathBuf { + Path::new(env!("CARGO_MANIFEST_DIR")) + .join("fixtures") + .join(name) +} + +/// List all `.json` fixture filenames in the `fixtures/` directory. +pub fn list_fixtures() -> Vec<String> { + let fixtures_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("fixtures"); + std::fs::read_dir(fixtures_dir) + .into_iter() + .flatten() + .filter_map(|e| e.ok()) + .filter(|e| e.path().extension().is_some_and(|ext| ext == "json")) + .filter_map(|e| e.file_name().into_string().ok()) + .collect() +} + +/// Errors that can occur when loading fixture files. +#[derive(Debug, thiserror::Error)] +pub enum FixtureError { + #[error("Failed to read fixture at {path}: {source}")] + ReadError { + path: PathBuf, + source: std::io::Error, + }, + #[error("Failed to parse fixture at {path}: {source}")] + ParseError { + path: PathBuf, + source: serde_json::Error, + }, +} diff --git a/crates/dirigent_testing/src/lib.rs b/crates/dirigent_testing/src/lib.rs new file mode 100644 index 0000000..bd7bc3a --- /dev/null +++ b/crates/dirigent_testing/src/lib.rs @@ -0,0 +1,8 @@ +//! Testing utilities for Dirigent. +//! Provides replay-based e2e test support using recorded ACP interactions. + +pub mod fixtures; +pub mod replay; + +pub use fixtures::{load_fixture, list_fixtures}; +pub use replay::{AcpReplay, Direction, ReplayMessage, ReplaySource}; diff --git a/crates/dirigent_testing/src/replay.rs b/crates/dirigent_testing/src/replay.rs new file mode 100644 index 0000000..a3273ba --- /dev/null +++ b/crates/dirigent_testing/src/replay.rs @@ -0,0 +1,88 @@ +use serde::{Deserialize, Serialize}; + +/// Direction of a message in an ACP interaction replay. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Direction { + ClientToAgent, + AgentToClient, +} + +/// Source system that the replay was recorded from. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ReplaySource { + Zed, + Claude, + Custom(String), +} + +/// A single message in an ACP replay sequence. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReplayMessage { + pub direction: Direction, + pub payload: serde_json::Value, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub delay_ms: Option<u64>, +} + +/// A complete ACP interaction replay containing a named sequence of messages. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AcpReplay { + pub name: String, + pub source: ReplaySource, + pub messages: Vec<ReplayMessage>, +} + +impl AcpReplay { + /// Returns only the messages sent from client to agent. + pub fn client_messages(&self) -> Vec<&ReplayMessage> { + self.messages + .iter() + .filter(|m| m.direction == Direction::ClientToAgent) + .collect() + } + + /// Returns only the messages sent from agent to client. + pub fn agent_messages(&self) -> Vec<&ReplayMessage> { + self.messages + .iter() + .filter(|m| m.direction == Direction::AgentToClient) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::fixtures; + + #[test] + fn test_load_minimal_fixture() { + let replay = fixtures::load_fixture("minimal_init.json").unwrap(); + assert_eq!(replay.name, "minimal_init"); + assert_eq!(replay.messages.len(), 2); + } + + #[test] + fn test_filter_by_direction() { + let replay = fixtures::load_fixture("minimal_init.json").unwrap(); + assert_eq!(replay.client_messages().len(), 1); + assert_eq!(replay.agent_messages().len(), 1); + } + + #[test] + fn test_list_fixtures() { + let fixtures = fixtures::list_fixtures(); + assert!(fixtures.contains(&"minimal_init.json".to_string())); + } + + #[test] + fn test_serde_roundtrip() { + let replay = fixtures::load_fixture("minimal_init.json").unwrap(); + let json = serde_json::to_string_pretty(&replay).unwrap(); + let parsed: AcpReplay = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.name, replay.name); + assert_eq!(parsed.messages.len(), replay.messages.len()); + } +} diff --git a/crates/dirigent_testing/tests/zed_claude_replay.rs b/crates/dirigent_testing/tests/zed_claude_replay.rs new file mode 100644 index 0000000..db4643e --- /dev/null +++ b/crates/dirigent_testing/tests/zed_claude_replay.rs @@ -0,0 +1,138 @@ +use dirigent_testing::{load_fixture, Direction}; + +#[test] +fn test_zed_claude_fixture_loads() { + let replay = load_fixture("zed_claude_session.json").unwrap(); + assert!(!replay.messages.is_empty()); + assert!(!replay.client_messages().is_empty()); + assert!(!replay.agent_messages().is_empty()); +} + +#[test] +fn test_zed_claude_starts_with_initialize() { + let replay = load_fixture("zed_claude_session.json").unwrap(); + let first = &replay.messages[0]; + assert_eq!(first.direction, Direction::ClientToAgent); + let method = first.payload.get("method").and_then(|m| m.as_str()); + assert_eq!(method, Some("initialize")); +} + +#[test] +fn test_zed_claude_message_counts() { + let replay = load_fixture("zed_claude_session.json").unwrap(); + // 3 client messages: initialize, session/load, session/list + assert_eq!(replay.client_messages().len(), 3); + // 6 agent messages: initialize response, 2x session/update notifications, + // session/load response, available_commands_update, session/list response + assert_eq!(replay.agent_messages().len(), 6); + // Total: 9 messages + assert_eq!(replay.messages.len(), 9); +} + +#[test] +fn test_zed_claude_initialize_has_client_info() { + let replay = load_fixture("zed_claude_session.json").unwrap(); + let init = &replay.messages[0].payload; + let client_info = init + .pointer("/params/clientInfo/name") + .and_then(|v| v.as_str()); + assert_eq!(client_info, Some("zed")); +} + +#[test] +fn test_zed_claude_initialize_response_has_agent_info() { + let replay = load_fixture("zed_claude_session.json").unwrap(); + let init_resp = &replay.messages[1].payload; + assert_eq!(init_resp.get("id").and_then(|v| v.as_u64()), Some(0)); + let agent_name = init_resp + .pointer("/params/agentInfo/name") + .and_then(|v| v.as_str()); + assert_eq!(agent_name, Some("@zed-industries/claude-agent-acp")); +} + +#[test] +fn test_zed_claude_session_load_flow() { + let replay = load_fixture("zed_claude_session.json").unwrap(); + // Message at index 2 is the session/load request + let session_load = &replay.messages[2]; + assert_eq!(session_load.direction, Direction::ClientToAgent); + let method = session_load + .payload + .get("method") + .and_then(|m| m.as_str()); + assert_eq!(method, Some("session/load")); + + let session_id = session_load + .payload + .pointer("/params/sessionId") + .and_then(|v| v.as_str()); + assert_eq!( + session_id, + Some("cb878ad6-d72b-43c9-93e0-8228f309a786") + ); +} + +#[test] +fn test_zed_claude_contains_session_list() { + let replay = load_fixture("zed_claude_session.json").unwrap(); + let list_request = replay + .messages + .iter() + .find(|m| { + m.direction == Direction::ClientToAgent + && m.payload.get("method").and_then(|v| v.as_str()) == Some("session/list") + }) + .expect("should contain a session/list request"); + + assert_eq!(list_request.payload.get("id").and_then(|v| v.as_u64()), Some(2)); + + // Find the matching response + let list_response = replay + .messages + .iter() + .find(|m| { + m.direction == Direction::AgentToClient + && m.payload.get("method").and_then(|v| v.as_str()) == Some("session/list") + && m.payload.get("id").and_then(|v| v.as_u64()) == Some(2) + }) + .expect("should contain a session/list response"); + + let sessions = list_response + .payload + .pointer("/params/sessions") + .and_then(|v| v.as_array()) + .expect("should have sessions array"); + assert_eq!(sessions.len(), 2); +} + +#[test] +fn test_zed_claude_has_delay_timings() { + let replay = load_fixture("zed_claude_session.json").unwrap(); + // First message (initialize request from client) has no delay + assert!(replay.messages[0].delay_ms.is_none()); + // Most subsequent messages should have delay_ms set + let messages_with_delay = replay + .messages + .iter() + .filter(|m| m.delay_ms.is_some()) + .count(); + assert!( + messages_with_delay > 0, + "at least some messages should have delay timing" + ); +} + +#[test] +fn test_zed_claude_serde_roundtrip() { + let replay = load_fixture("zed_claude_session.json").unwrap(); + let json = serde_json::to_string_pretty(&replay).unwrap(); + let parsed: dirigent_testing::AcpReplay = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.name, replay.name); + assert_eq!(parsed.messages.len(), replay.messages.len()); + + for (original, roundtripped) in replay.messages.iter().zip(parsed.messages.iter()) { + assert_eq!(original.direction, roundtripped.direction); + assert_eq!(original.payload, roundtripped.payload); + assert_eq!(original.delay_ms, roundtripped.delay_ms); + } +} diff --git a/crates/dirigent_tools/CLAUDE.md b/crates/dirigent_tools/CLAUDE.md new file mode 100644 index 0000000..2abd038 --- /dev/null +++ b/crates/dirigent_tools/CLAUDE.md @@ -0,0 +1,169 @@ +# Package: dirigent_tools + +Tool implementations for ACP client operations with sandboxing and security. + +## Quick Facts +- **Type**: Library +- **Main Entry**: src/lib.rs +- **Dependencies**: tokio, serde, anyhow, thiserror, tracing, regex, globset, similar, dunce +- **Status**: Scaffolding phase - structure complete, implementation pending + +## Purpose + +This package provides the foundational tool operations for the Dirigent ACP client, implementing file operations, terminal execution, and search capabilities with strong security guarantees. It is designed to support ACP-compliant agents by providing safe, sandboxed tool handlers. + +## Module Structure + +### Core Modules +- **error** (`src/error.rs`) - Error types for tool operations +- **config** (`src/config.rs`) - Configuration types (sandbox, permissions, terminal, search, embedding) +- **path** (`src/path.rs`) - Cross-platform path normalization and containment checking + +### Tool Modules +- **fs** (`src/fs.rs`) - File read/write/edit operations +- **search** (`src/search.rs`) - Glob/grep/ls search operations +- **terminal** (`src/terminal.rs`) - Command execution and output capture + +### Security Modules +- **permission** (`src/permission.rs`) - Permission prompt system and decision caching +- **audit** (`src/audit.rs`) - Structured audit logging + +### Integration Modules +- **tool_call** (`src/tool_call.rs`) - ACP tool call metadata and helpers + +## Security Model + +All operations are subject to: + +1. **Sandbox Containment**: Operations restricted to configured allowed roots +2. **Blocklist Enforcement**: Sensitive paths explicitly denied +3. **Permission Prompts**: Write/execute operations require user approval (configurable) +4. **Resource Limits**: Bounded file sizes, search results, terminal output +5. **Audit Logging**: All operations logged with structured context + +## Platform Support + +Windows is a first-class platform with explicit support for: +- Backslash and forward slash separators +- Drive letters (C:\, D:\) +- UNC paths (\\server\share\...) +- Long path prefixes (\\?\...) +- MINGW-style paths (/c/Users/...) +- Junctions and symlinks +- cmd.exe and PowerShell + +All path normalization and containment logic handles these cases. + +## Configuration (Future) + +Configuration types will be implemented in SCAFF-05: + +```rust +pub struct SandboxConfig { + pub allowed_roots: Vec<PathBuf>, + pub blocked_paths: Vec<String>, // glob patterns + pub allow_symlink_escape: bool, + pub follow_symlinks_within_roots: bool, + pub read_enabled: bool, + pub write_enabled: bool, + pub max_read_bytes: u64, + pub max_write_bytes: u64, + // ... more fields +} + +pub struct PermissionConfig { + pub mode: PermissionMode, // Ask | Whitelist | Yolo + pub remember_decisions: bool, + pub remember_ttl_secs: u64, + pub scope: DecisionScope, // PerConnector | PerSession + pub whitelist: WhitelistConfig, +} + +pub struct TerminalConfig { + pub enabled: bool, + pub default_cwd: PathBuf, + pub env_allowlist: Vec<String>, + pub command_blocklist: Vec<String>, + pub output_byte_limit: u64, + pub max_runtime_secs: u64, +} + +pub struct SearchConfig { + pub max_results: u32, + pub max_bytes: u64, + pub default_include_globs: Vec<String>, + pub default_exclude_globs: Vec<String>, +} + +pub struct EmbeddingConfig { + pub max_embed_bytes: u64, + pub allow_resource_link: bool, + pub redact_patterns: Vec<String>, // regex patterns +} +``` + +## Implementation Status + +**Current Phase**: SCAFF-01 (Scaffolding) - Complete + +All modules are stubs with `unimplemented!()`. Implementation will proceed in phases: + +### Phase 1: Path and Sandbox (Protocol tasks) +- Path normalization (Windows, Linux, macOS) +- Canonical path resolution (symlinks, junctions) +- Containment checking +- Blocklist matching + +### Phase 2: Tool Operations (Tool tasks) +- File read with line ranges +- File write with atomic operations +- File edit with diff generation +- Glob search +- Grep search with context +- LS directory listing +- Terminal creation and lifecycle +- Terminal output capture with ring buffer + +### Phase 3: Security and Integration (Integration tasks) +- Permission prompt flow +- Decision caching with TTL +- Audit logging with structured fields +- ACP tool call event generation +- Tool title formatting + +## Integration Points + +### dirigent_core +- Uses `dirigent_tools` for ACP client request handlers +- Passes configuration from connector params +- Routes tool operations through sandbox + +### api +- May use `dirigent_tools` for server function implementations +- Exposes configuration endpoints +- Handles permission prompts via UI + +### web +- Displays tool calls with titles, locations, diffs +- Renders permission prompts +- Shows audit logs + +## Testing Strategy + +Test infrastructure will be set up in SCAFF-03: +- Unit tests for each module +- Integration tests with mocker +- Cross-platform tests (Windows, Linux, macOS) +- Golden transcript fixtures +- Test utilities for temp directories, file comparison, sandboxed environments + +## Related Packages +- **dirigent_core** - Uses this package for tool operations +- **api** - May use this package for server functions +- **dirigent_protocol** - Shared event types for tool calls + +## References +- Task file: `docs/building/04_acp_client/04_tasks_00_scaffolding_and_finishing.md` +- Tools research: `docs/building/04_acp_client/03_tools_research.md` +- Sandbox spec: `docs/building/04_acp_client/03_fs_sandboxing_and_permissions_spec.md` +- Roadmap: `docs/building/04_acp_client/roadmap.md` diff --git a/crates/dirigent_tools/Cargo.toml b/crates/dirigent_tools/Cargo.toml new file mode 100644 index 0000000..9d456bc --- /dev/null +++ b/crates/dirigent_tools/Cargo.toml @@ -0,0 +1,65 @@ +[package] +name = "dirigent_tools" +version = "0.1.0" +edition = "2021" + +[lib] +path = "src/lib.rs" + +[dependencies] +# Policy types from dirigent_fermata +dirigent_fermata = { workspace = true } + +# Async runtime — minimal feature set; no `net` so wasm consumers don't pull mio. +tokio = { version = "1", features = ["rt", "rt-multi-thread", "macros", "sync", "fs", "io-util", "time", "process"] } + +# Bytes for tool result content +bytes = { version = "1", features = ["serde"] } +serde_bytes = "0.11" + +# Async traits (object-safe async fn) +async-trait = "0.1" + +# JSON Schema generation for tool input types +schemars = "1.0" + +# Serialization +serde = { version = "1.0", features = ["derive", "rc"] } +serde_json = "1.0" + +# Error handling +anyhow = "1.0" +thiserror = "2.0" + +# Logging +tracing = "0.1" + +# Text processing and pattern matching +regex = "1.0" +globset = "0.4" + +# Diff generation +similar = { version = "2.6", features = ["inline"] } + +# Path utilities for cross-platform path handling (especially Windows) +# Using dunce for reliable Windows path normalization +dunce = "1.0" + +# Directory traversal for glob/search operations +walkdir = "2.0" + +# Lazy static for global state +lazy_static = "1.5" + +# Configuration serialization (for TOML/JSON config support) +toml = { version = "0.8", optional = true } + +[dev-dependencies] +# Test dependencies +tokio = { version = "1", features = ["full", "test-util"] } +tempfile = "3.0" +toml = "0.8" + +[features] +default = [] +config = ["dep:toml"] diff --git a/crates/dirigent_tools/IMPLEMENTATION_STATUS.md b/crates/dirigent_tools/IMPLEMENTATION_STATUS.md new file mode 100644 index 0000000..f40a3ac --- /dev/null +++ b/crates/dirigent_tools/IMPLEMENTATION_STATUS.md @@ -0,0 +1,238 @@ +# dirigent_tools Implementation Status + +This document tracks the implementation status of the tools package, showing which features are complete and which are stubbed pending implementation. + +## Status Legend + +- ✅ **Completed** - Fully implemented with tests +- ⏸️ **Stubbed** - Function signatures present, returns `unimplemented!()` +- 🚧 **In Progress** - Actively being implemented + +--- + +## Completed Features + +### Core Infrastructure + +- ✅ **Error Types** (`src/error.rs`) - Complete error handling with user-facing messages +- ✅ **Configuration Types** (`src/config.rs`) - All config structures with validation +- ✅ **Path Validation** (`src/path/validate.rs`) - Path normalization and validation +- ✅ **Path Canonicalization** (`src/path/canonicalize.rs`) - Cross-platform canonical paths +- ✅ **Path Containment** (`src/path/containment.rs`) - Sandbox boundary checking +- ✅ **Path Blocklist** (`src/path/blocklist.rs`) - Glob-based path blocking + +--- + +## Stubbed Features (TODO) + +### File Operations + +- ⏸️ **Read Text File** (`src/fs/read.rs`) - TOOLS-FS-01 + - Function: `read_text_file()` + - Types: `ReadTextFileRequest`, `ReadTextFileResponse` + - Features: Line/limit support, UTF-8 validation, sandboxing + +- ⏸️ **Write Text File** (`src/fs/write.rs`) - TOOLS-FS-02 + - Function: `write_text_file()`, `normalize_eol()` + - Types: `WriteTextFileRequest`, `WriteTextFileResponse` + - Features: Atomic writes, EOL normalization, permission checks + +- ⏸️ **Generate Diff** (`src/fs/diff.rs`) - TOOLS-FS-03 + - Function: `generate_diff()` + - Features: Unified diff generation, edge case handling + +- ⏸️ **Edit File** (`src/fs/edit.rs`) - TOOLS-FS-04 + - Function: `edit_file()` + - Types: `EditFileRequest`, `EditFileResponse`, `EditOperation` + - Features: Read + transform + write, automatic diff generation + +### Search Operations + +- ⏸️ **Directory Listing** (`src/search/ls.rs`) - TOOLS-SEARCH-01 + - Function: `ls()` + - Types: `LsRequest`, `LsResponse`, `LsEntry`, `FileKind` + - Features: Sandboxing, exclude globs, file metadata + +- ⏸️ **Glob Search** (`src/search/glob.rs`) - TOOLS-SEARCH-02 + - Function: `glob_search()` + - Types: `GlobRequest`, `GlobResponse` + - Features: Pattern matching, result limits, exclude patterns + +- ⏸️ **Content Search** (`src/search/grep.rs`) - TOOLS-SEARCH-03 + - Function: `grep_search()` + - Types: `GrepRequest`, `GrepResponse`, `GrepMatch` + - Features: Regex search, context lines, binary file detection + +### Terminal Operations + +- ⏸️ **Create Terminal** (`src/terminal/create.rs`) - TOOLS-TERM-01 + - Function: `create_terminal()` + - Types: `CreateTerminalRequest`, `CreateTerminalResponse`, `EnvVar` + - Features: Process spawning, CWD validation, env filtering, output capture + +- ⏸️ **Get Terminal Output** (`src/terminal/output.rs`) - TOOLS-TERM-02 + - Function: `get_terminal_output()` + - Types: `TerminalOutputRequest`, `TerminalOutputResponse` + - Features: Ring buffer snapshot, truncation tracking, exit status + +- ⏸️ **Wait for Exit** (`src/terminal/wait.rs`) - TOOLS-TERM-03 + - Function: `wait_for_terminal_exit()` + - Types: `WaitForTerminalExitRequest`, `WaitForTerminalExitResponse` + - Features: Blocking wait, timeout enforcement + +- ⏸️ **Kill Terminal** (`src/terminal/kill.rs`) - TOOLS-TERM-04 + - Function: `kill_terminal()` + - Types: `KillTerminalCommandRequest`, `KillTerminalCommandResponse` + - Features: Forceful termination, idempotent operations + +- ⏸️ **Release Terminal** (`src/terminal/release.rs`) - TOOLS-TERM-05 + - Function: `release_terminal()` + - Types: `ReleaseTerminalRequest`, `ReleaseTerminalResponse` + - Features: Resource cleanup, registry removal + +- ⏸️ **Ring Buffer** (`src/terminal/ring_buffer.rs`) - TOOLS-TERM-01 + - Type: `RingBuffer` + - Features: Fixed-size circular buffer, UTF-8 boundary handling + +### Security and Permissions + +- ⏸️ **Permission System** (`src/permission.rs`) - TOOLS-PERM-01 through TOOLS-PERM-05 + - Functions: Permission checks, decision caching, whitelist matching + - Features: Ask/whitelist/yolo modes, per-connector/session scope + +### Observability + +- ⏸️ **Audit Logging** (`src/audit.rs`) - TOOLS-AUDIT-01 through TOOLS-AUDIT-03 + - Functions: Structured audit logging, metrics + - Features: Operation tracking, performance monitoring + +### UI Integration + +- ⏸️ **Tool Call Rendering** (`src/tool_call.rs`) - TOOLS-UI-01 through TOOLS-UI-03 + - Functions: ToolKind mapping, title generation, diff formatting + - Features: ACP integration, location extraction + +--- + +## When Implementing a Stubbed Feature + +Follow these steps to properly implement a feature: + +1. **Implement the Function** + - Replace `unimplemented!()` with actual logic + - Follow the task specification in `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md` + - Reference the security spec in `docs/building/04_acp_client/03_fs_sandboxing_and_permissions_spec.md` + +2. **Add Comprehensive Tests** + - Unit tests for happy paths + - Edge case testing + - Error case validation + - Cross-platform tests (Windows, Linux, macOS) + +3. **Enable Capability** (if applicable) + - Update `ClientCapabilities::default()` in `packages/dirigent_core/src/acp/connector_state.rs` + - Change from `None` to `Some(...)` for the relevant capability + - Document why the capability is now safe to advertise + +4. **Update This Document** + - Move feature from "Stubbed" to "Completed" + - Add any notes or caveats + - Update completion percentage + +5. **Integration Testing** + - Test with dirigent-acp-mocker + - Verify tool calls work end-to-end + - Test UI rendering if applicable + +--- + +## Capability Enablement + +**CRITICAL**: Do NOT enable capabilities in `ClientCapabilities::default()` until the corresponding tools are fully implemented and tested. + +### Current Capability Defaults (dirigent_core) + +```rust +impl Default for ClientCapabilities { + fn default() -> Self { + Self { + fs: None, // ❌ Disabled - file operations not yet implemented + terminal: None, // ❌ Disabled - terminal operations not yet implemented + _meta: None, + } + } +} +``` + +### When to Enable + +**File System (`fs`)**: +- ✅ TOOLS-FS-01 (read) is implemented and tested +- ✅ TOOLS-FS-02 (write) is implemented and tested (or omit `write_text_file: Some(true)`) +- ✅ Path validation and sandboxing are complete +- ✅ Permission system is integrated + +**Terminal (`terminal`)**: +- ✅ TOOLS-TERM-01 through TOOLS-TERM-05 are implemented +- ✅ Ring buffer is working correctly +- ✅ Permission checks are in place +- ✅ Cross-platform spawn/kill is tested + +--- + +## Progress Summary + +- **Total Features**: 25 +- **Completed**: 6 (24%) +- **Stubbed**: 19 (76%) +- **In Progress**: 0 + +--- + +## Task References + +All task specifications are in: +- `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md` + +Task ID format: `TOOLS-{AREA}-{NUMBER}` +- `TOOLS-SCAFFOLD-XX` - Package structure +- `TOOLS-PATH-XX` - Path operations +- `TOOLS-FS-XX` - File operations +- `TOOLS-SEARCH-XX` - Search operations +- `TOOLS-TERM-XX` - Terminal operations +- `TOOLS-PERM-XX` - Permission system +- `TOOLS-AUDIT-XX` - Audit and observability +- `TOOLS-UI-XX` - UI integration + +--- + +## Next Steps + +To continue implementation: + +1. **Protocol Integration** (prerequisite) + - Ensure ACP protocol handlers can call tool functions + - Wire up request/response types + - Handle tool call lifecycle (pending → in_progress → completed/failed) + +2. **File Operations** (high priority for UI features) + - TOOLS-FS-01: Read text file + - TOOLS-FS-02: Write text file + - TOOLS-FS-03: Generate diffs + +3. **Search Operations** (useful for development) + - TOOLS-SEARCH-01: Directory listing + - TOOLS-SEARCH-02: Glob search + +4. **Permission System** (security-critical) + - TOOLS-PERM-01: Permission checks + - TOOLS-PERM-02: Decision caching + +5. **Terminal Operations** (complex, later priority) + - TOOLS-TERM-01: Create terminal + - TOOLS-TERM-02: Get output + +--- + +*Last Updated: 2025-11-12* +*Status: All core infrastructure complete, tool operations stubbed and ready for implementation* diff --git a/crates/dirigent_tools/README.md b/crates/dirigent_tools/README.md new file mode 100644 index 0000000..01d7cb7 --- /dev/null +++ b/crates/dirigent_tools/README.md @@ -0,0 +1,100 @@ +# dirigent_tools + +Tool implementations for ACP (Agent-Client Protocol) client operations with sandboxing and permission management. + +## Overview + +This package provides the core tool operations for interacting with the filesystem, terminal, and search capabilities in a secure, sandboxed environment. It is designed to support ACP-compliant agents (like Claude) by implementing the client-side tool handlers with safety guarantees. + +## Features + +### File Operations +- **Read** text files with line range support +- **Write** text files with atomic writes and parent directory creation +- **Edit** files with diff generation for previews + +### Terminal Operations +- **Create** terminals and spawn commands +- **Capture** output with byte limits and ring-buffer truncation +- **Wait** for command completion +- **Kill** running commands +- **Release** terminal resources + +### Search Operations +- **Glob** file matching with patterns +- **Grep** content search with regex +- **LS** directory listing + +### Security Features +- **Sandboxing**: All operations restricted to configured allowed roots +- **Blocklists**: Explicit deny patterns for sensitive paths +- **Permissions**: Configurable prompt modes (ask, whitelist, yolo) +- **Audit Logging**: Structured logs for all operations +- **Resource Limits**: Bounded file sizes, search results, and terminal output + +## Platform Support + +Windows is a first-class platform: +- Handles Windows paths (backslashes, drive letters, UNC shares, `\\?\` prefixes) +- Supports MINGW-style paths (`/c/...`) +- Works with cmd.exe and PowerShell +- Normalizes path separators for consistent policy enforcement + +All tests run on Windows, Linux, and macOS. + +## Status + +**Phase**: Scaffolding (SCAFF-01) - Structure created, implementation pending + +All modules are stubs with `unimplemented!()` placeholders. Actual implementation will occur in subsequent phases: +- **Protocol tasks**: Path normalization, sandbox enforcement +- **Tool tasks**: File operations, terminal execution, search +- **Integration tasks**: Permission prompts, audit logging, ACP event generation + +## Configuration + +See `src/config.rs` for configuration types (to be implemented in SCAFF-05): +- `SandboxConfig` - Filesystem sandboxing +- `PermissionConfig` - Permission prompts and caching +- `TerminalConfig` - Terminal limits and restrictions +- `SearchConfig` - Search result limits +- `EmbeddingConfig` - File embedding thresholds + +## Usage Example (Future) + +```rust +use dirigent_tools::{fs, SandboxConfig}; + +// Configure sandbox +let sandbox = SandboxConfig { + allowed_roots: vec!["C:/work/project".to_string()], + blocked_paths: vec!["**/.env".to_string()], + // ... other fields +}; + +// Read a file (within sandbox) +let content = fs::read_text_file( + Path::new("C:/work/project/src/main.rs"), + None, // line + None, // limit +)?; +``` + +## Testing + +Test infrastructure will be set up in SCAFF-03. Tests will cover: +- Path normalization (especially Windows paths) +- Sandbox containment +- Permission flows +- File operations +- Terminal lifecycle +- Search operations + +## Documentation + +- **CLAUDE.md**: Package context for AI assistants +- **docs/**: API documentation (to be generated with `cargo doc`) + +## License + +Same as parent Dirigent project. diff --git a/crates/dirigent_tools/STUBBED_FUNCTIONS.md b/crates/dirigent_tools/STUBBED_FUNCTIONS.md new file mode 100644 index 0000000..2783984 --- /dev/null +++ b/crates/dirigent_tools/STUBBED_FUNCTIONS.md @@ -0,0 +1,260 @@ +# Stubbed Functions Reference + +Quick reference for all stubbed functions with task IDs and signatures. + +--- + +## File Operations + +### TOOLS-FS-01: Read Text File +**File**: `src/fs/read.rs` + +```rust +pub async fn read_text_file( + request: ReadTextFileRequest, + config: &SandboxConfig, +) -> ToolResult<ReadTextFileResponse> +``` + +**Types**: +- `ReadTextFileRequest { path: String, line: Option<usize>, limit: Option<usize> }` +- `ReadTextFileResponse { content: String }` + +--- + +### TOOLS-FS-02: Write Text File +**File**: `src/fs/write.rs` + +```rust +pub async fn write_text_file( + request: WriteTextFileRequest, + config: &SandboxConfig, +) -> ToolResult<WriteTextFileResponse> + +pub fn normalize_eol(content: &str, policy: EolPolicy) -> String +``` + +**Types**: +- `WriteTextFileRequest { path: String, content: String }` +- `WriteTextFileResponse {}` + +--- + +### TOOLS-FS-03: Generate Diff +**File**: `src/fs/diff.rs` + +```rust +pub fn generate_diff(old_content: &str, new_content: &str, path: &Path) -> String +``` + +--- + +### TOOLS-FS-04: Edit File +**File**: `src/fs/edit.rs` + +```rust +pub async fn edit_file( + request: EditFileRequest, + config: &SandboxConfig, +) -> ToolResult<EditFileResponse> +``` + +**Types**: +- `EditFileRequest { path: String, edits: Vec<EditOperation> }` +- `EditFileResponse { diff: String }` +- `EditOperation::Replace { old_text: String, new_text: String, replace_all: bool }` +- `EditOperation::Patch { diff: String }` + +--- + +## Search Operations + +### TOOLS-SEARCH-01: Directory Listing +**File**: `src/search/ls.rs` + +```rust +pub async fn ls(request: LsRequest, config: &SearchConfig) -> ToolResult<LsResponse> +``` + +**Types**: +- `LsRequest { path: String }` +- `LsResponse { entries: Vec<LsEntry> }` +- `LsEntry { path: PathBuf, kind: FileKind, size: Option<u64> }` +- `FileKind: File | Dir | Symlink` + +--- + +### TOOLS-SEARCH-02: Glob Search +**File**: `src/search/glob.rs` + +```rust +pub async fn glob_search( + request: GlobRequest, + config: &SearchConfig, +) -> ToolResult<GlobResponse> +``` + +**Types**: +- `GlobRequest { path: String, pattern: String, exclude: Option<Vec<String>>, max_results: Option<u32> }` +- `GlobResponse { matches: Vec<PathBuf>, truncated: bool }` + +--- + +### TOOLS-SEARCH-03: Content Search (Grep) +**File**: `src/search/grep.rs` + +```rust +pub async fn grep_search( + request: GrepRequest, + config: &SearchConfig, +) -> ToolResult<GrepResponse> +``` + +**Types**: +- `GrepRequest { path: String, pattern: String, file_pattern: Option<String>, ignore_case: bool, context_before: u32, context_after: u32, max_results: Option<u32> }` +- `GrepResponse { matches: Vec<GrepMatch>, truncated: bool }` +- `GrepMatch { path: PathBuf, line_number: usize, line: String, context_before: Vec<String>, context_after: Vec<String> }` + +--- + +## Terminal Operations + +### TOOLS-TERM-01: Create Terminal +**File**: `src/terminal/create.rs` + +```rust +pub async fn create_terminal( + request: CreateTerminalRequest, + config: &TerminalConfig, +) -> ToolResult<CreateTerminalResponse> +``` + +**Types**: +- `CreateTerminalRequest { command: String, args: Vec<String>, cwd: Option<String>, env: Option<Vec<EnvVar>>, output_byte_limit: Option<u64> }` +- `CreateTerminalResponse { terminal_id: String }` +- `EnvVar { name: String, value: String }` + +--- + +### TOOLS-TERM-02: Get Terminal Output +**File**: `src/terminal/output.rs` + +```rust +pub async fn get_terminal_output( + request: TerminalOutputRequest, + config: &TerminalConfig, +) -> ToolResult<TerminalOutputResponse> +``` + +**Types**: +- `TerminalOutputRequest { terminal_id: String }` +- `TerminalOutputResponse { output: String, truncated: bool, exit_status: Option<i32> }` + +--- + +### TOOLS-TERM-03: Wait for Exit +**File**: `src/terminal/wait.rs` + +```rust +pub async fn wait_for_terminal_exit( + request: WaitForTerminalExitRequest, + config: &TerminalConfig, +) -> ToolResult<WaitForTerminalExitResponse> +``` + +**Types**: +- `WaitForTerminalExitRequest { terminal_id: String }` +- `WaitForTerminalExitResponse { exit_status: i32 }` + +--- + +### TOOLS-TERM-04: Kill Terminal +**File**: `src/terminal/kill.rs` + +```rust +pub async fn kill_terminal( + request: KillTerminalCommandRequest, + config: &TerminalConfig, +) -> ToolResult<KillTerminalCommandResponse> +``` + +**Types**: +- `KillTerminalCommandRequest { terminal_id: String }` +- `KillTerminalCommandResponse {}` + +--- + +### TOOLS-TERM-05: Release Terminal +**File**: `src/terminal/release.rs` + +```rust +pub async fn release_terminal( + request: ReleaseTerminalRequest, + config: &TerminalConfig, +) -> ToolResult<ReleaseTerminalResponse> +``` + +**Types**: +- `ReleaseTerminalRequest { terminal_id: String }` +- `ReleaseTerminalResponse {}` + +--- + +### Ring Buffer +**File**: `src/terminal/ring_buffer.rs` + +```rust +impl RingBuffer { + pub fn new(capacity: usize) -> Self +} +``` + +**Note**: Full interface to be defined during implementation. + +--- + +## Summary + +**Total Stubbed Functions**: 15 main functions + 1 helper + +**File Operations**: 4 functions +- read_text_file, write_text_file, normalize_eol, generate_diff, edit_file + +**Search Operations**: 3 functions +- ls, glob_search, grep_search + +**Terminal Operations**: 6 functions + ring buffer +- create_terminal, get_terminal_output, wait_for_terminal_exit, kill_terminal, release_terminal, RingBuffer::new + +All functions return `unimplemented!()` with clear task IDs in the panic message. + +--- + +## Implementation Order + +Recommended implementation order based on dependencies: + +1. **File Operations** (UI features need these first) + - TOOLS-FS-01: read_text_file + - TOOLS-FS-02: write_text_file + normalize_eol + - TOOLS-FS-03: generate_diff + +2. **Search Operations** (useful for development) + - TOOLS-SEARCH-01: ls + - TOOLS-SEARCH-02: glob_search + - TOOLS-SEARCH-03: grep_search + +3. **Edit Operations** (builds on read/write) + - TOOLS-FS-04: edit_file + +4. **Terminal Operations** (more complex, can be deferred) + - TOOLS-TERM-01: create_terminal + ring_buffer + - TOOLS-TERM-02: get_terminal_output + - TOOLS-TERM-03: wait_for_terminal_exit + - TOOLS-TERM-04: kill_terminal + - TOOLS-TERM-05: release_terminal + +--- + +*Generated: 2025-11-12* +*Package: dirigent_tools v0.1.0* diff --git a/crates/dirigent_tools/examples/config_example.json b/crates/dirigent_tools/examples/config_example.json new file mode 100644 index 0000000..6fbad02 --- /dev/null +++ b/crates/dirigent_tools/examples/config_example.json @@ -0,0 +1,53 @@ +{ + "sandbox": { + "allowed_roots": ["C:/work/project", "C:/work/shared"], + "blocked_paths": ["**/.env", "**/secrets/**", "**/*.key"], + "allow_symlink_escape": false, + "follow_symlinks_within_roots": true, + "read_enabled": true, + "write_enabled": true, + "max_read_bytes": 1048576, + "max_write_bytes": 1048576, + "eol_policy": "preserve", + "encoding": "utf-8" + }, + "permissions": { + "mode": "whitelist", + "remember_decisions": true, + "remember_ttl_secs": 86400, + "scope": "per_connector", + "whitelist": { + "write_paths": ["C:/work/project/**"], + "execute_commands": ["cargo", "npm", "git", "python"] + } + }, + "terminal": { + "enabled": true, + "default_cwd": "C:/work/project", + "env_allowlist": ["RUST_LOG", "NODE_ENV", "PATH"], + "command_blocklist": ["rm", "rd", "format", "mkfs*", "del /f /q *"], + "output_byte_limit": 200000, + "max_runtime_secs": 3600 + }, + "search": { + "max_results": 5000, + "max_bytes": 1000000, + "default_include_globs": [], + "default_exclude_globs": [ + "**/target/**", + "**/.git/**", + "**/node_modules/**", + "**/__pycache__/**", + "**/.venv/**" + ] + }, + "embedding": { + "max_embed_bytes": 256000, + "allow_resource_link": true, + "redact_patterns": [ + "(?i)(api[_-]?key|password|secret|token)[:\\s]*['\"]?([a-zA-Z0-9_\\-\\.]+)['\"]?" + ], + "snippet_strategy": "head_tail", + "max_files_per_prompt": 10 + } +} diff --git a/crates/dirigent_tools/examples/config_example.toml b/crates/dirigent_tools/examples/config_example.toml new file mode 100644 index 0000000..59c1a74 --- /dev/null +++ b/crates/dirigent_tools/examples/config_example.toml @@ -0,0 +1,124 @@ +# Example configuration for dirigent_tools (Phase 03 features) +# This shows all available configuration options with typical values. + +# ============================================================================= +# Sandbox Configuration +# ============================================================================= +[sandbox] +# Absolute paths where file operations are allowed +allowed_roots = ["C:/work/project", "C:/work/shared"] + +# Patterns for paths that are blocked even within allowed roots +blocked_paths = ["**/.env", "**/secrets/**", "**/*.key"] + +# Whether to allow symlinks to point outside allowed roots (dangerous!) +allow_symlink_escape = false + +# Whether to follow symlinks within allowed roots +follow_symlinks_within_roots = true + +# Enable read/write operations +read_enabled = true +write_enabled = true + +# Maximum bytes per operation +max_read_bytes = 1_048_576 # 1 MB +max_write_bytes = 1_048_576 # 1 MB + +# Line ending policy: "preserve" | "lf" | "crlf" +eol_policy = "preserve" + +# Text encoding (only "utf-8" supported in Phase 03) +encoding = "utf-8" + +# ============================================================================= +# Permission Configuration +# ============================================================================= +[permissions] +# Permission mode: "ask" | "whitelist" | "yolo" +# - ask: Prompt for every sensitive operation +# - whitelist: Auto-approve whitelisted operations, prompt for others +# - yolo: Auto-approve all (with audit logging) +mode = "whitelist" + +# Whether to remember permission decisions +remember_decisions = true + +# TTL for cached decisions (seconds) +remember_ttl_secs = 86400 # 24 hours + +# Decision scope: "per_connector" | "per_session" +scope = "per_connector" + +# Whitelist configuration (for whitelist mode) +[permissions.whitelist] +# Paths that are safe for write operations +write_paths = ["C:/work/project/**"] + +# Commands that are safe to execute +execute_commands = ["cargo", "npm", "git", "python"] + +# ============================================================================= +# Terminal Configuration +# ============================================================================= +[terminal] +# Enable terminal operations +enabled = true + +# Default working directory (must be within allowed roots) +default_cwd = "C:/work/project" + +# Environment variables that are allowed in spawned processes +env_allowlist = ["RUST_LOG", "NODE_ENV", "PATH"] + +# Commands that are blocked (best-effort) +command_blocklist = ["rm", "rd", "format", "mkfs*", "del /f /q *"] + +# Maximum bytes to capture from output (ring buffer) +output_byte_limit = 200_000 + +# Maximum runtime before killing command (seconds) +max_runtime_secs = 3_600 # 1 hour + +# ============================================================================= +# Search Configuration +# ============================================================================= +[search] +# Maximum number of search results +max_results = 5_000 + +# Maximum total bytes in search results +max_bytes = 1_000_000 # 1 MB + +# Default include patterns (empty = include all) +default_include_globs = [] + +# Default exclude patterns +default_exclude_globs = [ + "**/target/**", + "**/.git/**", + "**/node_modules/**", + "**/__pycache__/**", + "**/.venv/**" +] + +# ============================================================================= +# Embedding Configuration +# ============================================================================= +[embedding] +# Maximum bytes to embed per file as resource (vs resource_link) +max_embed_bytes = 256_000 + +# Whether to allow resource_link for large files +allow_resource_link = true + +# Regex patterns for redacting secrets in embedded content +redact_patterns = [ + "(?i)(api[_-]?key|password|secret|token)[:\\s]*['\"]?([a-zA-Z0-9_\\-\\.]+)['\"]?" +] + +# Snippet strategy: "head_tail" | "head_only" | "tail_only" +snippet_strategy = "head_tail" + +# Maximum files to embed in a single prompt +max_files_per_prompt = 10 diff --git a/crates/dirigent_tools/src/audit.rs b/crates/dirigent_tools/src/audit.rs new file mode 100644 index 0000000..87e7464 --- /dev/null +++ b/crates/dirigent_tools/src/audit.rs @@ -0,0 +1,58 @@ +//! Audit logging for sensitive tool operations. +//! +//! This module provides structured logging for: +//! - File read/write operations +//! - Terminal command execution +//! - Permission decisions +//! - Sandbox violations +//! +//! All audit logs include: +//! - Timestamp +//! - User/session context +//! - Operation type +//! - Parameters (sanitized) +//! - Outcome (success/error) +//! +//! TODO: Implement audit logging + +use tracing::{info, warn}; + +/// Log a file read operation. +/// +/// TODO: Implement with structured fields +pub fn log_file_read(_path: &str, _success: bool) { + // Placeholder - will use tracing with structured fields + info!("File read audit log placeholder"); +} + +/// Log a file write operation. +/// +/// TODO: Implement with structured fields +pub fn log_file_write(_path: &str, _success: bool) { + // Placeholder - will use tracing with structured fields + info!("File write audit log placeholder"); +} + +/// Log a terminal command execution. +/// +/// TODO: Implement with structured fields +pub fn log_terminal_exec(_command: &str, _success: bool) { + // Placeholder - will use tracing with structured fields + info!("Terminal exec audit log placeholder"); +} + +/// Log a permission decision. +/// +/// TODO: Implement with structured fields +pub fn log_permission_decision(_operation: &str, _allowed: bool) { + // Placeholder - will use tracing with structured fields + info!("Permission decision audit log placeholder"); +} + +/// Log a sandbox violation attempt. +/// +/// TODO: Implement with structured fields +pub fn log_sandbox_violation(_path: &str, _reason: &str) { + // Placeholder - will use tracing with structured fields + warn!("Sandbox violation audit log placeholder"); +} diff --git a/crates/dirigent_tools/src/config.rs b/crates/dirigent_tools/src/config.rs new file mode 100644 index 0000000..c67ec4f --- /dev/null +++ b/crates/dirigent_tools/src/config.rs @@ -0,0 +1,425 @@ +//! Configuration types for Phase 03 features. +//! +//! This module defines configuration structures for: +//! - `SandboxConfig` - Filesystem sandboxing configuration +//! - `PermissionConfig` - Permission prompt and decision caching +//! - `TerminalConfig` - Terminal/command execution limits +//! - `SearchConfig` - Search operation limits and defaults +//! - `EmbeddingConfig` - File embedding thresholds and policies + +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +// ============================================================================= +// Sandbox Configuration +// ============================================================================= + +/// Filesystem sandboxing configuration. +/// +/// Determines which paths are accessible and how symlinks are handled. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct SandboxConfig { + /// Absolute paths that are allowed for file operations. + /// + /// Operations outside these roots will be rejected. + pub allowed_roots: Vec<PathBuf>, + + /// Path patterns that are explicitly blocked even if within allowed roots. + /// + /// Supports glob patterns like "**/.env", "**/secrets/**". + pub blocked_paths: Vec<String>, + + /// Whether to allow symlinks to escape allowed roots. + /// + /// If false (recommended), symlinks pointing outside allowed roots are rejected. + pub allow_symlink_escape: bool, + + /// Whether to follow symlinks within allowed roots. + /// + /// If true, symlinks within allowed roots are followed. + pub follow_symlinks_within_roots: bool, + + /// Enable read operations. + pub read_enabled: bool, + + /// Enable write operations. + pub write_enabled: bool, + + /// Maximum bytes to read in a single request. + /// + /// Soft cap for previews. Default: 1 MB. + pub max_read_bytes: u64, + + /// Maximum bytes to write in a single request. + /// + /// Default: 1 MB. + pub max_write_bytes: u64, + + /// End-of-line policy for file operations. + pub eol_policy: EolPolicy, + + /// Text encoding support. + /// + /// Currently only UTF-8 is supported. + pub encoding: String, +} + +impl Default for SandboxConfig { + fn default() -> Self { + Self { + allowed_roots: vec![], + blocked_paths: vec!["**/.env".to_string(), "**/secrets/**".to_string()], + allow_symlink_escape: false, + follow_symlinks_within_roots: true, + read_enabled: true, + write_enabled: false, + max_read_bytes: 1_048_576, // 1 MB + max_write_bytes: 1_048_576, // 1 MB + eol_policy: EolPolicy::Preserve, + encoding: "utf-8".to_string(), + } + } +} + +/// End-of-line handling policy. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum EolPolicy { + /// Preserve original line endings. + Preserve, + /// Normalize to LF (\n). + Lf, + /// Normalize to CRLF (\r\n). + Crlf, +} + +// ============================================================================= +// Permission Configuration +// ============================================================================= + +/// Permission prompt and decision caching configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct PermissionConfig { + /// Permission mode strategy. + pub mode: PermissionMode, + + /// Whether to remember permission decisions. + pub remember_decisions: bool, + + /// Time-to-live for cached decisions in seconds. + /// + /// Default: 86400 (24 hours). + pub remember_ttl_secs: u64, + + /// Scope of cached decisions. + pub scope: DecisionScope, + + /// Whitelist configuration for whitelist mode. + pub whitelist: WhitelistConfig, +} + +impl Default for PermissionConfig { + fn default() -> Self { + Self { + mode: PermissionMode::Whitelist, + remember_decisions: true, + remember_ttl_secs: 86_400, // 24 hours + scope: DecisionScope::PerConnector, + whitelist: WhitelistConfig::default(), + } + } +} + +/// Permission mode strategy. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum PermissionMode { + /// Prompt for every sensitive operation. + Ask, + /// Auto-approve whitelisted operations, prompt for others. + Whitelist, + /// Auto-approve all operations (with audit logging). + Yolo, +} + +/// Scope for cached permission decisions. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum DecisionScope { + /// Decisions persist per connector. + PerConnector, + /// Decisions persist only within a session. + PerSession, +} + +/// Whitelist configuration for auto-approved operations. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct WhitelistConfig { + /// Path patterns that are safe for write operations. + /// + /// Glob patterns like "C:/work/project/**". + pub write_paths: Vec<String>, + + /// Commands that are safe to execute. + /// + /// Glob patterns like "cargo", "npm", "git". + pub execute_commands: Vec<String>, +} + +impl Default for WhitelistConfig { + fn default() -> Self { + Self { + write_paths: vec![], + execute_commands: vec![], + } + } +} + +// ============================================================================= +// Terminal Configuration +// ============================================================================= + +/// Terminal/command execution configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct TerminalConfig { + /// Enable terminal operations. + pub enabled: bool, + + /// Default current working directory for terminal commands. + /// + /// Must be within an allowed sandbox root. + pub default_cwd: Option<PathBuf>, + + /// Environment variable names that are allowed. + /// + /// Only these variables can be set in spawned processes. + pub env_allowlist: Vec<String>, + + /// Command patterns that are blocked (best-effort). + /// + /// Glob patterns for dangerous commands like "rm", "format". + pub command_blocklist: Vec<String>, + + /// Maximum bytes to capture from terminal output (ring buffer). + /// + /// Default: 200,000 bytes. + pub output_byte_limit: u64, + + /// Maximum runtime for a terminal command in seconds. + /// + /// Commands exceeding this will be killed. Default: 3600 (1 hour). + pub max_runtime_secs: u64, +} + +impl Default for TerminalConfig { + fn default() -> Self { + Self { + enabled: true, + default_cwd: None, + env_allowlist: vec![ + "RUST_LOG".to_string(), + "NODE_ENV".to_string(), + "PATH".to_string(), + ], + command_blocklist: vec![ + "rm".to_string(), + "rd".to_string(), + "format".to_string(), + "mkfs*".to_string(), + ], + output_byte_limit: 200_000, + max_runtime_secs: 3_600, // 1 hour + } + } +} + +// ============================================================================= +// Search Configuration +// ============================================================================= + +/// Search operation configuration (glob, grep, ls). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct SearchConfig { + /// Maximum number of results to return. + /// + /// Default: 5,000. + pub max_results: u32, + + /// Maximum total bytes in search results. + /// + /// Default: 1,000,000 (1 MB). + pub max_bytes: u64, + + /// Default include patterns for searches. + pub default_include_globs: Vec<String>, + + /// Default exclude patterns for searches. + /// + /// Common directories to skip like target/, .git/, node_modules/. + pub default_exclude_globs: Vec<String>, +} + +impl Default for SearchConfig { + fn default() -> Self { + Self { + max_results: 5_000, + max_bytes: 1_000_000, // 1 MB + default_include_globs: vec![], + default_exclude_globs: vec![ + "**/target/**".to_string(), + "**/.git/**".to_string(), + "**/node_modules/**".to_string(), + "**/__pycache__/**".to_string(), + "**/.venv/**".to_string(), + ], + } + } +} + +// ============================================================================= +// Embedding Configuration +// ============================================================================= + +/// File embedding configuration for ACP prompt context. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] +pub struct EmbeddingConfig { + /// Maximum bytes to embed as ContentBlock::resource per file. + /// + /// Larger files will use resource_link instead. Default: 256,000. + pub max_embed_bytes: u64, + + /// Whether to allow resource_link for large files. + /// + /// If false, large files are rejected instead of linked. + pub allow_resource_link: bool, + + /// Regex patterns for redacting secrets in embedded content. + /// + /// Best-effort redaction (does not modify files on disk). + pub redact_patterns: Vec<String>, + + /// Snippet strategy when file is too large to embed fully. + pub snippet_strategy: SnippetStrategy, + + /// Maximum number of files to embed in a single prompt. + pub max_files_per_prompt: u32, +} + +impl Default for EmbeddingConfig { + fn default() -> Self { + Self { + max_embed_bytes: 256_000, + allow_resource_link: true, + redact_patterns: vec![ + // Common secret patterns + r"(?i)(api[_-]?key|password|secret|token)[:]\s*[']?([a-zA-Z0-9_\-\.]+)[']?".to_string(), + ], + snippet_strategy: SnippetStrategy::HeadTail, + max_files_per_prompt: 10, + } + } +} + +/// Strategy for creating snippets from large files. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum SnippetStrategy { + /// Include beginning and end of file. + HeadTail, + /// Include only the beginning. + HeadOnly, + /// Include only the end. + TailOnly, +} + +// ============================================================================= +// Validation +// ============================================================================= + +impl SandboxConfig { + /// Validate the sandbox configuration. + /// + /// Returns an error message if the configuration is invalid. + pub fn validate(&self) -> Result<(), String> { + if self.allowed_roots.is_empty() && (self.read_enabled || self.write_enabled) { + return Err("allowed_roots cannot be empty when read or write is enabled".to_string()); + } + + if self.encoding != "utf-8" { + return Err(format!("unsupported encoding: {}", self.encoding)); + } + + Ok(()) + } + + /// Normalize allowed roots to canonical paths. + /// + /// This should be called after loading configuration to ensure all roots are canonical. + /// Panics if any root cannot be canonicalized (configuration error). + pub fn normalize_roots(&mut self) { + self.allowed_roots = self.allowed_roots + .iter() + .map(|root| { + dunce::canonicalize(root) + .unwrap_or_else(|_| panic!("Failed to canonicalize allowed root: {:?}", root)) + }) + .collect(); + } +} + +impl TerminalConfig { + /// Validate the terminal configuration. + /// + /// Returns an error message if the configuration is invalid. + pub fn validate(&self) -> Result<(), String> { + if self.output_byte_limit == 0 { + return Err("output_byte_limit must be greater than 0".to_string()); + } + + if self.max_runtime_secs == 0 { + return Err("max_runtime_secs must be greater than 0".to_string()); + } + + Ok(()) + } +} + +impl SearchConfig { + /// Validate the search configuration. + /// + /// Returns an error message if the configuration is invalid. + pub fn validate(&self) -> Result<(), String> { + if self.max_results == 0 { + return Err("max_results must be greater than 0".to_string()); + } + + if self.max_bytes == 0 { + return Err("max_bytes must be greater than 0".to_string()); + } + + Ok(()) + } +} + +impl EmbeddingConfig { + /// Validate the embedding configuration. + /// + /// Returns an error message if the configuration is invalid. + pub fn validate(&self) -> Result<(), String> { + if self.max_embed_bytes == 0 { + return Err("max_embed_bytes must be greater than 0".to_string()); + } + + if self.max_files_per_prompt == 0 { + return Err("max_files_per_prompt must be greater than 0".to_string()); + } + + Ok(()) + } +} diff --git a/crates/dirigent_tools/src/dispatch.rs b/crates/dirigent_tools/src/dispatch.rs new file mode 100644 index 0000000..092e6cc --- /dev/null +++ b/crates/dirigent_tools/src/dispatch.rs @@ -0,0 +1,142 @@ +//! Dispatch orchestration: registry → SecurityFloor → permission → Tool::run. + +use crate::floor::{FloorDecision, SecurityFloor}; +use crate::registry::ToolRegistry; +use crate::tool::{AnyToolInput, ToolContext, ToolEventSink}; + +/// Outcome of a dispatch call. Matches `AnyTool::run`'s return shape: +/// `Ok` is a successful structured result, `Err` is a structured error. +pub type DispatchResult = Result<serde_json::Value, serde_json::Value>; + +/// Dispatch a tool call through the harness. +/// +/// 1. Resolve the tool via the registry (built-in vs dynamic, with collision +/// policy and per-client/per-protocol filters applied). +/// 2. Run the hardcoded [`SecurityFloor`]. Cannot be bypassed by settings. +/// 3. (Permission check is handled by the caller for now — the existing +/// [`crate::permission::check::check_permission`] takes a different +/// operation type and is wired in by the connector. This is documented in +/// `2026-04-28-tool-harness-design.md`.) +/// 4. Run the tool, awaiting its final result. +pub async fn dispatch( + registry: &ToolRegistry, + floor: &SecurityFloor, + name: &str, + input: serde_json::Value, + events: ToolEventSink, + ctx: &ToolContext, +) -> DispatchResult { + let tool = match registry.resolve(name, ctx) { + Some(t) => t, + None => return Err(serde_json::json!({ + "error": format!("unknown tool: {name}"), + })), + }; + + if let FloorDecision::Block { reason } = floor.check(name, &input, ctx) { + return Err(serde_json::json!({ "error": reason })); + } + + tool.run(AnyToolInput::Final(input), events, ctx).await +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig}; + use crate::permission::check::PermissionContext; + use crate::permission::whitelist::CompiledWhitelist; + use crate::registry::CollisionPolicy; + use crate::tool::{AnyTool, ClientKind, ProtocolKind, Tool, ToolInput, ToolKind}; + use async_trait::async_trait; + use schemars::JsonSchema; + use serde::{Deserialize, Serialize}; + use std::path::PathBuf; + use std::sync::Arc; + + #[derive(Serialize, Deserialize, JsonSchema)] + struct EchoIn { msg: String } + #[derive(Serialize, Deserialize)] + struct EchoOut { echoed: String } + + #[derive(Default)] + struct Echo; + + #[async_trait] + impl Tool for Echo { + type Input = EchoIn; + type Output = EchoOut; + const NAME: &'static str = "echo"; + fn kind() -> ToolKind { ToolKind::Other } + async fn run( + self: Arc<Self>, input: ToolInput<EchoIn>, + _e: ToolEventSink, _c: &ToolContext, + ) -> Result<EchoOut, EchoOut> { + let i = match input { ToolInput::Final(i) => i, _ => unreachable!() }; + Ok(EchoOut { echoed: i.msg }) + } + } + + fn ctx() -> ToolContext { + let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); + let pc = PermissionContext::new("c".to_string(), None, wl); + ToolContext::for_test( + "c", ClientKind::claude(), ProtocolKind::acp(), + PathBuf::from("/tmp"), + SandboxConfig::default(), PermissionConfig::default(), pc, + ) + } + + fn registry_with_echo() -> ToolRegistry { + let any: Arc<dyn AnyTool> = <Echo as Tool>::erase(Arc::new(Echo)); + ToolRegistry::new(vec![any], CollisionPolicy::BuiltInWins) + } + + #[tokio::test] + async fn dispatch_runs_known_tool() { + let r = registry_with_echo(); + let f = SecurityFloor::new(); + let (sink, _rx) = ToolEventSink::new(); + let out = dispatch(&r, &f, "echo", + serde_json::json!({ "msg": "hi" }), sink, &ctx()).await.unwrap(); + assert_eq!(out["echoed"], "hi"); + } + + #[tokio::test] + async fn dispatch_rejects_unknown_tool() { + let r = registry_with_echo(); + let f = SecurityFloor::new(); + let (sink, _rx) = ToolEventSink::new(); + let err = dispatch(&r, &f, "nope", + serde_json::json!({}), sink, &ctx()).await.unwrap_err(); + assert!(err["error"].as_str().unwrap().contains("unknown tool")); + } + + #[tokio::test] + async fn dispatch_blocks_on_security_floor() { + // We can't easily inject a "terminal" tool here, but we exercise the + // floor by registering a tool literally named "terminal". + #[derive(Default)] struct Term; + #[derive(Serialize, Deserialize, JsonSchema)] struct In { command: String } + #[derive(Serialize, Deserialize)] struct Out; + #[async_trait] + impl Tool for Term { + type Input = In; type Output = Out; + const NAME: &'static str = "terminal"; + fn kind() -> ToolKind { ToolKind::Execute } + async fn run( + self: Arc<Self>, _i: ToolInput<In>, + _e: ToolEventSink, _c: &ToolContext, + ) -> Result<Out, Out> { Ok(Out) } + } + + let any: Arc<dyn AnyTool> = <Term as Tool>::erase(Arc::new(Term)); + let r = ToolRegistry::new(vec![any], CollisionPolicy::BuiltInWins); + let f = SecurityFloor::new(); + let (sink, _rx) = ToolEventSink::new(); + let err = dispatch(&r, &f, "terminal", + serde_json::json!({ "command": "rm -rf /" }), sink, &ctx()) + .await.unwrap_err(); + assert!(err["error"].as_str().unwrap().to_lowercase().contains("blocked")); + } +} diff --git a/crates/dirigent_tools/src/embedding/decider.rs b/crates/dirigent_tools/src/embedding/decider.rs new file mode 100644 index 0000000..3fdcbc6 --- /dev/null +++ b/crates/dirigent_tools/src/embedding/decider.rs @@ -0,0 +1,467 @@ +//! Embedding decision logic for file attachments. +//! +//! This module implements the decision matrix for choosing between embedded content, +//! resource links, or snippets based on capabilities, size, and configuration. + +use std::path::Path; + +use crate::config::EmbeddingConfig; +use crate::error::ToolResult; +use crate::fs::file_type::detect_file_type; + +/// Strategy for including a file in an ACP prompt. +#[derive(Debug, Clone, PartialEq)] +pub enum EmbeddingStrategy { + /// Embed full file content as text resource. + EmbedText { + /// File content (possibly redacted). + content: String, + /// MIME type for the content. + mime_type: String, + }, + + /// Embed file as base64-encoded blob. + EmbedBlob { + /// Base64-encoded binary data. + data: Vec<u8>, + /// MIME type for the binary data. + mime_type: String, + }, + + /// Create a resource link (don't embed full content). + Link { + /// URI for the resource (e.g., dirigent://resource/<hash>). + uri: String, + /// Human-readable name (relative path). + name: String, + /// File size in bytes. + size: u64, + /// MIME type (if known). + mime_type: Option<String>, + }, + + /// Embed a snippet (head/tail) with optional link to full file. + Snippet { + /// Head portion of the file. + head: String, + /// Tail portion of the file. + tail: String, + /// Total size of the original file. + total_size: u64, + /// MIME type. + mime_type: String, + }, + + /// Deny the attachment (too large, blocked, etc.). + Deny { + /// Reason for denial. + reason: String, + }, +} + +/// Embedding decider. +/// +/// Decides how to include files in ACP prompts based on agent capabilities, +/// file properties, and configuration limits. +pub struct EmbeddingDecider { + /// Embedding configuration. + config: EmbeddingConfig, + /// Whether the agent supports embedded context. + agent_supports_embedded: bool, + /// Total bytes accumulated across all files in this prompt. + accumulated_bytes: usize, + /// Number of files processed so far. + file_count: usize, +} + +impl EmbeddingDecider { + /// Create a new embedding decider. + /// + /// # Arguments + /// + /// * `config` - Embedding configuration with size limits and policies + /// * `agent_supports_embedded` - Whether the agent advertised `embeddedContext` capability + pub fn new(config: EmbeddingConfig, agent_supports_embedded: bool) -> Self { + Self { + config, + agent_supports_embedded, + accumulated_bytes: 0, + file_count: 0, + } + } + + /// Decide the embedding strategy for a file. + /// + /// This implements the decision tree from the file embedding policy: + /// 1. File type detection + /// 2. Size checks (per-file and total accumulated) + /// 3. Capability check (embeddedContext) + /// 4. Strategy selection (embed text, blob, link, snippet, deny) + /// + /// # Arguments + /// + /// * `path` - Path to the file to embed + /// + /// # Returns + /// + /// The chosen embedding strategy. + pub fn decide(&mut self, path: &Path) -> ToolResult<EmbeddingStrategy> { + // Check file count limit + if self.file_count >= self.config.max_files_per_prompt as usize { + return Ok(EmbeddingStrategy::Deny { + reason: format!( + "Maximum file count ({}) exceeded", + self.config.max_files_per_prompt + ), + }); + } + + // Get file metadata + let metadata = std::fs::metadata(path)?; + let file_size = metadata.len(); + + // Detect file type + let file_type = detect_file_type(path)?; + + // Build a name for the file (use file name, not full path) + let name = path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown") + .to_string(); + + // Decide based on file type and capabilities + let strategy = if file_type.is_binary { + // Binary files: prefer link, allow small blobs if needed + if file_size <= self.config.max_embed_bytes && self.agent_supports_embedded { + // Small binary - could embed as blob, but prefer link for efficiency + if self.config.allow_resource_link { + EmbeddingStrategy::Link { + uri: format!("dirigent://resource/{}", Self::hash_path(path)), + name, + size: file_size, + mime_type: file_type.mime_type, + } + } else { + // Embed as blob only if linking is disabled + let data = std::fs::read(path)?; + EmbeddingStrategy::EmbedBlob { + data, + mime_type: file_type.mime_type.unwrap_or_else(|| "application/octet-stream".to_string()), + } + } + } else { + // Large binary - link only + if self.config.allow_resource_link { + EmbeddingStrategy::Link { + uri: format!("dirigent://resource/{}", Self::hash_path(path)), + name, + size: file_size, + mime_type: file_type.mime_type, + } + } else { + EmbeddingStrategy::Deny { + reason: format!("Binary file too large ({} bytes) and linking is disabled", file_size), + } + } + } + } else { + // Text files: embed if small and capability supports, otherwise link or snippet + if !self.agent_supports_embedded { + // Agent doesn't support embedded context - always link + if self.config.allow_resource_link { + EmbeddingStrategy::Link { + uri: format!("dirigent://resource/{}", Self::hash_path(path)), + name, + size: file_size, + mime_type: file_type.mime_type, + } + } else { + EmbeddingStrategy::Deny { + reason: "Agent does not support embedded context and linking is disabled".to_string(), + } + } + } else if file_size <= self.config.max_embed_bytes { + // Small text file - check total accumulated bytes + let new_total = self.accumulated_bytes + file_size as usize; + + // For Phase 03, use max_embed_bytes * max_files_per_prompt as total cap + let max_total_bytes = (self.config.max_embed_bytes as usize) + * (self.config.max_files_per_prompt as usize); + + if new_total <= max_total_bytes { + // Embed the file + let content = std::fs::read_to_string(path)?; + EmbeddingStrategy::EmbedText { + content, + mime_type: file_type.mime_type.unwrap_or_else(|| "text/plain; charset=utf-8".to_string()), + } + } else { + // Exceeds total byte cap - link or deny + if self.config.allow_resource_link { + EmbeddingStrategy::Link { + uri: format!("dirigent://resource/{}", Self::hash_path(path)), + name, + size: file_size, + mime_type: file_type.mime_type, + } + } else { + EmbeddingStrategy::Deny { + reason: format!("Total embedded bytes would exceed limit ({} bytes)", max_total_bytes), + } + } + } + } else { + // Large text file - link or snippet + if self.config.allow_resource_link { + EmbeddingStrategy::Link { + uri: format!("dirigent://resource/{}", Self::hash_path(path)), + name, + size: file_size, + mime_type: file_type.mime_type, + } + } else if self.config.snippet_strategy != crate::config::SnippetStrategy::HeadTail { + // Snippet embedding not configured + EmbeddingStrategy::Deny { + reason: format!("File too large ({} bytes) for embedding and linking is disabled", file_size), + } + } else { + // Generate snippet (handled in EMBED-05) + // For now, deny and indicate snippet is needed + EmbeddingStrategy::Deny { + reason: "Snippet generation not yet implemented".to_string(), + } + } + } + }; + + // Update accumulated bytes if we're embedding + match &strategy { + EmbeddingStrategy::EmbedText { content, .. } => { + self.accumulated_bytes += content.len(); + self.file_count += 1; + } + EmbeddingStrategy::EmbedBlob { data, .. } => { + self.accumulated_bytes += data.len(); + self.file_count += 1; + } + EmbeddingStrategy::Link { .. } => { + self.file_count += 1; + } + EmbeddingStrategy::Snippet { .. } => { + self.file_count += 1; + } + EmbeddingStrategy::Deny { .. } => { + // Don't increment count for denied files + } + } + + Ok(strategy) + } + + /// Get the total bytes accumulated so far. + pub fn accumulated_bytes(&self) -> usize { + self.accumulated_bytes + } + + /// Get the number of files processed so far. + pub fn file_count(&self) -> usize { + self.file_count + } + + /// Generate a stable hash for a file path (used in URIs). + /// + /// This creates an opaque, stable identifier for the file. + fn hash_path(path: &Path) -> String { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher = DefaultHasher::new(); + path.hash(&mut hasher); + format!("{:x}", hasher.finish()) + } +} + +/// Standalone function for simple decision-making (no state tracking). +/// +/// This is a convenience wrapper for one-off decisions without state. +pub fn decide_embedding_strategy( + path: &Path, + agent_supports_embedded: bool, + config: &EmbeddingConfig, +) -> ToolResult<EmbeddingStrategy> { + let mut decider = EmbeddingDecider::new(config.clone(), agent_supports_embedded); + decider.decide(path) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::SnippetStrategy; + use std::io::Write; + use tempfile::NamedTempFile; + + fn default_config() -> EmbeddingConfig { + EmbeddingConfig { + max_embed_bytes: 256_000, + allow_resource_link: true, + redact_patterns: vec![], + snippet_strategy: SnippetStrategy::HeadTail, + max_files_per_prompt: 10, + } + } + + #[test] + fn test_small_text_file_with_capability() { + let mut temp = NamedTempFile::with_suffix(".txt").unwrap(); + temp.write_all(b"Hello, world!").unwrap(); + + let config = default_config(); + let mut decider = EmbeddingDecider::new(config, true); + + let strategy = decider.decide(temp.path()).unwrap(); + + match strategy { + EmbeddingStrategy::EmbedText { content, mime_type } => { + assert_eq!(content, "Hello, world!"); + assert!(mime_type.contains("text")); + } + _ => panic!("Expected EmbedText strategy"), + } + } + + #[test] + fn test_small_text_file_without_capability() { + let mut temp = NamedTempFile::with_suffix(".txt").unwrap(); + temp.write_all(b"Hello, world!").unwrap(); + + let config = default_config(); + let mut decider = EmbeddingDecider::new(config, false); + + let strategy = decider.decide(temp.path()).unwrap(); + + match strategy { + EmbeddingStrategy::Link { .. } => { + // Correct - should link when capability is missing + } + _ => panic!("Expected Link strategy when capability is missing"), + } + } + + #[test] + fn test_large_text_file_creates_link() { + let mut temp = NamedTempFile::with_suffix(".txt").unwrap(); + // Create file larger than max_embed_bytes + let large_content = "x".repeat(300_000); + temp.write_all(large_content.as_bytes()).unwrap(); + + let config = default_config(); + let mut decider = EmbeddingDecider::new(config, true); + + let strategy = decider.decide(temp.path()).unwrap(); + + match strategy { + EmbeddingStrategy::Link { size, .. } => { + assert_eq!(size, 300_000); + } + _ => panic!("Expected Link strategy for large file"), + } + } + + #[test] + fn test_binary_file_creates_link() { + let mut temp = NamedTempFile::with_suffix(".png").unwrap(); + temp.write_all(b"\x89PNG\r\n\x1a\n").unwrap(); + + let config = default_config(); + let mut decider = EmbeddingDecider::new(config, true); + + let strategy = decider.decide(temp.path()).unwrap(); + + match strategy { + EmbeddingStrategy::Link { mime_type, .. } => { + assert_eq!(mime_type, Some("image/png".to_string())); + } + _ => panic!("Expected Link strategy for binary file"), + } + } + + #[test] + fn test_max_file_count_exceeded() { + let mut config = default_config(); + config.max_files_per_prompt = 2; + + let mut decider = EmbeddingDecider::new(config, true); + + // Process 2 files successfully + let mut temp1 = NamedTempFile::with_suffix(".txt").unwrap(); + temp1.write_all(b"File 1").unwrap(); + let _ = decider.decide(temp1.path()).unwrap(); + + let mut temp2 = NamedTempFile::with_suffix(".txt").unwrap(); + temp2.write_all(b"File 2").unwrap(); + let _ = decider.decide(temp2.path()).unwrap(); + + // Third file should be denied + let mut temp3 = NamedTempFile::with_suffix(".txt").unwrap(); + temp3.write_all(b"File 3").unwrap(); + let strategy = decider.decide(temp3.path()).unwrap(); + + match strategy { + EmbeddingStrategy::Deny { reason } => { + assert!(reason.contains("Maximum file count")); + } + _ => panic!("Expected Deny strategy when file count exceeded"), + } + } + + #[test] + fn test_linking_disabled_denies_large_file() { + let mut temp = NamedTempFile::with_suffix(".txt").unwrap(); + let large_content = "x".repeat(300_000); + temp.write_all(large_content.as_bytes()).unwrap(); + + let mut config = default_config(); + config.allow_resource_link = false; + + let mut decider = EmbeddingDecider::new(config, true); + + let strategy = decider.decide(temp.path()).unwrap(); + + match strategy { + EmbeddingStrategy::Deny { reason } => { + assert!(reason.contains("linking is disabled")); + } + _ => panic!("Expected Deny when linking disabled and file too large"), + } + } + + #[test] + fn test_accumulated_bytes_tracking() { + let config = default_config(); + let mut decider = EmbeddingDecider::new(config, true); + + assert_eq!(decider.accumulated_bytes(), 0); + + let mut temp1 = NamedTempFile::with_suffix(".txt").unwrap(); + temp1.write_all(b"Hello").unwrap(); + let _ = decider.decide(temp1.path()).unwrap(); + + assert_eq!(decider.accumulated_bytes(), 5); + + let mut temp2 = NamedTempFile::with_suffix(".txt").unwrap(); + temp2.write_all(b"World!").unwrap(); + let _ = decider.decide(temp2.path()).unwrap(); + + assert_eq!(decider.accumulated_bytes(), 11); + } + + #[test] + fn test_hash_path_is_stable() { + let path = Path::new("/test/file.txt"); + let hash1 = EmbeddingDecider::hash_path(path); + let hash2 = EmbeddingDecider::hash_path(path); + + assert_eq!(hash1, hash2); + } +} diff --git a/crates/dirigent_tools/src/embedding/mod.rs b/crates/dirigent_tools/src/embedding/mod.rs new file mode 100644 index 0000000..fc98206 --- /dev/null +++ b/crates/dirigent_tools/src/embedding/mod.rs @@ -0,0 +1,12 @@ +//! File embedding utilities for ACP prompts. +//! +//! This module provides the core logic for deciding how to include file content +//! in ACP prompts based on agent capabilities, file size, and configuration. + +pub mod decider; +pub mod redactor; +pub mod snippet; + +pub use decider::{decide_embedding_strategy, EmbeddingStrategy, EmbeddingDecider}; +pub use redactor::{ContentRedactor, RedactionPattern}; +pub use snippet::{generate_snippet, Snippet, SnippetConfig}; diff --git a/crates/dirigent_tools/src/embedding/redactor.rs b/crates/dirigent_tools/src/embedding/redactor.rs new file mode 100644 index 0000000..d7af44f --- /dev/null +++ b/crates/dirigent_tools/src/embedding/redactor.rs @@ -0,0 +1,382 @@ +//! Content redaction for embedded files. +//! +//! This module implements pattern-based redaction to prevent secrets from +//! being included in embedded file content. Redaction only affects the +//! payload sent to the agent - files on disk are never modified. + +use regex::Regex; + +use crate::error::ToolResult; + +/// A single redaction pattern with name and replacement text. +#[derive(Debug, Clone)] +pub struct RedactionPattern { + /// Human-readable name for this pattern (e.g., "API_KEY"). + pub name: String, + /// Compiled regular expression to match. + pub regex: Regex, + /// Replacement text (e.g., "<REDACTED:API_KEY>"). + pub replacement: String, +} + +impl RedactionPattern { + /// Create a new redaction pattern. + /// + /// # Arguments + /// + /// * `name` - Human-readable name + /// * `pattern` - Regular expression pattern + /// * `replacement` - Replacement text (can include pattern name) + /// + /// # Returns + /// + /// A compiled redaction pattern, or an error if the regex is invalid. + pub fn new( + name: impl Into<String>, + pattern: &str, + replacement: impl Into<String>, + ) -> ToolResult<Self> { + let name = name.into(); + let regex = Regex::new(pattern).map_err(|e| { + crate::error::ToolError::InvalidInput(format!("Invalid redaction pattern: {}", e)) + })?; + + Ok(Self { + name, + regex, + replacement: replacement.into(), + }) + } + + /// Create a redaction pattern with default "<REDACTED:NAME>" replacement. + pub fn with_default_replacement(name: impl Into<String>, pattern: &str) -> ToolResult<Self> { + let name_str = name.into(); + let replacement = format!("<REDACTED:{}>", name_str.to_uppercase()); + Self::new(name_str, pattern, replacement) + } +} + +/// Content redactor for secret patterns. +/// +/// Applies a set of redaction patterns to content before embedding. +pub struct ContentRedactor { + /// Redaction patterns to apply. + patterns: Vec<RedactionPattern>, +} + +impl ContentRedactor { + /// Create a new content redactor with the given patterns. + /// + /// # Arguments + /// + /// * `pattern_strings` - List of regex pattern strings from configuration + /// + /// # Returns + /// + /// A content redactor, or an error if any pattern is invalid. + pub fn new(pattern_strings: &[String]) -> ToolResult<Self> { + let mut patterns = Vec::new(); + + for (i, pattern_str) in pattern_strings.iter().enumerate() { + let pattern = RedactionPattern::with_default_replacement( + format!("CUSTOM_{}", i), + pattern_str, + )?; + patterns.push(pattern); + } + + Ok(Self { patterns }) + } + + /// Create a redactor with default built-in patterns. + /// + /// Default patterns include: + /// - API keys + /// - AWS credentials + /// - Generic secrets and tokens + /// - Passwords + pub fn with_default_patterns() -> Self { + let mut patterns = Vec::new(); + + // API keys (various formats) + if let Ok(p) = RedactionPattern::with_default_replacement( + "API_KEY", + r#"(?i)(api[_-]?key|apikey)[:=\s]+["']?([a-zA-Z0-9_\-\.]{20,})["']?"#, + ) { + patterns.push(p); + } + + // AWS access key IDs + if let Ok(p) = RedactionPattern::with_default_replacement( + "AWS_ACCESS_KEY", + r"AKIA[0-9A-Z]{16}", + ) { + patterns.push(p); + } + + // Generic secrets and tokens + if let Ok(p) = RedactionPattern::with_default_replacement( + "SECRET", + r#"(?i)(secret|token)[:=\s]+["']?([a-zA-Z0-9_\-\.]{16,})["']?"#, + ) { + patterns.push(p); + } + + // Passwords + if let Ok(p) = RedactionPattern::with_default_replacement( + "PASSWORD", + r#"(?i)password[:=\s]+["']?([^\s"']{8,})["']?"#, + ) { + patterns.push(p); + } + + // Bearer tokens + if let Ok(p) = RedactionPattern::with_default_replacement( + "BEARER_TOKEN", + r"(?i)bearer\s+([a-zA-Z0-9_\-\.=]+)", + ) { + patterns.push(p); + } + + Self { patterns } + } + + /// Create a redactor combining default and custom patterns. + /// + /// # Arguments + /// + /// * `custom_patterns` - Additional custom pattern strings + pub fn with_custom_patterns(custom_patterns: &[String]) -> ToolResult<Self> { + let mut redactor = Self::with_default_patterns(); + + for (i, pattern_str) in custom_patterns.iter().enumerate() { + let pattern = RedactionPattern::with_default_replacement( + format!("CUSTOM_{}", i), + pattern_str, + )?; + redactor.patterns.push(pattern); + } + + Ok(redactor) + } + + /// Redact sensitive content from the given text. + /// + /// Applies all configured redaction patterns in order and returns + /// the redacted content. The original content is never modified. + /// + /// # Arguments + /// + /// * `content` - The content to redact + /// + /// # Returns + /// + /// Redacted content with sensitive data replaced. + pub fn redact(&self, content: &str) -> String { + let mut redacted = content.to_string(); + + for pattern in &self.patterns { + redacted = pattern.regex.replace_all(&redacted, &pattern.replacement).to_string(); + } + + // Ensure UTF-8 safety (should always be valid since we're only replacing with ASCII) + debug_assert!(redacted.is_char_boundary(redacted.len())); + + redacted + } + + /// Check if content contains any patterns that would be redacted. + /// + /// This can be used to warn users before embedding. + pub fn contains_secrets(&self, content: &str) -> bool { + self.patterns.iter().any(|p| p.regex.is_match(content)) + } + + /// Get the number of redaction patterns configured. + pub fn pattern_count(&self) -> usize { + self.patterns.len() + } +} + +// Implement a safe Debug that doesn't leak pattern details +impl std::fmt::Debug for ContentRedactor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ContentRedactor") + .field("pattern_count", &self.patterns.len()) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_redaction_pattern_creation() { + let pattern = RedactionPattern::new("TEST", r"\d+", "XXX").unwrap(); + + assert_eq!(pattern.name, "TEST"); + assert_eq!(pattern.replacement, "XXX"); + assert!(pattern.regex.is_match("123")); + } + + #[test] + fn test_redaction_pattern_with_default_replacement() { + let pattern = RedactionPattern::with_default_replacement("api_key", r"key_\d+").unwrap(); + + assert_eq!(pattern.name, "api_key"); + assert_eq!(pattern.replacement, "<REDACTED:API_KEY>"); + } + + #[test] + fn test_redaction_pattern_invalid_regex() { + let result = RedactionPattern::new("TEST", r"[invalid(", "XXX"); + + assert!(result.is_err()); + } + + #[test] + fn test_default_patterns_api_key() { + let redactor = ContentRedactor::with_default_patterns(); + + let content = "API_KEY=sk_live_1234567890abcdefghij"; + let redacted = redactor.redact(content); + + assert!(redacted.contains("<REDACTED:")); + assert!(!redacted.contains("sk_live_")); + } + + #[test] + fn test_default_patterns_aws_key() { + let redactor = ContentRedactor::with_default_patterns(); + + let content = "AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE"; + let redacted = redactor.redact(content); + + assert!(redacted.contains("<REDACTED:")); + assert!(!redacted.contains("AKIAIOSFODNN7EXAMPLE")); + } + + #[test] + fn test_default_patterns_secret() { + let redactor = ContentRedactor::with_default_patterns(); + + let content = "secret: my_super_secret_value_12345"; + let redacted = redactor.redact(content); + + assert!(redacted.contains("<REDACTED:")); + assert!(!redacted.contains("my_super_secret")); + } + + #[test] + fn test_default_patterns_password() { + let redactor = ContentRedactor::with_default_patterns(); + + let content = "password=MyP@ssw0rd123"; + let redacted = redactor.redact(content); + + assert!(redacted.contains("<REDACTED:")); + assert!(!redacted.contains("MyP@ssw0rd")); + } + + #[test] + fn test_default_patterns_bearer_token() { + let redactor = ContentRedactor::with_default_patterns(); + + let content = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"; + let redacted = redactor.redact(content); + + assert!(redacted.contains("<REDACTED:")); + assert!(!redacted.contains("eyJhbGci")); + } + + #[test] + fn test_custom_patterns() { + let custom = vec![r"CUSTOM-\d{4}".to_string()]; + let redactor = ContentRedactor::with_custom_patterns(&custom).unwrap(); + + let content = "My custom ID is CUSTOM-1234"; + let redacted = redactor.redact(content); + + assert!(redacted.contains("<REDACTED:")); + assert!(!redacted.contains("CUSTOM-1234")); + } + + #[test] + fn test_multiple_patterns() { + let redactor = ContentRedactor::with_default_patterns(); + + let content = r#" + API_KEY="sk_test_abcdefghijklmnopqrst" + password=secret123 + token: my_long_secret_token_value + "#; + + let redacted = redactor.redact(content); + + // All secrets should be redacted + assert!(!redacted.contains("sk_test_")); + assert!(!redacted.contains("secret123")); + assert!(!redacted.contains("my_long_secret")); + assert!(redacted.contains("<REDACTED:")); + } + + #[test] + fn test_contains_secrets() { + let redactor = ContentRedactor::with_default_patterns(); + + assert!(redactor.contains_secrets("API_KEY=sk_test_12345678901234567890")); + assert!(redactor.contains_secrets("password=secret")); + assert!(!redactor.contains_secrets("Hello, world!")); + } + + #[test] + fn test_utf8_safety() { + let redactor = ContentRedactor::with_default_patterns(); + + let content = "API_KEY=sk_test_12345678901234567890 你好世界"; + let redacted = redactor.redact(content); + + // Should still contain valid UTF-8 + assert!(redacted.contains("你好世界")); + assert!(std::str::from_utf8(redacted.as_bytes()).is_ok()); + } + + #[test] + fn test_no_false_positives_on_code() { + let redactor = ContentRedactor::with_default_patterns(); + + let code = r#" + fn api_key_validator(input: &str) -> bool { + input.len() > 20 + } + "#; + + let redacted = redactor.redact(code); + + // Function name and code should not be redacted + assert!(redacted.contains("api_key_validator")); + assert!(redacted.contains("fn")); + } + + #[test] + fn test_pattern_count() { + let redactor = ContentRedactor::with_default_patterns(); + assert_eq!(redactor.pattern_count(), 5); // 5 default patterns + + let custom = vec![r"CUSTOM-\d{4}".to_string()]; + let redactor_custom = ContentRedactor::with_custom_patterns(&custom).unwrap(); + assert_eq!(redactor_custom.pattern_count(), 6); // 5 default + 1 custom + } + + #[test] + fn test_debug_impl_safe() { + let redactor = ContentRedactor::with_default_patterns(); + let debug_str = format!("{:?}", redactor); + + // Should not leak actual pattern details + assert!(debug_str.contains("ContentRedactor")); + assert!(debug_str.contains("pattern_count")); + assert!(!debug_str.contains("regex")); // Internal details hidden + } +} diff --git a/crates/dirigent_tools/src/embedding/snippet.rs b/crates/dirigent_tools/src/embedding/snippet.rs new file mode 100644 index 0000000..16fbce4 --- /dev/null +++ b/crates/dirigent_tools/src/embedding/snippet.rs @@ -0,0 +1,354 @@ +//! Snippet generation for large files. +//! +//! This module implements partial selection strategies (head/tail/window) +//! for files exceeding embed size limits. + +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::Path; + +use crate::error::ToolResult; + +/// Configuration for snippet generation. +#[derive(Debug, Clone)] +pub struct SnippetConfig { + /// Number of lines to include from the beginning. + pub head_lines: usize, + /// Number of lines to include from the end. + pub tail_lines: usize, + /// Maximum total bytes for the snippet. + pub max_snippet_bytes: usize, +} + +impl Default for SnippetConfig { + fn default() -> Self { + Self { + head_lines: 200, + tail_lines: 200, + max_snippet_bytes: 128_000, // Half of default max_embed_bytes + } + } +} + +/// A snippet extracted from a file. +#[derive(Debug, Clone)] +pub struct Snippet { + /// First portion of the file (head). + pub head: Option<String>, + /// Last portion of the file (tail). + pub tail: Option<String>, + /// Total size of the original file in bytes. + pub total_size: u64, + /// Total number of lines in the original file. + pub total_lines: usize, + /// Whether the snippet is truncated (doesn't show full file). + pub truncated: bool, +} + +impl Snippet { + /// Render the snippet with metadata comments. + /// + /// Returns a formatted string with: + /// - Metadata header showing file info and truncation + /// - Head content + /// - Separator (if both head and tail present) + /// - Tail content + pub fn render(&self, file_path: &Path) -> String { + let file_name = file_path + .file_name() + .and_then(|n| n.to_str()) + .unwrap_or("unknown"); + + let mut output = String::new(); + + // Metadata header + output.push_str(&format!("# File: {} (truncated)\n", file_name)); + output.push_str(&format!( + "# Total size: {:.1} KB, {} lines\n", + self.total_size as f64 / 1024.0, + self.total_lines + )); + + if self.head.is_some() && self.tail.is_some() { + output.push_str(&format!( + "# Showing: first {} lines + last {} lines\n\n", + self.head.as_ref().unwrap().lines().count(), + self.tail.as_ref().unwrap().lines().count() + )); + } else if self.head.is_some() { + output.push_str(&format!( + "# Showing: first {} lines\n\n", + self.head.as_ref().unwrap().lines().count() + )); + } else if self.tail.is_some() { + output.push_str(&format!( + "# Showing: last {} lines\n\n", + self.tail.as_ref().unwrap().lines().count() + )); + } + + // Head content + if let Some(ref head) = self.head { + output.push_str(head); + if !head.ends_with('\n') { + output.push('\n'); + } + } + + // Separator if both head and tail + if self.head.is_some() && self.tail.is_some() { + output.push_str("\n# ... (middle content omitted) ...\n\n"); + } + + // Tail content + if let Some(ref tail) = self.tail { + output.push_str(tail); + if !tail.ends_with('\n') { + output.push('\n'); + } + } + + output + } +} + +/// Generate a snippet from a file. +/// +/// Reads the first N and last N lines from the file, respecting UTF-8 +/// boundaries and byte limits. +/// +/// # Arguments +/// +/// * `path` - Path to the file +/// * `config` - Snippet configuration +/// +/// # Returns +/// +/// A snippet with head/tail content and metadata. +pub fn generate_snippet(path: &Path, config: &SnippetConfig) -> ToolResult<Snippet> { + let metadata = std::fs::metadata(path)?; + let total_size = metadata.len(); + + // Read the full file to count lines and extract head/tail + let file = File::open(path)?; + let reader = BufReader::new(file); + + let mut all_lines: Vec<String> = Vec::new(); + for line in reader.lines() { + all_lines.push(line?); + } + + let total_lines = all_lines.len(); + + // Check if file is small enough to not need truncation + if total_lines <= config.head_lines + config.tail_lines { + return Ok(Snippet { + head: Some(all_lines.join("\n")), + tail: None, + total_size, + total_lines, + truncated: false, + }); + } + + // Extract head and tail + let head_content = all_lines + .iter() + .take(config.head_lines) + .cloned() + .collect::<Vec<_>>() + .join("\n"); + + let tail_content = all_lines + .iter() + .skip(all_lines.len().saturating_sub(config.tail_lines)) + .cloned() + .collect::<Vec<_>>() + .join("\n"); + + // Check if combined snippet exceeds max bytes + let combined_bytes = head_content.len() + tail_content.len(); + if combined_bytes > config.max_snippet_bytes { + // Reduce to head only if combined is too large + let truncated_head = truncate_to_bytes(&head_content, config.max_snippet_bytes); + Ok(Snippet { + head: Some(truncated_head), + tail: None, + total_size, + total_lines, + truncated: true, + }) + } else { + Ok(Snippet { + head: Some(head_content), + tail: Some(tail_content), + total_size, + total_lines, + truncated: true, + }) + } +} + +/// Truncate a string to fit within a byte limit, respecting UTF-8 boundaries. +fn truncate_to_bytes(s: &str, max_bytes: usize) -> String { + if s.len() <= max_bytes { + return s.to_string(); + } + + // Find the largest valid UTF-8 boundary within max_bytes + let mut boundary = max_bytes; + while boundary > 0 && !s.is_char_boundary(boundary) { + boundary -= 1; + } + + s[..boundary].to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use tempfile::NamedTempFile; + + #[test] + fn test_snippet_config_default() { + let config = SnippetConfig::default(); + + assert_eq!(config.head_lines, 200); + assert_eq!(config.tail_lines, 200); + assert_eq!(config.max_snippet_bytes, 128_000); + } + + #[test] + fn test_generate_snippet_small_file() { + let mut temp = NamedTempFile::new().unwrap(); + temp.write_all(b"Line 1\nLine 2\nLine 3\n").unwrap(); + + let config = SnippetConfig { + head_lines: 10, + tail_lines: 10, + max_snippet_bytes: 10_000, + }; + + let snippet = generate_snippet(temp.path(), &config).unwrap(); + + assert!(!snippet.truncated); + assert_eq!(snippet.total_lines, 3); + assert!(snippet.head.is_some()); + assert!(snippet.tail.is_none()); + } + + #[test] + fn test_generate_snippet_large_file() { + let mut temp = NamedTempFile::new().unwrap(); + for i in 1..=1000 { + writeln!(temp, "Line {}", i).unwrap(); + } + + let config = SnippetConfig { + head_lines: 10, + tail_lines: 10, + max_snippet_bytes: 100_000, + }; + + let snippet = generate_snippet(temp.path(), &config).unwrap(); + + assert!(snippet.truncated); + assert_eq!(snippet.total_lines, 1000); + assert!(snippet.head.is_some()); + assert!(snippet.tail.is_some()); + + // Head should have first 10 lines + let head = snippet.head.unwrap(); + assert!(head.contains("Line 1")); + assert!(head.contains("Line 10")); + assert!(!head.contains("Line 11")); + + // Tail should have last 10 lines + let tail = snippet.tail.unwrap(); + assert!(tail.contains("Line 1000")); + assert!(tail.contains("Line 991")); + assert!(!tail.contains("Line 990")); + } + + #[test] + fn test_snippet_render() { + let snippet = Snippet { + head: Some("Line 1\nLine 2".to_string()), + tail: Some("Line 99\nLine 100".to_string()), + total_size: 10_000, + total_lines: 100, + truncated: true, + }; + + let rendered = snippet.render(Path::new("/test/file.txt")); + + assert!(rendered.contains("# File: file.txt (truncated)")); + assert!(rendered.contains("# Total size:")); + assert!(rendered.contains("100 lines")); + assert!(rendered.contains("Line 1")); + assert!(rendered.contains("Line 100")); + assert!(rendered.contains("... (middle content omitted) ...")); + } + + #[test] + fn test_snippet_render_head_only() { + let snippet = Snippet { + head: Some("Line 1\nLine 2".to_string()), + tail: None, + total_size: 1_000, + total_lines: 10, + truncated: true, + }; + + let rendered = snippet.render(Path::new("/test/file.txt")); + + assert!(rendered.contains("Line 1")); + assert!(!rendered.contains("... (middle content omitted) ...")); + } + + #[test] + fn test_truncate_to_bytes_exact() { + let s = "Hello"; + let truncated = truncate_to_bytes(s, 5); + assert_eq!(truncated, "Hello"); + } + + #[test] + fn test_truncate_to_bytes_shorter() { + let s = "Hello, world!"; + let truncated = truncate_to_bytes(s, 5); + assert_eq!(truncated, "Hello"); + } + + #[test] + fn test_truncate_to_bytes_utf8_boundary() { + let s = "Hello 你好世界"; + // Try to truncate in the middle of a multibyte character + let truncated = truncate_to_bytes(s, 8); + // Should truncate before the multibyte char to maintain UTF-8 validity + assert_eq!(truncated, "Hello "); + } + + #[test] + fn test_snippet_exceeds_max_bytes() { + let mut temp = NamedTempFile::new().unwrap(); + for i in 1..=1000 { + writeln!(temp, "This is a very long line number {} with lots of text", i).unwrap(); + } + + let config = SnippetConfig { + head_lines: 500, + tail_lines: 500, + max_snippet_bytes: 1_000, // Very small limit + }; + + let snippet = generate_snippet(temp.path(), &config).unwrap(); + + assert!(snippet.truncated); + // Should fall back to head only when combined exceeds max + assert!(snippet.head.is_some()); + let head = snippet.head.unwrap(); + assert!(head.len() <= config.max_snippet_bytes); + } +} diff --git a/crates/dirigent_tools/src/error.rs b/crates/dirigent_tools/src/error.rs new file mode 100644 index 0000000..af185b3 --- /dev/null +++ b/crates/dirigent_tools/src/error.rs @@ -0,0 +1,108 @@ +//! Error types for tool operations. + +use thiserror::Error; + +/// Result type for tool operations. +pub type ToolResult<T> = Result<T, ToolError>; + +/// Errors that can occur during tool operations. +#[derive(Error, Debug)] +pub enum ToolError { + /// File or directory not found. + #[error("Not found: {path}")] + NotFound { path: String }, + + /// Permission denied by OS or sandbox policy. + #[error("Permission denied: {reason}")] + PermissionDenied { reason: String }, + + /// Path outside allowed sandbox roots. + #[error("Sandbox violation: {reason}")] + SandboxViolation { reason: String }, + + /// Path matched blocklist pattern. + #[error("Blocked path: {reason}")] + BlockedPath { reason: String }, + + /// File too large to process. + #[error("File too large: {size} bytes exceeds limit of {limit} bytes")] + FileTooLarge { size: u64, limit: u64 }, + + /// Invalid encoding (non-UTF-8). + #[error("Encoding not supported: {encoding}")] + EncodingUnsupported { encoding: String }, + + /// User rejected permission prompt. + #[error("Permission rejected by user")] + PermissionRejected, + + /// Terminal operation failed. + #[error("Terminal error: {message}")] + TerminalError { message: String }, + + /// Terminal not found. + #[error("Terminal not found: {terminal_id}")] + TerminalNotFound { terminal_id: String }, + + /// Search operation exceeded limits. + #[error("Search limit exceeded: {reason}")] + SearchLimitExceeded { reason: String }, + + /// Invalid configuration. + #[error("Invalid configuration: {0}")] + InvalidConfig(String), + + /// File read error with detailed information. + #[error("Failed to read file {path}: {source}")] + FileReadError { + path: String, + #[source] + source: std::io::Error, + }, + + /// Invalid input or parameters. + #[error("Invalid input: {0}")] + InvalidInput(String), + + /// I/O error. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + + /// JSON serialization error. + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), + + /// Generic error. + #[error("{0}")] + Other(#[from] anyhow::Error), +} + +impl ToolError { + /// Create a sandbox violation error without exposing the full path. + pub fn sandbox_violation(reason: impl Into<String>) -> Self { + Self::SandboxViolation { + reason: reason.into(), + } + } + + /// Create a blocked path error without exposing the full path. + pub fn blocked_path(reason: impl Into<String>) -> Self { + Self::BlockedPath { + reason: reason.into(), + } + } + + /// Create a permission denied error. + pub fn permission_denied(reason: impl Into<String>) -> Self { + Self::PermissionDenied { + reason: reason.into(), + } + } + + /// Create a terminal error. + pub fn terminal_error(message: impl Into<String>) -> Self { + Self::TerminalError { + message: message.into(), + } + } +} diff --git a/crates/dirigent_tools/src/floor/mod.rs b/crates/dirigent_tools/src/floor/mod.rs new file mode 100644 index 0000000..9e216f2 --- /dev/null +++ b/crates/dirigent_tools/src/floor/mod.rs @@ -0,0 +1,189 @@ +//! Hardcoded security floor. Settings cannot bypass these rules. +//! +//! Initial rule set lifted from Zed's `HARDCODED_SECURITY_RULES`: +//! catastrophic recursive deletes (`rm -rf /`, `~`, `$HOME`, `.`, `..`). + +pub mod shell; + +use crate::tool::ToolContext; +use regex::Regex; +use std::sync::OnceLock; + +#[derive(Debug)] +pub enum FloorDecision { + Pass, + Block { reason: &'static str }, +} + +pub struct SecurityFloor { + terminal_deny: Vec<Regex>, +} + +impl SecurityFloor { + pub fn new() -> Self { + Self { terminal_deny: terminal_deny_patterns().clone() } + } + + /// Check whether the given tool invocation hits a hard rule. + /// `tool` is the tool name; `input` is the JSON arguments object. + pub fn check(&self, tool: &str, input: &serde_json::Value, _ctx: &ToolContext) -> FloorDecision { + if tool != "terminal" && tool != "bash" { + return FloorDecision::Pass; + } + let Some(cmd) = extract_command(input) else { return FloorDecision::Pass }; + + // Check the raw command first + if self.matches_any(cmd) { + return FloorDecision::Block { + reason: "Blocked by built-in security rule (recursive delete of \ + root, home, current, or parent directory).", + }; + } + // Then each sub-command in a chain + for sub in shell::split_chain(cmd) { + if self.matches_any(sub) { + return FloorDecision::Block { + reason: "Blocked by built-in security rule (chained recursive \ + delete detected).", + }; + } + } + if interpolation_pattern().is_match(cmd) { + return FloorDecision::Block { + reason: "Blocked: shell substitutions/interpolations are not allowed \ + in terminal commands. Resolve $VAR, ${VAR}, $(...), backticks, \ + $((...)), <(...), >(...) before calling.", + }; + } + FloorDecision::Pass + } + + fn matches_any(&self, cmd: &str) -> bool { + self.terminal_deny.iter().any(|re| re.is_match(cmd)) + } +} + +impl Default for SecurityFloor { fn default() -> Self { Self::new() } } + +fn extract_command(input: &serde_json::Value) -> Option<&str> { + input.get("command").and_then(|v| v.as_str()) +} + +fn terminal_deny_patterns() -> &'static Vec<Regex> { + static SET: OnceLock<Vec<Regex>> = OnceLock::new(); + SET.get_or_init(|| { + const FLAGS: &str = r"(--[a-zA-Z0-9][-a-zA-Z0-9_]*(=[^\s]*)?\s+|-[a-zA-Z]+\s+)*"; + const TRAIL: &str = r"(\s+--[a-zA-Z0-9][-a-zA-Z0-9_]*(=[^\s]*)?|\s+-[a-zA-Z]+)*\s*"; + vec![ + Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?/\*?{TRAIL}$")).unwrap(), + Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?~/?\*?{TRAIL}$")).unwrap(), + Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?(\$HOME|\$\{{HOME\}})/?\*?{TRAIL}$")).unwrap(), + Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?\./?\*?{TRAIL}$")).unwrap(), + Regex::new(&format!(r"\brm\s+{FLAGS}(--\s+)?\.\./?\*?{TRAIL}$")).unwrap(), + ] + }) +} + +fn interpolation_pattern() -> &'static Regex { + static RE: OnceLock<Regex> = OnceLock::new(); + RE.get_or_init(|| { + // $VAR, ${...}, $(...), $((...)), `...`, <(...), >(...) + Regex::new(r#"(\$[A-Za-z_][A-Za-z0-9_]*|\$\{[^}]*\}|\$\([^)]*\)|\$\(\([^)]*\)\)|`[^`]*`|<\([^)]*\)|>\([^)]*\))"#).unwrap() + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig}; + use crate::permission::check::PermissionContext; + use crate::permission::whitelist::CompiledWhitelist; + use crate::tool::{ClientKind, ProtocolKind}; + use std::path::PathBuf; + + fn ctx() -> ToolContext { + let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); + let pc = PermissionContext::new("c".to_string(), None, wl); + ToolContext::for_test( + "c", ClientKind::claude(), ProtocolKind::acp(), + PathBuf::from("/tmp"), + SandboxConfig::default(), PermissionConfig::default(), pc, + ) + } + + fn input(cmd: &str) -> serde_json::Value { + serde_json::json!({ "command": cmd }) + } + + #[test] + fn floor_passes_non_terminal_tools() { + let f = SecurityFloor::new(); + assert!(matches!(f.check("read", &input("rm -rf /"), &ctx()), FloorDecision::Pass)); + } + + #[test] + fn floor_blocks_rm_rf_root() { + let f = SecurityFloor::new(); + for cmd in ["rm -rf /", "rm -rfv /", "rm -rf /*"] { + assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Block { .. }), + "expected block for {cmd}"); + } + } + + #[test] + fn floor_blocks_home_variants() { + let f = SecurityFloor::new(); + for cmd in ["rm -rf ~", "rm -rf ~/", "rm -rf $HOME", "rm -rf ${HOME}", "rm -rf ${HOME}/*"] { + assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Block { .. }), + "expected block for {cmd}"); + } + } + + #[test] + fn floor_blocks_dot_dotdot() { + let f = SecurityFloor::new(); + for cmd in ["rm -rf .", "rm -rf ./", "rm -rf ./*", "rm -rf ..", "rm -rf ../*"] { + assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Block { .. }), + "expected block for {cmd}"); + } + } + + #[test] + fn floor_passes_safe_commands() { + let f = SecurityFloor::new(); + for cmd in ["ls", "rm foo.txt", "rm -rf ./build", "rm -rf /tmp/work"] { + assert!(matches!(f.check("terminal", &input(cmd), &ctx()), FloorDecision::Pass), + "expected pass for {cmd}"); + } + } + + #[test] + fn floor_blocks_chained_rm_rf_root() { + let f = SecurityFloor::new(); + assert!(matches!( + f.check("terminal", &input("ls && rm -rf /"), &ctx()), + FloorDecision::Block { .. } + )); + } + + #[test] + fn floor_blocks_interpolation_in_terminal() { + let f = SecurityFloor::new(); + for cmd in ["echo $VAR", "echo ${VAR}", "echo $(date)", "echo `date`", + "echo $((1+1))", "cat <(ls)", "tee >(ls)"] { + assert!(matches!( + f.check("terminal", &input(cmd), &ctx()), + FloorDecision::Block { .. } + ), "expected block for {cmd}"); + } + } + + #[test] + fn floor_passes_command_without_interpolation() { + let f = SecurityFloor::new(); + assert!(matches!( + f.check("terminal", &input("git status"), &ctx()), + FloorDecision::Pass + )); + } +} diff --git a/crates/dirigent_tools/src/floor/shell.rs b/crates/dirigent_tools/src/floor/shell.rs new file mode 100644 index 0000000..f5fee3c --- /dev/null +++ b/crates/dirigent_tools/src/floor/shell.rs @@ -0,0 +1,66 @@ +//! Lightweight POSIX shell command decomposition. +//! +//! Splits a one-line shell input on `&&`, `||`, `;`, `|` operators. +//! Quoted regions are preserved so chain tokens inside quotes do not split. +//! This is a heuristic — not a full shell parser. It is sufficient to ensure +//! a chained `rm -rf /` is not hidden behind a leading benign command. + +/// Split `input` into individual sub-commands by POSIX chain operators. +/// Returns the original string as a single element if no chains are detected. +pub fn split_chain(input: &str) -> Vec<&str> { + let mut out = Vec::new(); + let bytes = input.as_bytes(); + let mut start = 0; + let mut i = 0; + let mut quote: Option<u8> = None; + let mut escape = false; + + while i < bytes.len() { + let b = bytes[i]; + if escape { escape = false; i += 1; continue; } + if b == b'\\' && quote != Some(b'\'') { escape = true; i += 1; continue; } + + match quote { + Some(q) if q == b => { quote = None; i += 1; continue; } + Some(_) => { i += 1; continue; } + None => { + if b == b'\'' || b == b'"' { quote = Some(b); i += 1; continue; } + } + } + + let two = bytes.get(i..i+2); + if two == Some(b"&&") || two == Some(b"||") { + out.push(input[start..i].trim()); + i += 2; start = i; continue; + } + if b == b';' || b == b'|' { + out.push(input[start..i].trim()); + i += 1; start = i; continue; + } + i += 1; + } + let tail = input[start..].trim(); + if !tail.is_empty() { out.push(tail); } + out +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] fn splits_on_and() { + assert_eq!(split_chain("ls && rm -rf /"), vec!["ls", "rm -rf /"]); + } + #[test] fn splits_on_semicolon_and_pipe() { + assert_eq!(split_chain("ls ; rm -rf / | wc"), vec!["ls", "rm -rf /", "wc"]); + } + #[test] fn keeps_quoted_chains_intact() { + assert_eq!(split_chain("echo 'a && b'"), vec!["echo 'a && b'"]); + } + #[test] fn handles_no_chain() { + assert_eq!(split_chain("ls -la"), vec!["ls -la"]); + } + #[test] fn handles_escaped_pipe() { + assert_eq!(split_chain(r"echo a\|b"), vec![r"echo a\|b"]); + } +} diff --git a/crates/dirigent_tools/src/fs.rs b/crates/dirigent_tools/src/fs.rs new file mode 100644 index 0000000..e269dd7 --- /dev/null +++ b/crates/dirigent_tools/src/fs.rs @@ -0,0 +1,28 @@ +//! File system operations with sandbox enforcement. +//! +//! This module provides: +//! - `read_text_file()` - Read UTF-8 text files with line range support (TOOLS-FS-01) +//! - `write_text_file()` - Write UTF-8 text files with atomic writes (TOOLS-FS-02) +//! - `generate_diff()` - Generate unified diffs for changes (TOOLS-FS-03) +//! - `edit_file()` - Apply text replacements and generate diffs (TOOLS-FS-04) +//! +//! All operations are subject to: +//! - Sandbox containment checks +//! - Blocklist enforcement +//! - Size limits +//! - Permission prompts (for writes) +//! +//! **Status**: All functions stubbed, implementation pending + +pub mod read; +pub mod write; +pub mod diff; +pub mod edit; +pub mod file_type; + +// Re-export main types and functions +pub use read::{read_text_file, ReadTextFileRequest, ReadTextFileResponse}; +pub use write::{write_text_file, normalize_eol, WriteTextFileRequest, WriteTextFileResponse}; +pub use diff::generate_diff; +pub use edit::{edit_file, EditFileRequest, EditFileResponse, EditOperation}; +pub use file_type::{detect_file_type, is_valid_utf8, FileTypeInfo}; diff --git a/crates/dirigent_tools/src/fs/diff.rs b/crates/dirigent_tools/src/fs/diff.rs new file mode 100644 index 0000000..7647599 --- /dev/null +++ b/crates/dirigent_tools/src/fs/diff.rs @@ -0,0 +1,195 @@ +//! Unified diff generation for write operations. +//! +//! **Status**: Implemented (TOOLS-FS-03) +//! +//! This module implements: +//! - Unified diff generation using similar crate +//! - Edge case handling (new files, deleted files, no changes) +//! - Diff size limiting for UI rendering +//! - Binary file detection and fallback + +use similar::{ChangeTag, TextDiff}; +use std::path::Path; + +/// Maximum diff size in characters before truncation. +const MAX_DIFF_SIZE: usize = 100_000; + +/// Generate a unified diff between old and new file contents. +/// +/// Returns a unified diff in standard format, suitable for UI rendering. +/// +/// ## Edge Cases +/// +/// - New file (old = empty) → Shows all lines as additions +/// - Deleted file (new = empty) → Shows all lines as deletions +/// - No change (old == new) → Returns empty string +/// - Very large diff → Truncated with message +/// +/// ## Format +/// +/// Standard unified diff format: +/// ```text +/// --- path/to/file.txt +/// +++ path/to/file.txt +/// @@ -1,3 +1,3 @@ +/// context line +/// -old line +/// +new line +/// context line +/// ``` +/// +/// ## Error Handling +/// +/// Diff generation is best-effort: +/// - Should never panic +/// - Falls back to generic message on error +/// - Logs warnings for debugging +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-03` +pub fn generate_diff(old_content: &str, new_content: &str, path: &Path) -> String { + // Early exit if contents are identical + if old_content == new_content { + return String::new(); + } + + // Handle new file case (old is empty) + if old_content.is_empty() { + let line_count = new_content.lines().count(); + let mut result = format!( + "--- /dev/null\n+++ {}\n@@ -0,0 +1,{} @@\n", + path.display(), + line_count + ); + for line in new_content.lines() { + result.push_str(&format!("+{}\n", line)); + } + return truncate_diff(result); + } + + // Handle deleted file case (new is empty) + if new_content.is_empty() { + let line_count = old_content.lines().count(); + let mut result = format!( + "--- {}\n+++ /dev/null\n@@ -1,{} +0,0 @@\n", + path.display(), + line_count + ); + for line in old_content.lines() { + result.push_str(&format!("-{}\n", line)); + } + return truncate_diff(result); + } + + // Generate unified diff using similar crate + match generate_unified_diff(old_content, new_content, path) { + Ok(diff) => truncate_diff(diff), + Err(e) => { + tracing::warn!(path = %path.display(), error = %e, "Failed to generate diff"); + format!("# Content changed (diff generation failed: {})\n", e) + } + } +} + +/// Generate unified diff using the similar crate. +fn generate_unified_diff(old_content: &str, new_content: &str, path: &Path) -> Result<String, String> { + let diff = TextDiff::from_lines(old_content, new_content); + + let mut output = String::new(); + output.push_str(&format!("--- {}\n", path.display())); + output.push_str(&format!("+++ {}\n", path.display())); + + // Generate unified diff format with context + for hunk in diff.unified_diff().iter_hunks() { + // Write hunk header + output.push_str(&format!("{}", hunk.header())); + + // Write changes + for change in hunk.iter_changes() { + let sign = match change.tag() { + ChangeTag::Delete => "-", + ChangeTag::Insert => "+", + ChangeTag::Equal => " ", + }; + output.push_str(&format!("{}{}", sign, change.value())); + } + } + + Ok(output) +} + +/// Truncate diff if it exceeds maximum size. +fn truncate_diff(mut diff: String) -> String { + if diff.len() > MAX_DIFF_SIZE { + diff.truncate(MAX_DIFF_SIZE); + diff.push_str("\n... [diff truncated for display] ...\n"); + } + diff +} + +#[cfg(test)] +mod tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_generate_diff_identical() { + let path = PathBuf::from("test.txt"); + let content = "line1\nline2\nline3"; + let diff = generate_diff(content, content, &path); + assert_eq!(diff, ""); + } + + #[test] + fn test_generate_diff_new_file() { + let path = PathBuf::from("test.txt"); + let old = ""; + let new = "line1\nline2\nline3"; + let diff = generate_diff(old, new, &path); + + assert!(diff.contains("--- /dev/null")); + assert!(diff.contains("+++ test.txt")); + assert!(diff.contains("+line1")); + assert!(diff.contains("+line2")); + assert!(diff.contains("+line3")); + } + + #[test] + fn test_generate_diff_deleted_file() { + let path = PathBuf::from("test.txt"); + let old = "line1\nline2\nline3"; + let new = ""; + let diff = generate_diff(old, new, &path); + + assert!(diff.contains("--- test.txt")); + assert!(diff.contains("+++ /dev/null")); + assert!(diff.contains("-line1")); + assert!(diff.contains("-line2")); + assert!(diff.contains("-line3")); + } + + #[test] + fn test_generate_diff_change() { + let path = PathBuf::from("test.txt"); + let old = "line1\nline2\nline3"; + let new = "line1\nmodified\nline3"; + let diff = generate_diff(old, new, &path); + + assert!(diff.contains("--- test.txt")); + assert!(diff.contains("+++ test.txt")); + assert!(diff.contains("-line2")); + assert!(diff.contains("+modified")); + } + + #[test] + fn test_truncate_diff() { + let short_diff = "short diff"; + assert_eq!(truncate_diff(short_diff.to_string()), short_diff); + + let long_diff = "x".repeat(MAX_DIFF_SIZE + 1000); + let truncated = truncate_diff(long_diff); + assert!(truncated.len() <= MAX_DIFF_SIZE + 100); // Allow for truncation message + assert!(truncated.contains("[diff truncated for display]")); + } +} diff --git a/crates/dirigent_tools/src/fs/edit.rs b/crates/dirigent_tools/src/fs/edit.rs new file mode 100644 index 0000000..be8d935 --- /dev/null +++ b/crates/dirigent_tools/src/fs/edit.rs @@ -0,0 +1,258 @@ +//! Internal edit helper for read + transform + write operations. +//! +//! **Status**: Implemented (TOOLS-FS-04) +//! +//! This module implements: +//! - Edit operation abstraction (not ACP-native, internal API) +//! - Read + transform + write flow +//! - Automatic diff generation +//! - String replacement operations + +use crate::config::SandboxConfig; +use crate::error::{ToolError, ToolResult}; +use crate::fs::diff::generate_diff; +use crate::fs::read::{read_text_file, ReadTextFileRequest}; +use crate::fs::write::{write_text_file, WriteTextFileRequest}; +use crate::path::validate_path; +use serde::{Deserialize, Serialize}; + +/// Request to edit a file via transformation operations. +/// +/// **Note**: This is an internal API, not exposed via ACP directly. +/// Agents use fs/write_text_file; this is for richer Dirigent UX. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EditFileRequest { + /// Absolute path to the file to edit. + pub path: String, + + /// Ordered list of edit operations to apply. + pub edits: Vec<EditOperation>, +} + +/// A single edit operation to apply to file content. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum EditOperation { + /// Replace occurrences of old_text with new_text. + Replace { + /// Text to find. + old_text: String, + + /// Text to replace with. + new_text: String, + + /// Replace all occurrences (true) or first only (false). + replace_all: bool, + }, + + /// Apply a unified diff patch (future). + #[allow(dead_code)] + Patch { + /// Unified diff string. + diff: String, + }, +} + +/// Response from editing a file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EditFileResponse { + /// Unified diff showing changes made. + pub diff: String, +} + +/// Edit a file by applying transformation operations. +/// +/// ## Implementation +/// +/// This function: +/// 1. Reads existing file content +/// 2. Applies edits in order: +/// - Replace: String find/replace (once or all) +/// - Patch: Apply unified diff (future) +/// 3. Writes transformed content +/// 4. Generates and returns diff +/// +/// ## Algorithm +/// +/// ```text +/// 1. old_content = read_text_file(path) +/// 2. new_content = old_content +/// 3. for each edit in edits: +/// new_content = apply_edit(new_content, edit) +/// 4. write_text_file(path, new_content) +/// 5. diff = generate_diff(old_content, new_content, path) +/// 6. return EditFileResponse { diff } +/// ``` +/// +/// ## Error Cases +/// +/// - File not found → `ToolError::NotFound` +/// - Edit on new file → `ToolError::NotFound` (edits require existing content) +/// - Same sandboxing/permission errors as read/write +/// +/// ## Tool Call Rendering +/// +/// Always emits `ToolCallContent::diff` for UX visualization. +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-04` +pub async fn edit_file( + request: EditFileRequest, + sandbox_config: &SandboxConfig, + permission_config: &crate::config::PermissionConfig, + permission_context: &crate::permission::check::PermissionContext, +) -> ToolResult<EditFileResponse> { + // Validate path first to get canonical path for diff generation + let canonical_path = validate_path(&request.path, sandbox_config)?; + + // Read existing file content + let read_request = ReadTextFileRequest { + path: request.path.clone(), + line: None, + limit: None, + }; + + let read_response = read_text_file(read_request, sandbox_config).await?; + let old_content = read_response.content; + + // Apply all edits in order + let mut new_content = old_content.clone(); + for edit in &request.edits { + new_content = apply_edit(&new_content, edit)?; + } + + // Write transformed content + let write_request = WriteTextFileRequest { + path: request.path.clone(), + content: new_content.clone(), + }; + + write_text_file(write_request, sandbox_config, permission_config, permission_context).await?; + + // Generate diff for UX + let diff = generate_diff(&old_content, &new_content, &canonical_path); + + tracing::info!( + path = %request.path, + edit_count = request.edits.len(), + "File edited successfully" + ); + + Ok(EditFileResponse { diff }) +} + +/// Apply a single edit operation to content. +fn apply_edit(content: &str, edit: &EditOperation) -> ToolResult<String> { + match edit { + EditOperation::Replace { + old_text, + new_text, + replace_all, + } => { + if *replace_all { + // Replace all occurrences + Ok(content.replace(old_text, new_text)) + } else { + // Replace only the first occurrence + if let Some(pos) = content.find(old_text) { + let mut result = String::with_capacity(content.len()); + result.push_str(&content[..pos]); + result.push_str(new_text); + result.push_str(&content[pos + old_text.len()..]); + Ok(result) + } else { + // No match found - return content unchanged + // This could be a warning, but we'll allow it + tracing::warn!( + old_text = %old_text, + "Edit operation: old_text not found in content" + ); + Ok(content.to_string()) + } + } + } + EditOperation::Patch { diff: _ } => { + // Future: Apply unified diff patch + Err(ToolError::InvalidInput( + "Patch edit operation not yet implemented".to_string(), + )) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_apply_edit_replace_all() { + let content = "foo bar foo baz foo"; + let edit = EditOperation::Replace { + old_text: "foo".to_string(), + new_text: "qux".to_string(), + replace_all: true, + }; + let result = apply_edit(content, &edit).unwrap(); + assert_eq!(result, "qux bar qux baz qux"); + } + + #[test] + fn test_apply_edit_replace_first() { + let content = "foo bar foo baz foo"; + let edit = EditOperation::Replace { + old_text: "foo".to_string(), + new_text: "qux".to_string(), + replace_all: false, + }; + let result = apply_edit(content, &edit).unwrap(); + assert_eq!(result, "qux bar foo baz foo"); + } + + #[test] + fn test_apply_edit_no_match() { + let content = "foo bar baz"; + let edit = EditOperation::Replace { + old_text: "qux".to_string(), + new_text: "quux".to_string(), + replace_all: false, + }; + let result = apply_edit(content, &edit).unwrap(); + assert_eq!(result, content); // Unchanged + } + + #[test] + fn test_apply_edit_empty_replacement() { + let content = "foo bar foo"; + let edit = EditOperation::Replace { + old_text: "foo".to_string(), + new_text: "".to_string(), + replace_all: true, + }; + let result = apply_edit(content, &edit).unwrap(); + assert_eq!(result, " bar "); + } + + #[test] + fn test_apply_edit_multiline() { + let content = "line1\nfoo\nline3\nfoo\nline5"; + let edit = EditOperation::Replace { + old_text: "foo".to_string(), + new_text: "bar".to_string(), + replace_all: true, + }; + let result = apply_edit(content, &edit).unwrap(); + assert_eq!(result, "line1\nbar\nline3\nbar\nline5"); + } + + #[test] + fn test_apply_edit_patch_unimplemented() { + let content = "foo bar baz"; + let edit = EditOperation::Patch { + diff: "some diff".to_string(), + }; + let result = apply_edit(content, &edit); + assert!(result.is_err()); + assert!(matches!(result, Err(ToolError::InvalidInput(_)))); + } +} diff --git a/crates/dirigent_tools/src/fs/file_type.rs b/crates/dirigent_tools/src/fs/file_type.rs new file mode 100644 index 0000000..a6d7522 --- /dev/null +++ b/crates/dirigent_tools/src/fs/file_type.rs @@ -0,0 +1,331 @@ +//! File type detection and MIME type guessing. +//! +//! This module provides utilities for detecting whether a file is text or binary, +//! and guessing appropriate MIME types based on file extensions. + +use std::fs; +use std::io::Read; +use std::path::Path; + +use crate::error::ToolResult; + +/// Information about a file's type and encoding. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct FileTypeInfo { + /// Whether the file appears to be text (not binary). + pub is_text: bool, + /// Whether the file appears to be binary (not text). + pub is_binary: bool, + /// MIME type guess based on extension and content. + pub mime_type: Option<String>, + /// Character set (always "utf-8" for text files in Phase 03). + pub charset: Option<String>, +} + +impl FileTypeInfo { + /// Create a text file type info. + pub fn text(mime_type: impl Into<String>) -> Self { + Self { + is_text: true, + is_binary: false, + mime_type: Some(mime_type.into()), + charset: Some("utf-8".to_string()), + } + } + + /// Create a binary file type info. + pub fn binary(mime_type: impl Into<String>) -> Self { + Self { + is_text: false, + is_binary: true, + mime_type: Some(mime_type.into()), + charset: None, + } + } + + /// Create an unknown file type info (defaults to binary). + pub fn unknown() -> Self { + Self { + is_text: false, + is_binary: true, + mime_type: Some("application/octet-stream".to_string()), + charset: None, + } + } +} + +/// Detect file type based on extension and content analysis. +/// +/// This uses a fast path (extension-based detection) followed by optional +/// content sniffing for extensionless files. +/// +/// # Arguments +/// +/// * `path` - Path to the file to analyze +/// +/// # Returns +/// +/// File type information including text/binary classification and MIME type. +pub fn detect_file_type(path: &Path) -> ToolResult<FileTypeInfo> { + // First, try extension-based detection (fast path) + if let Some(info) = detect_by_extension(path) { + return Ok(info); + } + + // For extensionless files, do content sniffing + detect_by_content(path) +} + +/// Detect file type by extension. +/// +/// This is the fast path for common file types. Returns None if the extension +/// is unknown or missing. +fn detect_by_extension(path: &Path) -> Option<FileTypeInfo> { + let ext = path.extension()?.to_str()?.to_lowercase(); + + match ext.as_str() { + // Programming languages + "rs" => Some(FileTypeInfo::text("text/x-rust")), + "ts" => Some(FileTypeInfo::text("text/typescript")), + "tsx" => Some(FileTypeInfo::text("text/tsx")), + "js" => Some(FileTypeInfo::text("text/javascript")), + "jsx" => Some(FileTypeInfo::text("text/jsx")), + "py" => Some(FileTypeInfo::text("text/x-python")), + "rb" => Some(FileTypeInfo::text("text/x-ruby")), + "go" => Some(FileTypeInfo::text("text/x-go")), + "java" => Some(FileTypeInfo::text("text/x-java")), + "c" => Some(FileTypeInfo::text("text/x-c")), + "cpp" | "cc" | "cxx" => Some(FileTypeInfo::text("text/x-c++")), + "h" | "hpp" => Some(FileTypeInfo::text("text/x-c-header")), + "cs" => Some(FileTypeInfo::text("text/x-csharp")), + "sh" | "bash" | "zsh" => Some(FileTypeInfo::text("text/x-shellscript")), + "ps1" => Some(FileTypeInfo::text("text/x-powershell")), + + // Data formats + "json" => Some(FileTypeInfo::text("application/json")), + "toml" => Some(FileTypeInfo::text("application/toml")), + "yaml" | "yml" => Some(FileTypeInfo::text("application/yaml")), + "xml" => Some(FileTypeInfo::text("application/xml")), + "csv" => Some(FileTypeInfo::text("text/csv")), + + // Markup + "html" | "htm" => Some(FileTypeInfo::text("text/html")), + "css" => Some(FileTypeInfo::text("text/css")), + "scss" | "sass" => Some(FileTypeInfo::text("text/x-scss")), + "md" | "markdown" => Some(FileTypeInfo::text("text/markdown")), + "rst" => Some(FileTypeInfo::text("text/x-rst")), + + // Plain text + "txt" => Some(FileTypeInfo::text("text/plain")), + "log" => Some(FileTypeInfo::text("text/plain")), + + // Config files + "ini" | "cfg" | "conf" => Some(FileTypeInfo::text("text/plain")), + "env" => Some(FileTypeInfo::text("text/plain")), + + // Binary formats + "png" => Some(FileTypeInfo::binary("image/png")), + "jpg" | "jpeg" => Some(FileTypeInfo::binary("image/jpeg")), + "gif" => Some(FileTypeInfo::binary("image/gif")), + "webp" => Some(FileTypeInfo::binary("image/webp")), + "svg" => Some(FileTypeInfo::text("image/svg+xml")), // SVG is text + "pdf" => Some(FileTypeInfo::binary("application/pdf")), + "zip" => Some(FileTypeInfo::binary("application/zip")), + "gz" | "gzip" => Some(FileTypeInfo::binary("application/gzip")), + "tar" => Some(FileTypeInfo::binary("application/x-tar")), + "mp3" => Some(FileTypeInfo::binary("audio/mpeg")), + "mp4" => Some(FileTypeInfo::binary("video/mp4")), + "exe" | "dll" | "so" | "dylib" => Some(FileTypeInfo::binary("application/octet-stream")), + + // Unknown extension + _ => None, + } +} + +/// Detect file type by analyzing content. +/// +/// Reads the first 8KB of the file and checks for UTF-8 validity and +/// presence of null bytes (indicating binary). +fn detect_by_content(path: &Path) -> ToolResult<FileTypeInfo> { + // Open file + let mut file = fs::File::open(path)?; + + // Read first 8KB for analysis + let mut buffer = vec![0u8; 8192]; + let bytes_read = file.read(&mut buffer)?; + buffer.truncate(bytes_read); + + // Check for null bytes (strong indicator of binary) + if buffer.contains(&0) { + return Ok(FileTypeInfo::binary("application/octet-stream")); + } + + // Try to validate as UTF-8 + match std::str::from_utf8(&buffer) { + Ok(_) => Ok(FileTypeInfo::text("text/plain")), + Err(_) => Ok(FileTypeInfo::binary("application/octet-stream")), + } +} + +/// Validate that a file contains valid UTF-8 text. +/// +/// This is used to ensure text files can be safely read and embedded. +/// +/// # Arguments +/// +/// * `path` - Path to the file to validate +/// +/// # Returns +/// +/// `Ok(true)` if the file is valid UTF-8, `Ok(false)` if not. +pub fn is_valid_utf8(path: &Path) -> ToolResult<bool> { + let content = fs::read(path)?; + Ok(std::str::from_utf8(&content).is_ok()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + use tempfile::NamedTempFile; + + #[test] + fn test_detect_by_extension_rust() { + let path = Path::new("test.rs"); + let info = detect_by_extension(path).unwrap(); + + assert!(info.is_text); + assert!(!info.is_binary); + assert_eq!(info.mime_type, Some("text/x-rust".to_string())); + assert_eq!(info.charset, Some("utf-8".to_string())); + } + + #[test] + fn test_detect_by_extension_typescript() { + let path = Path::new("test.ts"); + let info = detect_by_extension(path).unwrap(); + + assert!(info.is_text); + assert_eq!(info.mime_type, Some("text/typescript".to_string())); + } + + #[test] + fn test_detect_by_extension_json() { + let path = Path::new("config.json"); + let info = detect_by_extension(path).unwrap(); + + assert!(info.is_text); + assert_eq!(info.mime_type, Some("application/json".to_string())); + } + + #[test] + fn test_detect_by_extension_binary_png() { + let path = Path::new("image.png"); + let info = detect_by_extension(path).unwrap(); + + assert!(!info.is_text); + assert!(info.is_binary); + assert_eq!(info.mime_type, Some("image/png".to_string())); + assert_eq!(info.charset, None); + } + + #[test] + fn test_detect_by_extension_unknown() { + let path = Path::new("file.unknownext"); + let info = detect_by_extension(path); + + assert!(info.is_none()); + } + + #[test] + fn test_detect_by_extension_no_extension() { + let path = Path::new("Makefile"); + let info = detect_by_extension(path); + + assert!(info.is_none()); + } + + #[test] + fn test_detect_by_content_text() { + let mut temp = NamedTempFile::new().unwrap(); + temp.write_all(b"Hello, world!\nThis is a text file.\n") + .unwrap(); + + let info = detect_by_content(temp.path()).unwrap(); + + assert!(info.is_text); + assert!(!info.is_binary); + assert_eq!(info.mime_type, Some("text/plain".to_string())); + } + + #[test] + fn test_detect_by_content_binary_with_nulls() { + let mut temp = NamedTempFile::new().unwrap(); + temp.write_all(b"\x00\x01\x02\x03binary data").unwrap(); + + let info = detect_by_content(temp.path()).unwrap(); + + assert!(!info.is_text); + assert!(info.is_binary); + assert_eq!(info.mime_type, Some("application/octet-stream".to_string())); + } + + #[test] + fn test_detect_by_content_invalid_utf8() { + let mut temp = NamedTempFile::new().unwrap(); + // Invalid UTF-8 sequence + temp.write_all(&[0xFF, 0xFE, 0xFD]).unwrap(); + + let info = detect_by_content(temp.path()).unwrap(); + + assert!(!info.is_text); + assert!(info.is_binary); + } + + #[test] + fn test_detect_file_type_with_extension() { + let path = Path::new("test.rs"); + // Will use extension-based detection (no file needed) + let info = detect_by_extension(path).unwrap(); + + assert!(info.is_text); + assert_eq!(info.mime_type, Some("text/x-rust".to_string())); + } + + #[test] + fn test_is_valid_utf8_valid() { + let mut temp = NamedTempFile::new().unwrap(); + temp.write_all("Hello, UTF-8! 你好世界".as_bytes()) + .unwrap(); + + let result = is_valid_utf8(temp.path()).unwrap(); + assert!(result); + } + + #[test] + fn test_is_valid_utf8_invalid() { + let mut temp = NamedTempFile::new().unwrap(); + temp.write_all(&[0xFF, 0xFE, 0xFD]).unwrap(); + + let result = is_valid_utf8(temp.path()).unwrap(); + assert!(!result); + } + + #[test] + fn test_file_type_info_constructors() { + let text = FileTypeInfo::text("text/plain"); + assert!(text.is_text); + assert!(!text.is_binary); + assert_eq!(text.charset, Some("utf-8".to_string())); + + let binary = FileTypeInfo::binary("application/pdf"); + assert!(!binary.is_text); + assert!(binary.is_binary); + assert_eq!(binary.charset, None); + + let unknown = FileTypeInfo::unknown(); + assert!(!unknown.is_text); + assert!(unknown.is_binary); + assert_eq!(unknown.mime_type, Some("application/octet-stream".to_string())); + } +} diff --git a/crates/dirigent_tools/src/fs/read.rs b/crates/dirigent_tools/src/fs/read.rs new file mode 100644 index 0000000..d087bbd --- /dev/null +++ b/crates/dirigent_tools/src/fs/read.rs @@ -0,0 +1,215 @@ +//! File read operation with sandboxing and line/limit support. +//! +//! **Status**: Implemented (TOOLS-FS-01) +//! +//! This module implements: +//! - Path validation using the security layer +//! - Asynchronous file reading +//! - UTF-8 encoding validation +//! - Line range and limit semantics +//! - Sandbox containment checks +//! - Blocklist enforcement +//! - Size limit soft caps + +use crate::config::SandboxConfig; +use crate::error::{ToolError, ToolResult}; +use crate::path::validate_path; +use serde::{Deserialize, Serialize}; + +/// Request to read a text file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReadTextFileRequest { + /// Absolute path to the file to read. + pub path: String, + + /// Optional starting line (1-indexed). + /// + /// If provided with limit, reads from this line. + /// If provided without limit, reads from this line to end. + pub line: Option<usize>, + + /// Optional maximum number of lines to read. + /// + /// If provided without line, reads first N lines. + /// If provided with line, reads N lines starting from line. + pub limit: Option<usize>, +} + +/// Response from reading a text file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReadTextFileResponse { + /// UTF-8 text content of the file (or portion thereof). + pub content: String, +} + +/// Read a text file with sandboxing and line/limit support. +/// +/// ## Line/Limit Semantics +/// +/// - `line: None, limit: None` → Read entire file +/// - `line: Some(n), limit: None` → Read from line n to end +/// - `line: None, limit: Some(m)` → Read first m lines +/// - `line: Some(n), limit: Some(m)` → Read m lines starting from line n +/// +/// ## Error Cases +/// +/// - Read disabled → `ToolError::PermissionDenied` +/// - Path outside allowed roots → `ToolError::SandboxViolation` +/// - Path matches blocklist → `ToolError::BlockedPath` +/// - File not found → `ToolError::NotFound` +/// - Non-UTF-8 content → `ToolError::EncodingUnsupported` +/// - File too large (soft cap) → Warning logged, content returned +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-01` +/// - Security spec: `docs/building/04_acp_client/03_fs_sandboxing_and_permissions_spec.md` +pub async fn read_text_file( + request: ReadTextFileRequest, + config: &SandboxConfig, +) -> ToolResult<ReadTextFileResponse> { + // Check if read is enabled + if !config.read_enabled { + return Err(ToolError::permission_denied("Read operations are disabled")); + } + + // Validate and canonicalize path (checks containment and blocklist) + let canonical_path = validate_path(&request.path, config)?; + + // Read file asynchronously + let content_bytes = tokio::fs::read(&canonical_path) + .await + .map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + ToolError::NotFound { + path: request.path.clone(), + } + } else { + ToolError::FileReadError { + path: request.path.clone(), + source: e, + } + } + })?; + + // Check soft cap and log warning if exceeded + let file_size = content_bytes.len() as u64; + if file_size > config.max_read_bytes { + tracing::warn!( + path = %request.path, + size = file_size, + limit = config.max_read_bytes, + "File size exceeds max_read_bytes soft cap" + ); + } + + // Validate UTF-8 encoding + let content = String::from_utf8(content_bytes).map_err(|_| { + ToolError::EncodingUnsupported { + encoding: "non-UTF-8".to_string(), + } + })?; + + // Apply line/limit semantics + let final_content = apply_line_limit(&content, request.line, request.limit); + + Ok(ReadTextFileResponse { + content: final_content, + }) +} + +/// Apply line/limit semantics to file content. +/// +/// ## Semantics +/// +/// - `line: None, limit: None` → Return entire content +/// - `line: Some(n), limit: None` → Return from line n to end (1-indexed) +/// - `line: None, limit: Some(m)` → Return first m lines +/// - `line: Some(n), limit: Some(m)` → Return m lines starting from line n (1-indexed) +fn apply_line_limit(content: &str, line: Option<usize>, limit: Option<usize>) -> String { + match (line, limit) { + // No line or limit specified - return entire content + (None, None) => content.to_string(), + + // Only limit specified - return first N lines + (None, Some(limit)) => { + content + .lines() + .take(limit) + .collect::<Vec<_>>() + .join("\n") + } + + // Only line specified - return from line N to end (1-indexed) + (Some(start_line), None) => { + if start_line == 0 { + return content.to_string(); + } + content + .lines() + .skip(start_line.saturating_sub(1)) + .collect::<Vec<_>>() + .join("\n") + } + + // Both line and limit specified - return N lines starting from line M (1-indexed) + (Some(start_line), Some(limit)) => { + if start_line == 0 { + return content + .lines() + .take(limit) + .collect::<Vec<_>>() + .join("\n"); + } + content + .lines() + .skip(start_line.saturating_sub(1)) + .take(limit) + .collect::<Vec<_>>() + .join("\n") + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_apply_line_limit_no_params() { + let content = "line1\nline2\nline3"; + assert_eq!(apply_line_limit(content, None, None), content); + } + + #[test] + fn test_apply_line_limit_only_limit() { + let content = "line1\nline2\nline3\nline4"; + assert_eq!(apply_line_limit(content, None, Some(2)), "line1\nline2"); + } + + #[test] + fn test_apply_line_limit_only_line() { + let content = "line1\nline2\nline3\nline4"; + assert_eq!(apply_line_limit(content, Some(2), None), "line2\nline3\nline4"); + } + + #[test] + fn test_apply_line_limit_both() { + let content = "line1\nline2\nline3\nline4\nline5"; + assert_eq!(apply_line_limit(content, Some(2), Some(2)), "line2\nline3"); + } + + #[test] + fn test_apply_line_limit_start_zero() { + let content = "line1\nline2\nline3"; + assert_eq!(apply_line_limit(content, Some(0), None), content); + assert_eq!(apply_line_limit(content, Some(0), Some(2)), "line1\nline2"); + } + + #[test] + fn test_apply_line_limit_beyond_end() { + let content = "line1\nline2\nline3"; + assert_eq!(apply_line_limit(content, Some(10), None), ""); + assert_eq!(apply_line_limit(content, Some(2), Some(10)), "line2\nline3"); + } +} diff --git a/crates/dirigent_tools/src/fs/write.rs b/crates/dirigent_tools/src/fs/write.rs new file mode 100644 index 0000000..8fa9d4a --- /dev/null +++ b/crates/dirigent_tools/src/fs/write.rs @@ -0,0 +1,232 @@ +//! File write operation with sandboxing, permissions, and atomic writes. +//! +//! **Status**: Implemented (TOOLS-FS-02) +//! +//! This module implements: +//! - Path validation and canonicalization +//! - Permission checks (stubbed for now) +//! - Atomic write operations (temp + rename) +//! - EOL normalization +//! - Parent directory creation +//! - Size limit enforcement + +use crate::config::{EolPolicy, PermissionConfig, SandboxConfig}; +use crate::error::{ToolError, ToolResult}; +use crate::path::validate_path; +use crate::permission::check::{check_permission, PermissionContext}; +use crate::permission::cache::PermissionDecision; +use crate::permission::whitelist::PermissionOperation; +use serde::{Deserialize, Serialize}; +use std::path::Path; + +/// Request to write a text file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WriteTextFileRequest { + /// Absolute path to the file to write. + pub path: String, + + /// UTF-8 text content to write. + pub content: String, +} + +/// Response from writing a text file. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WriteTextFileResponse {} + +/// Write a text file with sandboxing, permission gating, and atomic writes. +/// +/// ## Implementation +/// +/// This function: +/// 1. Validates `write_enabled = true` in config +/// 2. Validates and canonicalizes path +/// 3. Checks containment and blocklist +/// 4. Checks permission (TODO: integrate with permission system) +/// 5. Validates content size against max_write_bytes +/// 6. Applies EOL policy normalization +/// 7. Creates parent directories if needed +/// 8. Performs atomic write (temp file + rename) +/// +/// ## EOL Normalization +/// +/// Applies configured EOL policy: +/// - `Preserve` → Keep original line endings +/// - `Lf` → Normalize to LF (\n) +/// - `Crlf` → Normalize to CRLF (\r\n) +/// +/// ## Atomic Writes +/// +/// To prevent partial writes: +/// 1. Write content to temporary file in same directory +/// 2. Rename temp file to target path (atomic on POSIX, near-atomic on Windows) +/// +/// ## Error Cases +/// +/// - Write disabled → `ToolError::PermissionDenied` +/// - Path outside allowed roots → `ToolError::SandboxViolation` +/// - Path matches blocklist → `ToolError::BlockedPath` +/// - Permission denied → `ToolError::PermissionDenied` +/// - Content too large → `ToolError::FileTooLarge` +/// - I/O errors → `ToolError::Io` +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-FS-02` +/// - Security spec: `docs/building/04_acp_client/03_fs_sandboxing_and_permissions_spec.md` +pub async fn write_text_file( + request: WriteTextFileRequest, + sandbox_config: &SandboxConfig, + permission_config: &PermissionConfig, + permission_context: &PermissionContext, +) -> ToolResult<WriteTextFileResponse> { + // Check if write is enabled + if !sandbox_config.write_enabled { + return Err(ToolError::permission_denied("Write operations are disabled")); + } + + // Validate content size + let content_size = request.content.len() as u64; + if content_size > sandbox_config.max_write_bytes { + return Err(ToolError::FileTooLarge { + size: content_size, + limit: sandbox_config.max_write_bytes, + }); + } + + // Validate and canonicalize path (checks containment and blocklist) + let canonical_path = validate_path(&request.path, sandbox_config)?; + + // Check permission via permission system (TOOLS-PERM-01) + let operation = PermissionOperation::Write { + path: request.path.clone(), + }; + + let decision = check_permission(&operation, permission_context, permission_config).await?; + + match decision { + PermissionDecision::Allowed => { + // Permission granted, proceed with write + tracing::debug!( + path = %request.path, + "Permission granted for write operation" + ); + } + PermissionDecision::Denied => { + return Err(ToolError::PermissionRejected); + } + PermissionDecision::Cancelled => { + return Err(ToolError::permission_denied("Permission prompt was cancelled")); + } + } + + // Apply EOL normalization + let normalized_content = normalize_eol(&request.content, sandbox_config.eol_policy); + + // Create parent directories if they don't exist + if let Some(parent) = canonical_path.parent() { + tokio::fs::create_dir_all(parent).await.map_err(|e| { + ToolError::FileReadError { + path: parent.display().to_string(), + source: e, + } + })?; + } + + // Perform atomic write + write_atomic(&canonical_path, &normalized_content).await?; + + tracing::info!( + path = %request.path, + size = content_size, + "File written successfully" + ); + + Ok(WriteTextFileResponse {}) +} + +/// Normalize line endings according to configured policy. +pub fn normalize_eol(content: &str, policy: EolPolicy) -> String { + match policy { + EolPolicy::Preserve => content.to_string(), + EolPolicy::Lf => { + // Replace CRLF with LF, then ensure all CR are removed + content.replace("\r\n", "\n").replace('\r', "\n") + } + EolPolicy::Crlf => { + // First normalize to LF, then replace with CRLF + let lf_normalized = content.replace("\r\n", "\n").replace('\r', "\n"); + lf_normalized.replace('\n', "\r\n") + } + } +} + +/// Write content to a file atomically using temp file + rename. +/// +/// This minimizes the risk of partial writes by: +/// 1. Writing to a temporary file in the same directory +/// 2. Renaming the temp file to the target path +/// +/// The rename operation is atomic on POSIX systems and near-atomic on Windows. +async fn write_atomic(path: &Path, content: &str) -> ToolResult<()> { + // Create a temporary file in the same directory + let parent = path.parent().unwrap_or_else(|| Path::new(".")); + let file_name = path.file_name().unwrap_or_default().to_string_lossy(); + let temp_path = parent.join(format!(".{}.tmp", file_name)); + + // Write content to temporary file + tokio::fs::write(&temp_path, content).await.map_err(|e| { + ToolError::FileReadError { + path: temp_path.display().to_string(), + source: e, + } + })?; + + // Rename temp file to target (atomic operation) + tokio::fs::rename(&temp_path, path).await.map_err(|e| { + // Clean up temp file on error + let _ = std::fs::remove_file(&temp_path); + ToolError::FileReadError { + path: path.display().to_string(), + source: e, + } + })?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_normalize_eol_preserve() { + let content = "line1\nline2\r\nline3\rline4"; + assert_eq!(normalize_eol(content, EolPolicy::Preserve), content); + } + + #[test] + fn test_normalize_eol_lf() { + let content = "line1\nline2\r\nline3\rline4"; + let expected = "line1\nline2\nline3\nline4"; + assert_eq!(normalize_eol(content, EolPolicy::Lf), expected); + } + + #[test] + fn test_normalize_eol_crlf() { + let content = "line1\nline2\r\nline3\rline4"; + let expected = "line1\r\nline2\r\nline3\r\nline4"; + assert_eq!(normalize_eol(content, EolPolicy::Crlf), expected); + } + + #[test] + fn test_normalize_eol_edge_cases() { + // Empty string + assert_eq!(normalize_eol("", EolPolicy::Lf), ""); + + // Only newlines + assert_eq!(normalize_eol("\n\n\n", EolPolicy::Crlf), "\r\n\r\n\r\n"); + + // Mixed line endings to LF + assert_eq!(normalize_eol("a\rb\nc\r\nd", EolPolicy::Lf), "a\nb\nc\nd"); + } +} diff --git a/crates/dirigent_tools/src/lib.rs b/crates/dirigent_tools/src/lib.rs new file mode 100644 index 0000000..0b69886 --- /dev/null +++ b/crates/dirigent_tools/src/lib.rs @@ -0,0 +1,78 @@ +//! # dirigent_tools +//! +//! Tool implementations for ACP (Agent-Client Protocol) client operations. +//! +//! This package provides: +//! - **File operations** (read, write, edit) with sandboxing +//! - **Terminal operations** (create, output, wait, kill) with isolation +//! - **Search operations** (glob, grep, ls) with result limiting +//! - **Permission system** for securing write and execute operations +//! - **Path sandboxing** with configurable allowed roots and blocklists +//! +//! ## Module Overview +//! +//! - [`error`] - Error types for tool operations +//! - [`config`] - Configuration types for sandboxing, permissions, and limits +//! - [`path`] - Path normalization and containment checking (cross-platform) +//! - [`fs`] - File system operations (read, write) with sandbox enforcement +//! - [`search`] - Search operations (glob, grep, ls) with result limiting +//! - [`terminal`] - Terminal/command execution with output capture +//! - [`permission`] - Permission prompts and decision caching +//! - [`audit`] - Audit logging for sensitive operations +//! - [`tool`] - Tool trait, events, and per-call context +//! - [`registry`] - ToolRegistry for built-in and dynamic tools +//! - [`floor`] - Hardcoded SecurityFloor (cannot be bypassed by settings) +//! - [`dispatch`] - Dispatch pipeline (registry → floor → run) +//! - [`tools`] - Built-in tool implementations (read, ...) +//! +//! ## Safety and Security +//! +//! All file and terminal operations are subject to: +//! - **Sandbox containment** - Operations restricted to configured allowed roots +//! - **Blocklist enforcement** - Sensitive paths can be explicitly denied +//! - **Permission prompts** - Write and execute operations can require user approval +//! - **Resource limits** - File size, search results, and terminal output are bounded +//! - **Audit logging** - All operations are logged with structured context +//! +//! ## Platform Support +//! +//! This crate is designed with Windows as a first-class platform: +//! - Handles Windows paths (backslashes, drive letters, UNC shares, `\\?\` prefixes) +//! - Supports MINGW-style paths (`/c/...`) +//! - Normalizes path separators for consistent policy enforcement +//! - Tests run on Windows, Linux, and macOS + +pub mod error; +pub mod config; +pub mod path; +pub mod fs; +pub mod search; +pub mod terminal; +pub mod permission; +pub mod audit; +pub mod embedding; +pub mod tool; +pub mod registry; +pub mod floor; +pub mod dispatch; +pub mod tools; + +// Re-export commonly used types +pub use error::{ToolError, ToolResult}; +pub use config::{ + SandboxConfig, PermissionConfig, TerminalConfig, SearchConfig, EmbeddingConfig, +}; +pub use tool::{ + AnyTool, AnyToolInput, ClientKind, Erased, PermissionRequestId, ProtocolKind, Tool, + ToolContext, ToolEvent, ToolEventSink, ToolInput, ToolKind, ToolLocation, ToolResultContent, +}; +pub use registry::{CollisionPolicy, DynamicEntry, ToolRegistry, ToolSource, Winner}; +pub use floor::{FloorDecision, SecurityFloor}; +pub use dispatch::{dispatch, DispatchResult}; + +/// Re-exports of the policy types `dirigent_tools` consumes from +/// `dirigent_fermata`. Keeping this here pins the dependency direction: +/// `dirigent_tools` → `dirigent_fermata`, never the reverse. +pub mod policy { + pub use dirigent_fermata::core::{Decision, Op, Policy, Reason, Rule}; +} diff --git a/crates/dirigent_tools/src/path.rs b/crates/dirigent_tools/src/path.rs new file mode 100644 index 0000000..d6d55a3 --- /dev/null +++ b/crates/dirigent_tools/src/path.rs @@ -0,0 +1,114 @@ +//! Path normalization and containment checking. +//! +//! This module provides cross-platform path utilities with special attention to Windows: +//! - Normalizes path separators (backslash vs forward slash) +//! - Handles Windows drive letters (C:\, /c/, etc.) +//! - Handles UNC paths (\\server\share\...) +//! - Handles long path prefixes (\\?\...) +//! - Handles MINGW-style paths (/c/Users/...) +//! - Canonical path resolution (symlinks, junctions) +//! - Containment checking for sandbox enforcement +//! +//! # Security-Critical +//! +//! This module is the security foundation for tool sandboxing. All operations must: +//! - Prevent path traversal attacks +//! - Prevent symlink escape attacks +//! - Handle all Windows path edge cases correctly +//! - Never expose disallowed paths in error messages + +pub mod canonicalize; +pub mod containment; +pub mod blocklist; +pub mod validate; + +// Re-export public API +pub use canonicalize::{canonicalize_path, SymlinkPolicy}; +pub use containment::check_containment; +pub use blocklist::check_blocklist; +pub use validate::validate_path; + +use std::path::{Path, PathBuf}; + +/// Get the basename (final component) of a path for safe error messages. +/// +/// This is used to avoid leaking full paths in error messages. +pub fn basename(path: &Path) -> String { + path.file_name() + .and_then(|s| s.to_str()) + .map(|s| s.to_string()) + .unwrap_or_else(|| String::new()) +} + +/// Check if a path is absolute. +/// +/// On Windows, this handles: +/// - Standard absolute paths (C:\...) +/// - UNC paths (\\server\share\...) +/// - Long-path prefixes (\\?\...) +/// - Verbatim UNC (\\?\UNC\server\share\...) +pub fn is_absolute(path: &Path) -> bool { + path.is_absolute() +} + +/// Normalize a path to use the platform's standard separator. +/// +/// On Windows, this converts forward slashes to backslashes. +/// On Unix, this is a no-op. +pub fn normalize_separators(path: &Path) -> PathBuf { + #[cfg(windows)] + { + let s = path.to_string_lossy(); + let normalized = s.replace('/', "\\"); + PathBuf::from(normalized) + } + + #[cfg(not(windows))] + { + path.to_path_buf() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_basename() { + assert_eq!(basename(Path::new("/foo/bar/baz.txt")), "baz.txt"); + assert_eq!(basename(Path::new("baz.txt")), "baz.txt"); + assert_eq!(basename(Path::new("/")), ""); + } + + #[cfg(windows)] + #[test] + fn test_is_absolute_windows() { + assert!(is_absolute(Path::new("C:\\Users"))); + assert!(is_absolute(Path::new("C:/Users"))); + assert!(is_absolute(Path::new("\\\\server\\share"))); + assert!(is_absolute(Path::new("\\\\?\\C:\\Users"))); + assert!(!is_absolute(Path::new("relative\\path"))); + assert!(!is_absolute(Path::new("..\\parent"))); + } + + #[cfg(unix)] + #[test] + fn test_is_absolute_unix() { + assert!(is_absolute(Path::new("/home/user"))); + assert!(!is_absolute(Path::new("relative/path"))); + assert!(!is_absolute(Path::new("../parent"))); + } + + #[cfg(windows)] + #[test] + fn test_normalize_separators_windows() { + assert_eq!( + normalize_separators(Path::new("C:/Users/foo/bar.txt")), + PathBuf::from("C:\\Users\\foo\\bar.txt") + ); + assert_eq!( + normalize_separators(Path::new("C:\\Users\\foo\\bar.txt")), + PathBuf::from("C:\\Users\\foo\\bar.txt") + ); + } +} diff --git a/crates/dirigent_tools/src/path/blocklist.rs b/crates/dirigent_tools/src/path/blocklist.rs new file mode 100644 index 0000000..06547ee --- /dev/null +++ b/crates/dirigent_tools/src/path/blocklist.rs @@ -0,0 +1,220 @@ +//! Blocklist evaluation using glob patterns. +//! +//! This module checks if paths match configured blocklist patterns, +//! allowing fine-grained denial of sensitive paths even within allowed roots. + +use crate::error::{ToolError, ToolResult}; +use globset::{Glob, GlobSet, GlobSetBuilder}; +use std::path::Path; + +/// Check if a path matches any blocklist patterns. +/// +/// Blocklist patterns can be: +/// - Absolute paths (exact match) +/// - Glob patterns (e.g., `**/.env`, `**/secrets/**`) +/// +/// # Arguments +/// +/// * `canonical_path` - The canonical path to check +/// * `blocklist_patterns` - List of path patterns or globs to deny +/// +/// # Returns +/// +/// Ok(()) if the path is not blocked, or an error if it matches a blocklist pattern. +/// +/// # Performance +/// +/// For best performance, pre-compile patterns into a `CompiledBlocklist` and use +/// `check_blocklist_compiled` instead. +pub fn check_blocklist( + canonical_path: &Path, + blocklist_patterns: &[String], +) -> ToolResult<()> { + if blocklist_patterns.is_empty() { + return Ok(()); + } + + // Compile patterns on the fly (slower) + let compiled = compile_blocklist(blocklist_patterns)?; + check_blocklist_compiled(canonical_path, &compiled) +} + +/// Pre-compiled blocklist for efficient matching. +pub struct CompiledBlocklist { + glob_set: GlobSet, +} + +impl CompiledBlocklist { + /// Create a new compiled blocklist. + pub fn new(glob_set: GlobSet) -> Self { + Self { glob_set } + } + + /// Get the underlying GlobSet. + pub fn glob_set(&self) -> &GlobSet { + &self.glob_set + } +} + +/// Compile blocklist patterns into a GlobSet for efficient matching. +/// +/// This should be done once at configuration load time. +/// +/// # Arguments +/// +/// * `patterns` - List of glob patterns or absolute paths +/// +/// # Returns +/// +/// A compiled blocklist ready for fast matching. +pub fn compile_blocklist(patterns: &[String]) -> ToolResult<CompiledBlocklist> { + let mut builder = GlobSetBuilder::new(); + + for pattern in patterns { + let glob = Glob::new(pattern).map_err(|e| { + ToolError::InvalidConfig(format!("Invalid blocklist pattern '{}': {}", pattern, e)) + })?; + builder.add(glob); + } + + let glob_set = builder.build().map_err(|e| { + ToolError::InvalidConfig(format!("Failed to compile blocklist patterns: {}", e)) + })?; + + Ok(CompiledBlocklist::new(glob_set)) +} + +/// Check if a path matches a pre-compiled blocklist. +/// +/// This is the fast path for blocklist checking. +/// +/// # Arguments +/// +/// * `canonical_path` - The canonical path to check +/// * `compiled` - Pre-compiled blocklist +/// +/// # Returns +/// +/// Ok(()) if the path is not blocked, or an error if it matches a pattern. +pub fn check_blocklist_compiled( + canonical_path: &Path, + compiled: &CompiledBlocklist, +) -> ToolResult<()> { + if compiled.glob_set.is_match(canonical_path) { + return Err(ToolError::blocked_path(format!( + "Path matches blocklist: {}", + super::basename(canonical_path) + ))); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_blocklist() { + let result = check_blocklist(Path::new("/any/path"), &[]); + assert!(result.is_ok()); + } + + #[test] + fn test_exact_path_match() { + let patterns = vec![ + "/etc/passwd".to_string(), + "/home/user/.ssh/id_rsa".to_string(), + ]; + + // Blocked + assert!(check_blocklist(Path::new("/etc/passwd"), &patterns).is_err()); + assert!(check_blocklist(Path::new("/home/user/.ssh/id_rsa"), &patterns).is_err()); + + // Not blocked + assert!(check_blocklist(Path::new("/etc/hosts"), &patterns).is_ok()); + } + + #[test] + fn test_glob_patterns() { + let patterns = vec![ + "**/.env".to_string(), + "**/secrets/**".to_string(), + "**/.git/**".to_string(), + ]; + + // Blocked by .env pattern + assert!(check_blocklist(Path::new("/project/.env"), &patterns).is_err()); + assert!(check_blocklist(Path::new("/project/subdir/.env"), &patterns).is_err()); + + // Blocked by secrets pattern + assert!(check_blocklist(Path::new("/project/secrets/key.txt"), &patterns).is_err()); + assert!(check_blocklist(Path::new("/app/secrets/api_key"), &patterns).is_err()); + + // Blocked by .git pattern + assert!(check_blocklist(Path::new("/project/.git/config"), &patterns).is_err()); + + // Not blocked + assert!(check_blocklist(Path::new("/project/src/main.rs"), &patterns).is_ok()); + assert!(check_blocklist(Path::new("/project/README.md"), &patterns).is_ok()); + } + + #[test] + fn test_compile_blocklist() { + let patterns = vec![ + "**/.env".to_string(), + "**/secrets/**".to_string(), + ]; + + let compiled = compile_blocklist(&patterns).unwrap(); + + // Blocked + assert!(check_blocklist_compiled( + Path::new("/project/.env"), + &compiled + ).is_err()); + + // Not blocked + assert!(check_blocklist_compiled( + Path::new("/project/src/main.rs"), + &compiled + ).is_ok()); + } + + #[test] + fn test_invalid_pattern() { + let patterns = vec![ + "[invalid".to_string(), // Invalid glob syntax + ]; + + let result = compile_blocklist(&patterns); + assert!(result.is_err()); + } + + #[cfg(windows)] + #[test] + fn test_windows_paths() { + let patterns = vec![ + "**/.env".to_string(), + "C:/secrets/**".to_string(), + ]; + + let compiled = compile_blocklist(&patterns).unwrap(); + + // Blocked + assert!(check_blocklist_compiled( + Path::new("C:\\project\\.env"), + &compiled + ).is_err()); + assert!(check_blocklist_compiled( + Path::new("C:\\secrets\\key.txt"), + &compiled + ).is_err()); + + // Not blocked + assert!(check_blocklist_compiled( + Path::new("C:\\project\\src\\main.rs"), + &compiled + ).is_ok()); + } +} diff --git a/crates/dirigent_tools/src/path/canonicalize.rs b/crates/dirigent_tools/src/path/canonicalize.rs new file mode 100644 index 0000000..8eb7f36 --- /dev/null +++ b/crates/dirigent_tools/src/path/canonicalize.rs @@ -0,0 +1,387 @@ +//! Path canonicalization with cross-platform support. +//! +//! This module handles: +//! - Converting relative paths to absolute +//! - Resolving symlinks and junctions per policy +//! - Normalizing Windows path formats (UNC, long-path, MINGW) +//! - Rejecting Windows reserved device names +//! - Handling non-existent paths (for write operations) + +use crate::error::{ToolError, ToolResult}; +use std::path::{Path, PathBuf, Component}; + +/// Symlink handling policy for path canonicalization. +#[derive(Debug, Clone, Copy)] +pub struct SymlinkPolicy { + /// Allow symlinks to escape allowed roots. + /// + /// If false (recommended), symlinks pointing outside allowed roots are rejected. + pub allow_symlink_escape: bool, + + /// Follow symlinks within allowed roots. + /// + /// If true, symlinks within allowed roots are followed during canonicalization. + pub follow_symlinks_within_roots: bool, +} + +impl Default for SymlinkPolicy { + fn default() -> Self { + Self { + allow_symlink_escape: false, + follow_symlinks_within_roots: true, + } + } +} + +impl SymlinkPolicy { + /// Create a policy that follows all symlinks (least restrictive). + pub fn follow_all() -> Self { + Self { + allow_symlink_escape: true, + follow_symlinks_within_roots: true, + } + } + + /// Create a policy that never follows symlinks (most restrictive). + pub fn follow_none() -> Self { + Self { + allow_symlink_escape: false, + follow_symlinks_within_roots: false, + } + } +} + +/// Canonicalize a path with the given symlink policy. +/// +/// This function: +/// 1. Rejects empty or relative paths (ACP requires absolute paths) +/// 2. Normalizes separators to platform standard +/// 3. Converts MINGW-style paths (/c/...) to native Windows paths (C:\...) +/// 4. Strips long-path prefixes (\\?\) for policy comparison +/// 5. Resolves symlinks/junctions per policy +/// 6. For non-existent paths, canonicalizes parent + appends remainder +/// 7. Ensures no ".." escapes remain +/// 8. Rejects Windows reserved device names (CON, NUL, etc.) +/// +/// # Arguments +/// +/// * `user_path` - The path to canonicalize (must be absolute) +/// * `policy` - Symlink handling policy +/// +/// # Returns +/// +/// The canonical absolute path, or an error if the path is invalid or violates policy. +/// +/// # Security +/// +/// This is a security-critical function. It must prevent: +/// - Path traversal attacks via ".." components +/// - Symlink escape attacks +/// - Access to Windows reserved device names +pub fn canonicalize_path(user_path: &Path, _policy: &SymlinkPolicy) -> ToolResult<PathBuf> { + // Step 1: Reject empty paths + if user_path.as_os_str().is_empty() { + return Err(ToolError::sandbox_violation("Path cannot be empty")); + } + + // Step 2: Convert to string for preprocessing + let path_str = user_path.to_string_lossy(); + + // Step 3: Handle MINGW-style paths on Windows (/c/... -> C:\...) + #[cfg(windows)] + let path_str = convert_mingw_path(&path_str); + + let mut path = PathBuf::from(path_str.as_ref()); + + // Step 4: Normalize separators + path = super::normalize_separators(&path); + + // Step 5: Reject relative paths + if !path.is_absolute() { + return Err(ToolError::sandbox_violation(format!( + "Path must be absolute: {}", + super::basename(&path) + ))); + } + + // Step 6: Strip long-path prefix for normalization (\\?\C:\... -> C:\...) + #[cfg(windows)] + { + path = strip_long_path_prefix(&path); + } + + // Step 7: Check for reserved device names on Windows + #[cfg(windows)] + { + if is_reserved_device_name(&path) { + return Err(ToolError::sandbox_violation(format!( + "Path is a reserved device name: {}", + super::basename(&path) + ))); + } + } + + // Step 8: Normalize drive letter case on Windows + #[cfg(windows)] + { + path = normalize_drive_letter(&path); + } + + // Step 9: Try to canonicalize using dunce (follows symlinks on all platforms) + // If the path doesn't exist, find the first existing ancestor and canonicalize it, + // then append all non-existent components + let canonical = match dunce::canonicalize(&path) { + Ok(canonical) => canonical, + Err(_) => { + // Path doesn't exist - find first existing ancestor + canonicalize_non_existent_path(&path)? + } + }; + + // Step 10: Verify no ".." components remain (security check) + for component in canonical.components() { + if component == Component::ParentDir { + return Err(ToolError::sandbox_violation( + "Path contains '..' after canonicalization (potential traversal attack)", + )); + } + } + + // Step 11: Strip long-path prefix again if added by canonicalize + #[cfg(windows)] + let canonical = strip_long_path_prefix(&canonical); + + Ok(canonical) +} + +/// Canonicalize a non-existent path by finding the first existing ancestor. +/// +/// This function walks up the directory tree until it finds an existing directory, +/// then builds the canonical path by appending the non-existent components. +fn canonicalize_non_existent_path(path: &Path) -> ToolResult<PathBuf> { + // Collect components of the non-existent path parts + let mut components_to_append: Vec<std::ffi::OsString> = Vec::new(); + let mut current = path.to_path_buf(); + + // Walk up the directory tree until we find an existing directory + loop { + if let Some(parent) = current.parent() { + if parent == current { + // Hit root without finding existing directory + // This shouldn't happen with absolute paths, but handle it + return Ok(path.to_path_buf()); + } + + // Try to canonicalize the parent + match dunce::canonicalize(parent) { + Ok(canonical_parent) => { + // Found existing parent - build the full path + components_to_append.reverse(); + let mut result = canonical_parent; + + // First add the current file_name (if any) + if let Some(file_name) = current.file_name() { + result = result.join(file_name); + } + + // Then add all the rest + for component in components_to_append { + result = result.join(component); + } + + return Ok(result); + } + Err(_) => { + // Parent doesn't exist - collect the component and continue up + if let Some(file_name) = current.file_name() { + components_to_append.push(file_name.to_os_string()); + } + current = parent.to_path_buf(); + } + } + } else { + // No parent (shouldn't reach here with absolute paths) + return Ok(path.to_path_buf()); + } + } +} + +/// Convert MINGW-style path (/c/Users/...) to Windows path (C:\Users\...). +/// +/// Only applies on Windows. Detects paths like: +/// - /c/... -> C:\... +/// - /d/... -> D:\... +#[cfg(windows)] +fn convert_mingw_path(path: &str) -> std::borrow::Cow<'_, str> { + // Check for MINGW-style path: starts with /<letter>/ + if path.len() >= 3 + && path.starts_with('/') + && path.chars().nth(1).map_or(false, |c| c.is_ascii_alphabetic()) + && (path.len() == 3 || path.chars().nth(2) == Some('/')) + { + let drive_letter = path.chars().nth(1).unwrap().to_ascii_uppercase(); + let rest = if path.len() > 3 { + // Normalize forward slashes to backslashes in the remainder + path[3..].replace('/', "\\") + } else { + String::new() + }; + format!("{}:\\{}", drive_letter, rest).into() + } else { + path.into() + } +} + +/// Strip the long-path prefix (\\?\) from a Windows path. +/// +/// Also handles verbatim UNC paths (\\?\UNC\server\share -> \\server\share). +#[cfg(windows)] +fn strip_long_path_prefix(path: &Path) -> PathBuf { + let path_str = path.to_string_lossy(); + + // Handle \\?\UNC\server\share -> \\server\share + if path_str.starts_with(r"\\?\UNC\") { + let without_prefix = &path_str[r"\\?\UNC\".len()..]; + return PathBuf::from(format!(r"\\{}", without_prefix)); + } + + // Handle \\?\C:\ -> C:\ + if path_str.starts_with(r"\\?\") { + let without_prefix = &path_str[r"\\?\".len()..]; + return PathBuf::from(without_prefix); + } + + path.to_path_buf() +} + +/// Normalize drive letter to uppercase (C:\ instead of c:\). +#[cfg(windows)] +fn normalize_drive_letter(path: &Path) -> PathBuf { + let path_str = path.to_string_lossy(); + + // Check if path starts with a drive letter (e.g., "c:\") + if path_str.len() >= 2 + && path_str.chars().nth(0).map_or(false, |c| c.is_ascii_alphabetic()) + && path_str.chars().nth(1) == Some(':') + { + let mut chars: Vec<char> = path_str.chars().collect(); + chars[0] = chars[0].to_ascii_uppercase(); + return PathBuf::from(chars.iter().collect::<String>()); + } + + path.to_path_buf() +} + +/// Check if a path is a Windows reserved device name. +/// +/// Reserved names include: +/// - CON, PRN, AUX, NUL +/// - COM1-COM9 +/// - LPT1-LPT9 +/// +/// These can appear with or without extensions (e.g., CON.txt is also reserved). +#[cfg(windows)] +fn is_reserved_device_name(path: &Path) -> bool { + const RESERVED_NAMES: &[&str] = &[ + "CON", "PRN", "AUX", "NUL", + "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9", + "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9", + ]; + + if let Some(file_name) = path.file_name() { + let name = file_name.to_string_lossy().to_uppercase(); + + // Check exact match + if RESERVED_NAMES.contains(&name.as_str()) { + return true; + } + + // Check with extension (e.g., CON.txt) + if let Some(stem) = Path::new(&*name).file_stem() { + let stem_str = stem.to_string_lossy(); + if RESERVED_NAMES.contains(&stem_str.as_ref()) { + return true; + } + } + } + + false +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_empty_path() { + let result = canonicalize_path(Path::new(""), &SymlinkPolicy::default()); + assert!(result.is_err()); + } + + #[cfg(windows)] + #[test] + fn test_convert_mingw_path() { + assert_eq!(convert_mingw_path("/c/Users/foo"), "C:\\Users\\foo"); + assert_eq!(convert_mingw_path("/d/Projects"), "D:\\Projects"); + assert_eq!(convert_mingw_path("/c/"), "C:\\"); + assert_eq!(convert_mingw_path("C:\\Users"), "C:\\Users"); // No conversion + } + + #[cfg(windows)] + #[test] + fn test_strip_long_path_prefix() { + assert_eq!( + strip_long_path_prefix(Path::new(r"\\?\C:\Users\foo")), + PathBuf::from(r"C:\Users\foo") + ); + assert_eq!( + strip_long_path_prefix(Path::new(r"\\?\UNC\server\share\file")), + PathBuf::from(r"\\server\share\file") + ); + assert_eq!( + strip_long_path_prefix(Path::new(r"C:\Users\foo")), + PathBuf::from(r"C:\Users\foo") + ); + } + + #[cfg(windows)] + #[test] + fn test_normalize_drive_letter() { + assert_eq!( + normalize_drive_letter(Path::new("c:\\Users")), + PathBuf::from("C:\\Users") + ); + assert_eq!( + normalize_drive_letter(Path::new("C:\\Users")), + PathBuf::from("C:\\Users") + ); + } + + #[cfg(windows)] + #[test] + fn test_is_reserved_device_name() { + assert!(is_reserved_device_name(Path::new("CON"))); + assert!(is_reserved_device_name(Path::new("con"))); + assert!(is_reserved_device_name(Path::new("CON.txt"))); + assert!(is_reserved_device_name(Path::new("C:\\path\\to\\NUL"))); + assert!(is_reserved_device_name(Path::new("COM1"))); + assert!(is_reserved_device_name(Path::new("LPT5"))); + assert!(!is_reserved_device_name(Path::new("CONNECT.txt"))); + assert!(!is_reserved_device_name(Path::new("file.txt"))); + } + + #[test] + fn test_symlink_policy() { + let default = SymlinkPolicy::default(); + assert!(!default.allow_symlink_escape); + assert!(default.follow_symlinks_within_roots); + + let follow_all = SymlinkPolicy::follow_all(); + assert!(follow_all.allow_symlink_escape); + assert!(follow_all.follow_symlinks_within_roots); + + let follow_none = SymlinkPolicy::follow_none(); + assert!(!follow_none.allow_symlink_escape); + assert!(!follow_none.follow_symlinks_within_roots); + } +} diff --git a/crates/dirigent_tools/src/path/containment.rs b/crates/dirigent_tools/src/path/containment.rs new file mode 100644 index 0000000..fc90921 --- /dev/null +++ b/crates/dirigent_tools/src/path/containment.rs @@ -0,0 +1,265 @@ +//! Path containment checking for sandbox enforcement. +//! +//! This module verifies that a canonical path is within allowed sandbox roots. + +use crate::error::{ToolError, ToolResult}; +use std::path::{Path, PathBuf, Component}; + +/// Check if a canonical path is contained within allowed roots. +/// +/// This function performs component-wise prefix matching (not string matching) +/// to ensure the path is strictly within at least one allowed root. +/// +/// # Arguments +/// +/// * `canonical_path` - The canonical path to check (must already be canonicalized) +/// * `allowed_roots` - List of allowed root paths (must already be canonical) +/// +/// # Returns +/// +/// Ok(()) if the path is contained, or an error if it's outside all roots. +/// +/// # Security +/// +/// This is a security-critical function. It uses component-wise comparison to prevent: +/// - String prefix attacks (e.g., "/a" should not contain "/ab") +/// - Path traversal attacks +pub fn check_containment( + canonical_path: &Path, + allowed_roots: &[PathBuf], +) -> ToolResult<()> { + // If no roots configured, deny all paths + if allowed_roots.is_empty() { + return Err(ToolError::sandbox_violation( + "No allowed roots configured - all paths denied", + )); + } + + // Check if path is contained by at least one root + for root in allowed_roots { + if is_contained_by(canonical_path, root) { + return Ok(()); + } + } + + // Path is outside all roots + Err(ToolError::sandbox_violation(format!( + "Path is outside allowed roots: {}", + super::basename(canonical_path) + ))) +} + +/// Check if a path is strictly contained by a root. +/// +/// This performs component-wise comparison to ensure proper containment: +/// - "/a/b/c" is contained by "/a" +/// - "/a/b/c" is NOT contained by "/a/b/c" (must be strict) +/// - "/ab" is NOT contained by "/a" (not a string prefix match) +/// +/// On Windows, comparison is case-insensitive to match filesystem behavior. +/// +/// # Arguments +/// +/// * `path` - The path to check (must be canonical) +/// * `root` - The root path (must be canonical) +/// +/// # Returns +/// +/// `true` if path is strictly contained within root, `false` otherwise. +pub fn is_contained_by(path: &Path, root: &Path) -> bool { + // Get components + let path_components: Vec<_> = path.components().collect(); + let root_components: Vec<_> = root.components().collect(); + + // Path must have MORE components than root (strict containment) + // Equal components means path == root, which is not strict containment + if path_components.len() <= root_components.len() { + return false; + } + + // Check if all root components match path components (prefix) + for (i, root_comp) in root_components.iter().enumerate() { + let path_comp = &path_components[i]; + + if !components_equal(path_comp, root_comp) { + return false; + } + } + + true +} + +/// Compare two path components for equality. +/// +/// On Windows, this is case-insensitive. +/// On Unix, this is case-sensitive. +fn components_equal(a: &Component, b: &Component) -> bool { + #[cfg(windows)] + { + // Case-insensitive comparison on Windows + let a_str = component_to_string(a).to_lowercase(); + let b_str = component_to_string(b).to_lowercase(); + a_str == b_str + } + + #[cfg(not(windows))] + { + // Case-sensitive comparison on Unix + a == b + } +} + +/// Convert a path component to a string for comparison. +#[cfg(windows)] +fn component_to_string(comp: &Component) -> String { + match comp { + Component::Prefix(prefix) => prefix.as_os_str().to_string_lossy().to_string(), + Component::RootDir => "/".to_string(), + Component::CurDir => ".".to_string(), + Component::ParentDir => "..".to_string(), + Component::Normal(s) => s.to_string_lossy().to_string(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_is_contained_by_unix() { + // Strict containment (child must be strictly inside parent) + assert!(is_contained_by( + Path::new("/a/b/c"), + Path::new("/a") + )); + assert!(is_contained_by( + Path::new("/a/b/c"), + Path::new("/a/b") + )); + + // Equal paths are NOT strictly contained + assert!(!is_contained_by( + Path::new("/a/b"), + Path::new("/a/b") + )); + + // Parent is not contained by child + assert!(!is_contained_by( + Path::new("/a"), + Path::new("/a/b") + )); + + // Different trees + assert!(!is_contained_by( + Path::new("/a/b"), + Path::new("/c") + )); + + // String prefix but not path prefix + assert!(!is_contained_by( + Path::new("/ab"), + Path::new("/a") + )); + } + + #[cfg(windows)] + #[test] + fn test_is_contained_by_windows_case_insensitive() { + // Windows is case-insensitive + assert!(is_contained_by( + Path::new("C:\\Users\\foo\\bar"), + Path::new("C:\\Users") + )); + assert!(is_contained_by( + Path::new("c:\\users\\foo\\bar"), + Path::new("C:\\Users") + )); + assert!(is_contained_by( + Path::new("C:\\Users\\foo"), + Path::new("c:\\users") + )); + } + + #[test] + fn test_check_containment_no_roots() { + let result = check_containment(Path::new("/a/b/c"), &[]); + assert!(result.is_err()); + } + + #[test] + fn test_check_containment_allowed() { + let roots = vec![ + PathBuf::from("/home/user/project"), + PathBuf::from("/tmp"), + ]; + + // Inside first root + assert!(check_containment( + Path::new("/home/user/project/src/main.rs"), + &roots + ).is_ok()); + + // Inside second root + assert!(check_containment( + Path::new("/tmp/foo.txt"), + &roots + ).is_ok()); + } + + #[test] + fn test_check_containment_denied() { + let roots = vec![ + PathBuf::from("/home/user/project"), + ]; + + // Outside all roots + assert!(check_containment( + Path::new("/etc/passwd"), + &roots + ).is_err()); + + // Sibling directory (string prefix but not path prefix) + assert!(check_containment( + Path::new("/home/user/project2/file.txt"), + &roots + ).is_err()); + + // Parent directory + assert!(check_containment( + Path::new("/home/user/other.txt"), + &roots + ).is_err()); + + // Equal to root (not strictly contained) + assert!(check_containment( + Path::new("/home/user/project"), + &roots + ).is_err()); + } + + #[cfg(windows)] + #[test] + fn test_check_containment_windows() { + let roots = vec![ + PathBuf::from("C:\\Users\\foo"), + ]; + + // Allowed + assert!(check_containment( + Path::new("C:\\Users\\foo\\Documents\\file.txt"), + &roots + ).is_ok()); + + // Denied - different drive + assert!(check_containment( + Path::new("D:\\file.txt"), + &roots + ).is_err()); + + // Denied - outside root + assert!(check_containment( + Path::new("C:\\Windows\\system32"), + &roots + ).is_err()); + } +} diff --git a/crates/dirigent_tools/src/path/validate.rs b/crates/dirigent_tools/src/path/validate.rs new file mode 100644 index 0000000..295e348 --- /dev/null +++ b/crates/dirigent_tools/src/path/validate.rs @@ -0,0 +1,165 @@ +//! Path validation facade combining all security checks. +//! +//! This module provides a single entry point for path validation that: +//! 1. Canonicalizes the path +//! 2. Checks containment within allowed roots +//! 3. Evaluates blocklist patterns +//! +//! All tool operations should use this facade for consistent security enforcement. + +use crate::config::SandboxConfig; +use crate::error::ToolResult; +use std::path::{Path, PathBuf}; + +use super::blocklist::{compile_blocklist, check_blocklist_compiled, CompiledBlocklist}; +use super::canonicalize::{canonicalize_path, SymlinkPolicy}; +use super::containment::check_containment; + +/// Validate a path against sandbox configuration. +/// +/// This is the main entry point for path validation. It performs all security checks: +/// 1. Path canonicalization (handles symlinks, Windows paths, etc.) +/// 2. Containment checking (ensures path is within allowed roots) +/// 3. Blocklist evaluation (checks against denied patterns) +/// +/// # Arguments +/// +/// * `user_path` - The path provided by the user/agent (must be absolute) +/// * `config` - Sandbox configuration with allowed roots and blocklist +/// +/// # Returns +/// +/// The canonical path if all checks pass, or a security error otherwise. +/// +/// # Example +/// +/// ```rust,no_run +/// use dirigent_tools::path::validate_path; +/// use dirigent_tools::config::SandboxConfig; +/// use std::path::PathBuf; +/// +/// let mut config = SandboxConfig::default(); +/// config.allowed_roots = vec![PathBuf::from("/home/user/project")]; +/// config.blocked_paths = vec!["**/.env".to_string()]; +/// +/// // Valid path +/// let result = validate_path("/home/user/project/src/main.rs", &config); +/// assert!(result.is_ok()); +/// +/// // Blocked path +/// let result = validate_path("/home/user/project/.env", &config); +/// assert!(result.is_err()); +/// +/// // Outside roots +/// let result = validate_path("/etc/passwd", &config); +/// assert!(result.is_err()); +/// ``` +pub fn validate_path(user_path: &str, config: &SandboxConfig) -> ToolResult<PathBuf> { + // Step 1: Canonicalize the path + let symlink_policy = SymlinkPolicy { + allow_symlink_escape: config.allow_symlink_escape, + follow_symlinks_within_roots: config.follow_symlinks_within_roots, + }; + + let canonical_path = canonicalize_path(Path::new(user_path), &symlink_policy)?; + + // Step 2: Check containment + check_containment(&canonical_path, &config.allowed_roots)?; + + // Step 3: Check blocklist + if !config.blocked_paths.is_empty() { + // Compile blocklist on the fly (for now) + // TODO: Pre-compile blocklist in SandboxConfig for better performance + let compiled = compile_blocklist(&config.blocked_paths)?; + check_blocklist_compiled(&canonical_path, &compiled)?; + } + + Ok(canonical_path) +} + +/// Validate a path with a pre-compiled blocklist (for performance). +/// +/// This is a faster variant of `validate_path` that uses a pre-compiled blocklist. +/// Useful when validating many paths with the same configuration. +/// +/// # Arguments +/// +/// * `user_path` - The path provided by the user/agent (must be absolute) +/// * `config` - Sandbox configuration +/// * `compiled_blocklist` - Pre-compiled blocklist (can be None if blocklist is empty) +/// +/// # Returns +/// +/// The canonical path if all checks pass, or a security error otherwise. +pub fn validate_path_compiled( + user_path: &str, + config: &SandboxConfig, + compiled_blocklist: Option<&CompiledBlocklist>, +) -> ToolResult<PathBuf> { + // Step 1: Canonicalize the path + let symlink_policy = SymlinkPolicy { + allow_symlink_escape: config.allow_symlink_escape, + follow_symlinks_within_roots: config.follow_symlinks_within_roots, + }; + + let canonical_path = canonicalize_path(Path::new(user_path), &symlink_policy)?; + + // Step 2: Check containment + check_containment(&canonical_path, &config.allowed_roots)?; + + // Step 3: Check blocklist (if provided) + if let Some(compiled) = compiled_blocklist { + check_blocklist_compiled(&canonical_path, &compiled)?; + } + + Ok(canonical_path) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::ToolError; + + #[test] + fn test_validate_path_no_roots() { + let mut config = SandboxConfig::default(); + config.allowed_roots = vec![]; // No roots allowed + + // Any path should fail with no roots + let result = validate_path("/any/path", &config); + assert!(result.is_err()); + assert!(matches!(result, Err(ToolError::SandboxViolation { .. }))); + } + + #[test] + fn test_validate_path_empty_blocklist() { + // Note: Integration tests with real filesystem are in tests/path_normalization.rs + // Unit tests here focus on error conditions without filesystem dependencies + let config = SandboxConfig::default(); + assert_eq!(config.blocked_paths.len(), 2); // Default has .env and secrets + } + + #[test] + fn test_compile_blocklist_empty() { + let compiled = compile_blocklist(&[]).unwrap(); + assert_eq!(compiled.glob_set().len(), 0); + } + + #[test] + fn test_compile_blocklist_valid() { + let patterns = vec!["**/.env".to_string(), "**/secrets/**".to_string()]; + let result = compile_blocklist(&patterns); + assert!(result.is_ok()); + + let compiled = result.unwrap(); + assert_eq!(compiled.glob_set().len(), 2); + } + + #[test] + fn test_compile_blocklist_invalid() { + let patterns = vec!["[invalid".to_string()]; + let result = compile_blocklist(&patterns); + assert!(result.is_err()); + assert!(matches!(result, Err(ToolError::InvalidConfig { .. }))); + } +} diff --git a/crates/dirigent_tools/src/permission.rs b/crates/dirigent_tools/src/permission.rs new file mode 100644 index 0000000..b634935 --- /dev/null +++ b/crates/dirigent_tools/src/permission.rs @@ -0,0 +1,59 @@ +//! Permission prompt system and decision caching. +//! +//! **Status**: Implemented (TOOLS-PERM-01 through TOOLS-PERM-04) +//! +//! This module provides: +//! - **Permission checking** - Core algorithm integrating cache, whitelist, and ACP prompts +//! - **Decision caching** - Thread-safe cache with TTL and scope support +//! - **Whitelist matching** - Pattern-based auto-approval for safe operations +//! - **ACP integration** - User prompts via session/request_permission +//! +//! ## Module Structure +//! +//! - [`check`] - Core permission check function +//! - [`cache`] - Decision cache with TTL +//! - [`whitelist`] - Whitelist pattern matching +//! - [`acp`] - ACP integration for user prompts +//! +//! ## Quick Start +//! +//! ```rust,no_run +//! use dirigent_tools::config::{PermissionConfig, PermissionMode, WhitelistConfig}; +//! use dirigent_tools::permission::check::{PermissionContext, check_permission}; +//! use dirigent_tools::permission::whitelist::{CompiledWhitelist, PermissionOperation}; +//! +//! # async fn example() { +//! // Configure permission system +//! let config = PermissionConfig { +//! mode: PermissionMode::Whitelist, +//! remember_decisions: true, +//! remember_ttl_secs: 3600, +//! ..Default::default() +//! }; +//! +//! // Create context with whitelist +//! let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); +//! let context = PermissionContext::new( +//! "connector-1".to_string(), +//! Some("session-1".to_string()), +//! whitelist, +//! ); +//! +//! // Check permission for an operation +//! let operation = PermissionOperation::Write { +//! path: "/path/to/file".to_string(), +//! }; +//! let decision = check_permission(&operation, &context, &config).await.unwrap(); +//! # } +//! ``` + +pub mod check; +pub mod cache; +pub mod whitelist; +pub mod acp; + +// Re-export commonly used types +pub use check::{check_permission, PermissionContext}; +pub use cache::{CacheKey, DecisionCache, PermissionDecision}; +pub use whitelist::{CompiledWhitelist, PermissionOperation, matches_whitelist}; +pub use acp::{AcpPermissionContext, PermissionOutcome, request_permission_from_user}; diff --git a/crates/dirigent_tools/src/permission/acp.rs b/crates/dirigent_tools/src/permission/acp.rs new file mode 100644 index 0000000..cfe872b --- /dev/null +++ b/crates/dirigent_tools/src/permission/acp.rs @@ -0,0 +1,242 @@ +//! ACP integration for permission prompts. +//! +//! **Status**: Implemented (TOOLS-PERM-03) - Stub for ACP integration +//! +//! This module provides integration with ACP's `session/request_permission` capability +//! to prompt users for permissions when operations require approval. +//! +//! ## Permission Outcomes +//! +//! Users can respond with: +//! - **AllowOnce**: Approve this operation only +//! - **AllowAlways**: Approve this operation and remember the decision +//! - **RejectOnce**: Deny this operation only +//! - **RejectAlways**: Deny this operation and remember the decision +//! - **Cancelled**: User cancelled the prompt (treat as rejection) +//! +//! ## ACP Protocol Integration +//! +//! This module calls the ACP client's `session/request_permission` handler (agent → client). +//! The actual implementation depends on the ACP client infrastructure in `dirigent_core`. +//! +//! For now, this provides a stub that can be replaced with real ACP calls when the +//! infrastructure is ready. + +use crate::error::ToolResult; +use crate::permission::whitelist::PermissionOperation; + +/// Permission outcome from user prompt. +/// +/// Maps directly to ACP session/request_permission response options. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PermissionOutcome { + /// Allow this operation once (do not cache) + AllowOnce, + /// Allow this operation and remember for future (cache with TTL) + AllowAlways, + /// Reject this operation once (do not cache) + RejectOnce, + /// Reject this operation and remember for future (cache with TTL) + RejectAlways, + /// User cancelled the prompt (treat as rejection) + Cancelled, +} + +impl PermissionOutcome { + /// Check if this outcome should be cached. + pub fn should_cache(&self) -> bool { + matches!(self, Self::AllowAlways | Self::RejectAlways) + } + + /// Check if this outcome allows the operation. + pub fn is_allowed(&self) -> bool { + matches!(self, Self::AllowOnce | Self::AllowAlways) + } + + /// Convert to cache decision if this outcome should be cached. + pub fn to_cache_decision(&self) -> Option<crate::permission::cache::PermissionDecision> { + use crate::permission::cache::PermissionDecision; + + match self { + Self::AllowAlways => Some(PermissionDecision::Allowed), + Self::RejectAlways => Some(PermissionDecision::Denied), + Self::Cancelled => Some(PermissionDecision::Cancelled), + Self::AllowOnce | Self::RejectOnce => None, + } + } +} + +/// Context for ACP permission requests. +/// +/// Contains information needed to make the ACP call: +/// - Connector and session identifiers +/// - Reference to the ACP session object (when available) +#[derive(Debug, Clone)] +pub struct AcpPermissionContext { + /// Connector ID + pub connector_id: String, + /// Session ID (if in a session) + pub session_id: Option<String>, + // TODO: Add ACP session handle when available + // pub session_handle: Arc<AcpSession>, +} + +/// Request permission from user via ACP session/request_permission. +/// +/// This calls the ACP protocol's `session/request_permission` capability to present +/// a permission prompt to the user with the following options: +/// - Allow once +/// - Allow always (remember decision) +/// - Reject once +/// - Reject always (remember decision) +/// +/// ## Implementation Status +/// +/// **TODO**: This is currently a stub that returns `AllowOnce` for all operations. +/// Once the ACP client infrastructure in `dirigent_core` is ready, this should be +/// replaced with actual ACP calls. +/// +/// ## Expected ACP Call +/// +/// ```json +/// { +/// "method": "session/request_permission", +/// "params": { +/// "session_id": "session-123", +/// "operation": "write", +/// "path": "C:/work/project/file.txt", +/// "description": "Write to file C:/work/project/file.txt" +/// } +/// } +/// ``` +/// +/// ## Expected Response +/// +/// ```json +/// { +/// "result": { +/// "outcome": "allow_always" | "allow_once" | "reject_always" | "reject_once" | "cancelled" +/// } +/// } +/// ``` +/// +/// ## Errors +/// +/// - `ToolError::PermissionDenied` if the ACP call fails or times out +/// - `ToolError::AcpError` if the ACP protocol reports an error +pub async fn request_permission_from_user( + operation: &PermissionOperation, + context: &AcpPermissionContext, +) -> ToolResult<PermissionOutcome> { + // TODO: Replace with actual ACP call when infrastructure is ready + // + // Expected implementation: + // 1. Get session handle from context + // 2. Call session.request_permission(operation, description) + // 3. Wait for user response (with timeout) + // 4. Map ACP response to PermissionOutcome + // 5. Handle errors and timeouts + + tracing::warn!( + operation = %operation.display(), + connector_id = %context.connector_id, + session_id = ?context.session_id, + "TODO: Actual ACP request_permission call not yet implemented (TOOLS-PERM-03)" + ); + + // For now, always return AllowOnce as a safe default for development + // In production, this should prompt the user via ACP + Ok(PermissionOutcome::AllowOnce) +} + +/// Mock implementation for testing (when ACP infrastructure is not available). +/// +/// This allows tests to simulate different permission outcomes without requiring +/// a full ACP stack. +#[cfg(test)] +pub async fn request_permission_mock( + _operation: &PermissionOperation, + outcome: PermissionOutcome, +) -> ToolResult<PermissionOutcome> { + Ok(outcome) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_permission_outcome_should_cache() { + assert!(PermissionOutcome::AllowAlways.should_cache()); + assert!(PermissionOutcome::RejectAlways.should_cache()); + assert!(!PermissionOutcome::AllowOnce.should_cache()); + assert!(!PermissionOutcome::RejectOnce.should_cache()); + assert!(!PermissionOutcome::Cancelled.should_cache()); + } + + #[test] + fn test_permission_outcome_is_allowed() { + assert!(PermissionOutcome::AllowOnce.is_allowed()); + assert!(PermissionOutcome::AllowAlways.is_allowed()); + assert!(!PermissionOutcome::RejectOnce.is_allowed()); + assert!(!PermissionOutcome::RejectAlways.is_allowed()); + assert!(!PermissionOutcome::Cancelled.is_allowed()); + } + + #[test] + fn test_permission_outcome_to_cache_decision() { + use crate::permission::cache::PermissionDecision; + + assert_eq!( + PermissionOutcome::AllowAlways.to_cache_decision(), + Some(PermissionDecision::Allowed) + ); + assert_eq!( + PermissionOutcome::RejectAlways.to_cache_decision(), + Some(PermissionDecision::Denied) + ); + assert_eq!( + PermissionOutcome::Cancelled.to_cache_decision(), + Some(PermissionDecision::Cancelled) + ); + assert_eq!(PermissionOutcome::AllowOnce.to_cache_decision(), None); + assert_eq!(PermissionOutcome::RejectOnce.to_cache_decision(), None); + } + + #[tokio::test] + async fn test_request_permission_stub() { + let operation = PermissionOperation::Write { + path: "/test/path".to_string(), + }; + let context = AcpPermissionContext { + connector_id: "test-connector".to_string(), + session_id: Some("test-session".to_string()), + }; + + // Stub currently returns AllowOnce + let outcome = request_permission_from_user(&operation, &context) + .await + .unwrap(); + assert_eq!(outcome, PermissionOutcome::AllowOnce); + } + + #[tokio::test] + async fn test_request_permission_mock() { + let operation = PermissionOperation::Execute { + command: "test".to_string(), + cwd: "/".to_string(), + }; + + // Test all outcomes + for outcome in [ + PermissionOutcome::AllowOnce, + PermissionOutcome::AllowAlways, + PermissionOutcome::RejectOnce, + PermissionOutcome::RejectAlways, + PermissionOutcome::Cancelled, + ] { + let result = request_permission_mock(&operation, outcome).await.unwrap(); + assert_eq!(result, outcome); + } + } +} diff --git a/crates/dirigent_tools/src/permission/cache.rs b/crates/dirigent_tools/src/permission/cache.rs new file mode 100644 index 0000000..cf423dd --- /dev/null +++ b/crates/dirigent_tools/src/permission/cache.rs @@ -0,0 +1,389 @@ +//! Decision cache with TTL and scope support. +//! +//! **Status**: Implemented (TOOLS-PERM-02) +//! +//! This module provides thread-safe caching of permission decisions with: +//! - TTL (time-to-live) for cached decisions +//! - Scope support (per-connector or per-session) +//! - Automatic expiration of stale entries +//! - Hash-based cache keys for efficient lookups + +use crate::config::DecisionScope; +use std::collections::HashMap; +use std::hash::Hash; +use std::time::{Duration, Instant}; + +/// Permission decision outcome. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PermissionDecision { + Allowed, + Denied, + Cancelled, +} + +/// Cache key for permission decisions. +/// +/// Keys are constructed from: +/// - Operation kind (read, write, execute) +/// - Normalized path or command +/// - Scope identifier (connector_id or session_id) +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct CacheKey { + /// Operation discriminant + operation_kind: OperationKind, + /// Normalized path or command string + target: String, + /// Scope identifier (connector or session) + scope_id: String, +} + +/// Operation kind for cache key discrimination. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum OperationKind { + Read, + Write, + Execute, +} + +impl CacheKey { + /// Create a cache key for a read operation. + pub fn read(path: &str, connector_id: &str, scope: DecisionScope) -> Self { + Self { + operation_kind: OperationKind::Read, + target: path.to_string(), + scope_id: Self::scope_id(connector_id, None, scope), + } + } + + /// Create a cache key for a write operation. + pub fn write(path: &str, connector_id: &str, session_id: Option<&str>, scope: DecisionScope) -> Self { + Self { + operation_kind: OperationKind::Write, + target: path.to_string(), + scope_id: Self::scope_id(connector_id, session_id, scope), + } + } + + /// Create a cache key for an execute operation. + pub fn execute(command: &str, connector_id: &str, session_id: Option<&str>, scope: DecisionScope) -> Self { + Self { + operation_kind: OperationKind::Execute, + target: command.to_string(), + scope_id: Self::scope_id(connector_id, session_id, scope), + } + } + + /// Construct scope identifier based on scope policy. + fn scope_id(connector_id: &str, session_id: Option<&str>, scope: DecisionScope) -> String { + match scope { + DecisionScope::PerConnector => connector_id.to_string(), + DecisionScope::PerSession => { + format!("{}:{}", connector_id, session_id.unwrap_or("default")) + } + } + } +} + +/// Cached permission decision with expiration. +#[derive(Debug, Clone)] +struct CachedDecision { + /// The permission decision + decision: PermissionDecision, + /// When this entry expires + expires_at: Instant, +} + +impl CachedDecision { + /// Create a new cached decision with TTL. + fn new(decision: PermissionDecision, ttl: Duration) -> Self { + Self { + decision, + expires_at: Instant::now() + ttl, + } + } + + /// Check if this entry has expired. + fn is_expired(&self) -> bool { + Instant::now() >= self.expires_at + } +} + +/// Thread-safe decision cache with TTL support. +/// +/// The cache stores permission decisions keyed by operation, path/command, and scope. +/// Entries automatically expire after their TTL, and expired entries are pruned on access. +/// +/// ## Thread Safety +/// +/// The cache is designed to be wrapped in `Arc<Mutex<DecisionCache>>` for thread-safe access. +/// Individual operations (get, insert, clear) should acquire the lock briefly. +/// +/// ## Example +/// +/// ```rust +/// use dirigent_tools::permission::cache::{DecisionCache, PermissionDecision, CacheKey}; +/// use dirigent_tools::config::DecisionScope; +/// use std::time::Duration; +/// +/// let mut cache = DecisionCache::new(); +/// let key = CacheKey::write("/path/to/file", "connector-1", None, DecisionScope::PerConnector); +/// +/// // Cache a decision +/// cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_secs(300)); +/// +/// // Retrieve it +/// assert_eq!(cache.get(&key), Some(PermissionDecision::Allowed)); +/// ``` +#[derive(Debug)] +pub struct DecisionCache { + entries: HashMap<CacheKey, CachedDecision>, +} + +impl DecisionCache { + /// Create a new empty decision cache. + pub fn new() -> Self { + Self { + entries: HashMap::new(), + } + } + + /// Get a cached decision if it exists and hasn't expired. + /// + /// Expired entries are automatically removed. + /// + /// Returns `Some(decision)` if a valid cached entry exists, `None` otherwise. + pub fn get(&mut self, key: &CacheKey) -> Option<PermissionDecision> { + // Check if entry exists + if let Some(cached) = self.entries.get(key) { + // Check if expired + if cached.is_expired() { + // Remove expired entry + self.entries.remove(key); + None + } else { + // Return valid decision + Some(cached.decision) + } + } else { + None + } + } + + /// Insert a new decision into the cache with the given TTL. + /// + /// If an entry already exists for this key, it will be replaced. + pub fn insert(&mut self, key: CacheKey, decision: PermissionDecision, ttl: Duration) { + self.entries.insert(key, CachedDecision::new(decision, ttl)); + } + + /// Clear all cached decisions. + /// + /// Useful for manual cache invalidation or testing. + pub fn clear(&mut self) { + self.entries.clear(); + } + + /// Remove all expired entries from the cache. + /// + /// This is automatically done during `get()` operations, but can be called + /// periodically to prune the cache and free memory. + /// + /// Returns the number of expired entries removed. + pub fn clear_expired(&mut self) -> usize { + let before = self.entries.len(); + self.entries.retain(|_, cached| !cached.is_expired()); + before - self.entries.len() + } + + /// Get the number of cached entries (including expired). + /// + /// For testing and monitoring purposes. + pub fn len(&self) -> usize { + self.entries.len() + } + + /// Check if the cache is empty. + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } +} + +impl Default for DecisionCache { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cache_key_creation() { + let key1 = CacheKey::write("/path/to/file", "conn-1", None, DecisionScope::PerConnector); + let key2 = CacheKey::write("/path/to/file", "conn-1", None, DecisionScope::PerConnector); + let key3 = CacheKey::write("/path/to/file", "conn-2", None, DecisionScope::PerConnector); + + // Same connector and path should produce equal keys + assert_eq!(key1, key2); + + // Different connector should produce different keys + assert_ne!(key1, key3); + } + + #[test] + fn test_cache_key_scope() { + let key_connector = CacheKey::write( + "/path", + "conn-1", + Some("session-1"), + DecisionScope::PerConnector, + ); + let key_session = CacheKey::write( + "/path", + "conn-1", + Some("session-1"), + DecisionScope::PerSession, + ); + + // Different scopes should produce different keys + assert_ne!(key_connector, key_session); + } + + #[test] + fn test_cache_key_operation_kind() { + let key_read = CacheKey::read("/path", "conn-1", DecisionScope::PerConnector); + let key_write = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector); + + // Different operations should produce different keys + assert_ne!(key_read, key_write); + } + + #[test] + fn test_cache_insert_and_get() { + let mut cache = DecisionCache::new(); + let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector); + + // Initially empty + assert_eq!(cache.get(&key), None); + + // Insert decision + cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_secs(300)); + + // Should retrieve it + assert_eq!(cache.get(&key), Some(PermissionDecision::Allowed)); + } + + #[test] + fn test_cache_ttl_expiration() { + let mut cache = DecisionCache::new(); + let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector); + + // Insert with very short TTL + cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_millis(1)); + + // Wait for expiration + std::thread::sleep(Duration::from_millis(10)); + + // Should not retrieve expired entry + assert_eq!(cache.get(&key), None); + + // Entry should be removed from cache + assert_eq!(cache.len(), 0); + } + + #[test] + fn test_cache_clear() { + let mut cache = DecisionCache::new(); + let key1 = CacheKey::write("/path1", "conn-1", None, DecisionScope::PerConnector); + let key2 = CacheKey::write("/path2", "conn-1", None, DecisionScope::PerConnector); + + cache.insert(key1.clone(), PermissionDecision::Allowed, Duration::from_secs(300)); + cache.insert(key2.clone(), PermissionDecision::Denied, Duration::from_secs(300)); + + assert_eq!(cache.len(), 2); + + // Clear all entries + cache.clear(); + + assert_eq!(cache.len(), 0); + assert_eq!(cache.get(&key1), None); + assert_eq!(cache.get(&key2), None); + } + + #[test] + fn test_cache_clear_expired() { + let mut cache = DecisionCache::new(); + let key1 = CacheKey::write("/path1", "conn-1", None, DecisionScope::PerConnector); + let key2 = CacheKey::write("/path2", "conn-1", None, DecisionScope::PerConnector); + + // Insert one short-lived and one long-lived entry + cache.insert(key1.clone(), PermissionDecision::Allowed, Duration::from_millis(1)); + cache.insert(key2.clone(), PermissionDecision::Denied, Duration::from_secs(300)); + + assert_eq!(cache.len(), 2); + + // Wait for first to expire + std::thread::sleep(Duration::from_millis(10)); + + // Clear expired entries + let removed = cache.clear_expired(); + assert_eq!(removed, 1); + assert_eq!(cache.len(), 1); + + // Second entry should still be accessible + assert_eq!(cache.get(&key2), Some(PermissionDecision::Denied)); + } + + #[test] + fn test_cache_replace_entry() { + let mut cache = DecisionCache::new(); + let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector); + + // Insert initial decision + cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_secs(300)); + assert_eq!(cache.get(&key), Some(PermissionDecision::Allowed)); + + // Replace with different decision + cache.insert(key.clone(), PermissionDecision::Denied, Duration::from_secs(300)); + assert_eq!(cache.get(&key), Some(PermissionDecision::Denied)); + + // Should only have one entry + assert_eq!(cache.len(), 1); + } + + #[test] + fn test_cache_different_sessions() { + let mut cache = DecisionCache::new(); + let key_session1 = CacheKey::write( + "/path", + "conn-1", + Some("session-1"), + DecisionScope::PerSession, + ); + let key_session2 = CacheKey::write( + "/path", + "conn-1", + Some("session-2"), + DecisionScope::PerSession, + ); + + // Insert different decisions for different sessions + cache.insert(key_session1.clone(), PermissionDecision::Allowed, Duration::from_secs(300)); + cache.insert(key_session2.clone(), PermissionDecision::Denied, Duration::from_secs(300)); + + // Each session should have its own decision + assert_eq!(cache.get(&key_session1), Some(PermissionDecision::Allowed)); + assert_eq!(cache.get(&key_session2), Some(PermissionDecision::Denied)); + } + + #[test] + fn test_cache_cancelled_decision() { + let mut cache = DecisionCache::new(); + let key = CacheKey::write("/path", "conn-1", None, DecisionScope::PerConnector); + + // Cancelled decisions can also be cached + cache.insert(key.clone(), PermissionDecision::Cancelled, Duration::from_secs(300)); + assert_eq!(cache.get(&key), Some(PermissionDecision::Cancelled)); + } +} diff --git a/crates/dirigent_tools/src/permission/check.rs b/crates/dirigent_tools/src/permission/check.rs new file mode 100644 index 0000000..7c895e4 --- /dev/null +++ b/crates/dirigent_tools/src/permission/check.rs @@ -0,0 +1,366 @@ +//! Core permission check function integrating cache, whitelist, and ACP prompts. +//! +//! **Status**: Implemented (TOOLS-PERM-01) +//! +//! This module provides the main `check_permission` function that orchestrates: +//! - Permission mode evaluation (yolo, whitelist, ask) +//! - Decision cache lookup and updates +//! - Whitelist pattern matching +//! - ACP permission prompts +//! +//! ## Algorithm +//! +//! 1. **Yolo mode**: Always allow (with audit logging) +//! 2. **Check cache**: Return cached decision if present and valid +//! 3. **Whitelist mode**: Auto-approve if operation matches whitelist +//! 4. **Ask mode / No whitelist match**: Prompt user via ACP +//! 5. **Cache decision**: Store if user selected "always" option +//! 6. **Return decision**: Allow or deny the operation + +use crate::config::{DecisionScope, PermissionConfig, PermissionMode}; +use crate::error::ToolResult; +use crate::permission::acp::{request_permission_from_user, AcpPermissionContext, PermissionOutcome}; +use crate::permission::cache::{CacheKey, DecisionCache, PermissionDecision}; +use crate::permission::whitelist::{matches_whitelist, CompiledWhitelist, PermissionOperation}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +/// Context for permission checks. +/// +/// This should be created once per connector or session and reused for all +/// permission checks to maintain consistent cache state. +#[derive(Clone)] +pub struct PermissionContext { + /// Connector ID + pub connector_id: String, + /// Session ID (if in a session) + pub session_id: Option<String>, + /// Shared decision cache (thread-safe) + pub cache: Arc<Mutex<DecisionCache>>, + /// Compiled whitelist for fast matching + pub whitelist: Arc<CompiledWhitelist>, +} + +impl PermissionContext { + /// Create a new permission context. + pub fn new( + connector_id: String, + session_id: Option<String>, + whitelist: CompiledWhitelist, + ) -> Self { + Self { + connector_id, + session_id, + cache: Arc::new(Mutex::new(DecisionCache::new())), + whitelist: Arc::new(whitelist), + } + } + + /// Clear the decision cache (for testing or manual reset). + pub fn clear_cache(&self) { + if let Ok(mut cache) = self.cache.lock() { + cache.clear(); + } + } + + /// Get cache statistics (for monitoring/debugging). + pub fn cache_size(&self) -> usize { + self.cache.lock().map(|c| c.len()).unwrap_or(0) + } +} + +/// Check if an operation requires permission and prompt if needed. +/// +/// This is the main entry point for permission checks. It implements the full +/// permission algorithm including mode evaluation, caching, whitelist matching, +/// and ACP prompts. +/// +/// ## Algorithm +/// +/// 1. If mode is `Yolo`: Always allow (with audit log) +/// 2. Check decision cache (with TTL) +/// 3. If cached decision exists: Return it +/// 4. If mode is `Whitelist` and operation matches: Allow +/// 5. Otherwise: Call ACP `session/request_permission` +/// 6. If outcome is "always": Cache the decision +/// 7. Return decision (Allowed/Denied/Cancelled) +/// +/// ## Examples +/// +/// ```rust,no_run +/// use dirigent_tools::config::{PermissionConfig, PermissionMode, WhitelistConfig}; +/// use dirigent_tools::permission::check::{PermissionContext, check_permission}; +/// use dirigent_tools::permission::whitelist::{CompiledWhitelist, PermissionOperation}; +/// +/// # async fn example() { +/// let config = PermissionConfig { +/// mode: PermissionMode::Ask, +/// remember_decisions: true, +/// remember_ttl_secs: 3600, +/// ..Default::default() +/// }; +/// +/// let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); +/// let context = PermissionContext::new( +/// "connector-1".to_string(), +/// Some("session-1".to_string()), +/// whitelist, +/// ); +/// +/// let operation = PermissionOperation::Write { +/// path: "/path/to/file".to_string(), +/// }; +/// +/// let decision = check_permission(&operation, &context, &config).await.unwrap(); +/// # } +/// ``` +pub async fn check_permission( + operation: &PermissionOperation, + context: &PermissionContext, + config: &PermissionConfig, +) -> ToolResult<PermissionDecision> { + // Step 1: Yolo mode - always allow + if config.mode == PermissionMode::Yolo { + tracing::info!( + operation = %operation.display(), + mode = "yolo", + "Permission check: auto-approved (yolo mode)" + ); + return Ok(PermissionDecision::Allowed); + } + + // Step 2: Check decision cache + if config.remember_decisions { + let cache_key = create_cache_key(operation, context, config.scope); + + if let Ok(mut cache) = context.cache.lock() { + if let Some(cached_decision) = cache.get(&cache_key) { + tracing::debug!( + operation = %operation.display(), + decision = ?cached_decision, + "Permission check: using cached decision" + ); + return Ok(cached_decision); + } + } + } + + // Step 3: Whitelist mode - check if operation matches whitelist + if config.mode == PermissionMode::Whitelist { + if matches_whitelist(operation, &context.whitelist) { + tracing::info!( + operation = %operation.display(), + mode = "whitelist", + "Permission check: auto-approved (whitelist match)" + ); + return Ok(PermissionDecision::Allowed); + } + } + + // Step 4: Prompt user via ACP + tracing::debug!( + operation = %operation.display(), + mode = ?config.mode, + "Permission check: prompting user" + ); + + let acp_context = AcpPermissionContext { + connector_id: context.connector_id.clone(), + session_id: context.session_id.clone(), + }; + + let outcome = request_permission_from_user(operation, &acp_context).await?; + + // Step 5: Convert outcome to decision + let decision = if outcome.is_allowed() { + PermissionDecision::Allowed + } else { + match outcome { + PermissionOutcome::Cancelled => PermissionDecision::Cancelled, + _ => PermissionDecision::Denied, + } + }; + + // Step 6: Cache decision if "always" option was selected + if config.remember_decisions && outcome.should_cache() { + if let Some(cache_decision) = outcome.to_cache_decision() { + let cache_key = create_cache_key(operation, context, config.scope); + let ttl = Duration::from_secs(config.remember_ttl_secs); + + if let Ok(mut cache) = context.cache.lock() { + cache.insert(cache_key, cache_decision, ttl); + tracing::debug!( + operation = %operation.display(), + decision = ?cache_decision, + ttl_secs = config.remember_ttl_secs, + "Cached permission decision" + ); + } + } + } + + tracing::info!( + operation = %operation.display(), + decision = ?decision, + outcome = ?outcome, + "Permission check: decision made" + ); + + Ok(decision) +} + +/// Create a cache key for an operation with the appropriate scope. +fn create_cache_key( + operation: &PermissionOperation, + context: &PermissionContext, + scope: DecisionScope, +) -> CacheKey { + match operation { + PermissionOperation::Read { path } => { + CacheKey::read(path, &context.connector_id, scope) + } + PermissionOperation::Write { path } => { + CacheKey::write(path, &context.connector_id, context.session_id.as_deref(), scope) + } + PermissionOperation::Execute { command, .. } => { + CacheKey::execute(command, &context.connector_id, context.session_id.as_deref(), scope) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::WhitelistConfig; + + fn create_test_config(mode: PermissionMode) -> PermissionConfig { + PermissionConfig { + mode, + remember_decisions: true, + remember_ttl_secs: 300, + scope: DecisionScope::PerConnector, + whitelist: WhitelistConfig::default(), + } + } + + fn create_test_context() -> PermissionContext { + let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); + PermissionContext::new( + "test-connector".to_string(), + Some("test-session".to_string()), + whitelist, + ) + } + + #[tokio::test] + async fn test_yolo_mode_always_allows() { + let config = create_test_config(PermissionMode::Yolo); + let context = create_test_context(); + + let operations = vec![ + PermissionOperation::Read { path: "/any/path".to_string() }, + PermissionOperation::Write { path: "/any/path".to_string() }, + PermissionOperation::Execute { + command: "dangerous_command".to_string(), + cwd: "/".to_string(), + }, + ]; + + for operation in operations { + let decision = check_permission(&operation, &context, &config).await.unwrap(); + assert_eq!(decision, PermissionDecision::Allowed); + } + } + + #[tokio::test] + async fn test_whitelist_mode_read_always_allowed() { + let config = create_test_config(PermissionMode::Whitelist); + let context = create_test_context(); + + let operation = PermissionOperation::Read { + path: "/any/path".to_string(), + }; + + let decision = check_permission(&operation, &context, &config).await.unwrap(); + assert_eq!(decision, PermissionDecision::Allowed); + } + + #[tokio::test] + async fn test_whitelist_mode_with_pattern() { + let mut config = create_test_config(PermissionMode::Whitelist); + config.whitelist = WhitelistConfig { + write_paths: vec!["C:/work/**".to_string()], + execute_commands: vec!["cargo".to_string()], + }; + + let whitelist = CompiledWhitelist::compile(&config.whitelist).unwrap(); + let context = PermissionContext::new( + "test-connector".to_string(), + None, + whitelist, + ); + + // Should match whitelist + let write_ok = PermissionOperation::Write { + path: "C:/work/project/file.txt".to_string(), + }; + let decision = check_permission(&write_ok, &context, &config).await.unwrap(); + assert_eq!(decision, PermissionDecision::Allowed); + + // TODO: Test non-matching write would require ACP mock + // This is tested in integration tests + } + + #[tokio::test] + async fn test_cache_key_creation() { + let context = create_test_context(); + let scope = DecisionScope::PerConnector; + + let read_op = PermissionOperation::Read { + path: "/path".to_string(), + }; + let key1 = create_cache_key(&read_op, &context, scope); + let key2 = create_cache_key(&read_op, &context, scope); + + // Same operation should produce same key + assert_eq!(key1, key2); + } + + #[tokio::test] + async fn test_context_cache_operations() { + let context = create_test_context(); + + assert_eq!(context.cache_size(), 0); + + // Add entry to cache + { + let mut cache = context.cache.lock().unwrap(); + let key = CacheKey::write("/path", "test", None, DecisionScope::PerConnector); + cache.insert(key, PermissionDecision::Allowed, Duration::from_secs(300)); + } + + assert_eq!(context.cache_size(), 1); + + // Clear cache + context.clear_cache(); + assert_eq!(context.cache_size(), 0); + } + + #[test] + fn test_permission_context_clone() { + let context = create_test_context(); + let cloned = context.clone(); + + // Should share the same cache + assert_eq!(context.cache_size(), cloned.cache_size()); + + // Modifications to one should affect the other + { + let mut cache = context.cache.lock().unwrap(); + let key = CacheKey::write("/path", "test", None, DecisionScope::PerConnector); + cache.insert(key, PermissionDecision::Allowed, Duration::from_secs(300)); + } + + assert_eq!(cloned.cache_size(), 1); + } +} diff --git a/crates/dirigent_tools/src/permission/whitelist.rs b/crates/dirigent_tools/src/permission/whitelist.rs new file mode 100644 index 0000000..87df219 --- /dev/null +++ b/crates/dirigent_tools/src/permission/whitelist.rs @@ -0,0 +1,397 @@ +//! Whitelist pattern matching for auto-approval of safe operations. +//! +//! **Status**: Implemented (TOOLS-PERM-04) +//! +//! This module provides pattern matching against configured whitelists to +//! automatically approve safe operations without prompting the user. +//! +//! ## Whitelist Strategy +//! +//! - **Read operations**: Always approved in whitelist mode (reads are safe) +//! - **Write operations**: Match path against `write_paths` glob patterns +//! - **Execute operations**: Match command against `execute_commands` glob patterns +//! +//! ## Pattern Matching +//! +//! Uses `globset` for efficient compiled glob patterns: +//! - `**` matches any number of path segments +//! - `*` matches any characters within a segment +//! - `?` matches a single character +//! - Character classes like `[abc]` are supported +//! +//! ## Performance +//! +//! Globs are pre-compiled at configuration load time and stored in `CompiledWhitelist` +//! for fast matching during permission checks. + +use crate::config::WhitelistConfig; +use crate::error::{ToolError, ToolResult}; +use globset::{Glob, GlobSet, GlobSetBuilder}; +use std::path::Path; + +/// Compiled whitelist with pre-built glob sets for efficient matching. +/// +/// This should be constructed once at configuration load time and reused +/// for all permission checks. +#[derive(Debug, Clone)] +pub struct CompiledWhitelist { + /// Compiled glob patterns for write path matching + write_paths: GlobSet, + /// Compiled glob patterns for execute command matching + execute_commands: GlobSet, +} + +impl CompiledWhitelist { + /// Compile a whitelist configuration into an efficient matcher. + /// + /// ## Errors + /// + /// Returns an error if any glob pattern is invalid. + pub fn compile(config: &WhitelistConfig) -> ToolResult<Self> { + let write_paths = Self::compile_patterns(&config.write_paths)?; + let execute_commands = Self::compile_patterns(&config.execute_commands)?; + + Ok(Self { + write_paths, + execute_commands, + }) + } + + /// Compile a list of glob patterns into a GlobSet. + fn compile_patterns(patterns: &[String]) -> ToolResult<GlobSet> { + let mut builder = GlobSetBuilder::new(); + + for pattern in patterns { + let glob = Glob::new(pattern).map_err(|e| { + ToolError::InvalidConfig(format!("Invalid whitelist glob pattern '{}': {}", pattern, e)) + })?; + builder.add(glob); + } + + builder.build().map_err(|e| { + ToolError::InvalidConfig(format!("Failed to build whitelist glob set: {}", e)) + }) + } + + /// Check if a read operation matches the whitelist. + /// + /// In whitelist mode, all read operations are considered safe and return true. + pub fn matches_read(&self, _path: &Path) -> bool { + // Reads are always safe in whitelist mode + true + } + + /// Check if a write operation matches the whitelist. + /// + /// Returns true if the path matches any of the configured write_paths patterns. + pub fn matches_write(&self, path: &Path) -> bool { + self.write_paths.is_match(path) + } + + /// Check if an execute operation matches the whitelist. + /// + /// Returns true if the command matches any of the configured execute_commands patterns. + /// + /// The command is compared as a string (not a path) to support both: + /// - Simple command names (e.g., "cargo", "npm") + /// - Command patterns (e.g., "cargo*", "npm*") + pub fn matches_execute(&self, command: &str) -> bool { + // For command matching, we treat the command as a simple string path + // This allows patterns like "cargo", "npm*", etc. + self.execute_commands.is_match(command) + } + + /// Check if the whitelist has any write patterns configured. + pub fn has_write_patterns(&self) -> bool { + !self.write_paths.is_empty() + } + + /// Check if the whitelist has any execute patterns configured. + pub fn has_execute_patterns(&self) -> bool { + !self.execute_commands.is_empty() + } +} + +/// Operation type for whitelist matching. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PermissionOperation { + Read { path: String }, + Write { path: String }, + Execute { command: String, cwd: String }, +} + +impl PermissionOperation { + /// Get a display string for this operation (for logging/debugging). + pub fn display(&self) -> String { + match self { + Self::Read { path } => format!("read {}", path), + Self::Write { path } => format!("write {}", path), + Self::Execute { command, cwd } => format!("execute '{}' in {}", command, cwd), + } + } + + /// Get the operation kind as a string. + pub fn kind(&self) -> &'static str { + match self { + Self::Read { .. } => "read", + Self::Write { .. } => "write", + Self::Execute { .. } => "execute", + } + } +} + +/// Check if an operation matches the configured whitelist. +/// +/// ## Returns +/// +/// - `true` if the operation should be auto-approved +/// - `false` if the operation requires a permission prompt +/// +/// ## Examples +/// +/// ```rust +/// use dirigent_tools::config::WhitelistConfig; +/// use dirigent_tools::permission::whitelist::{CompiledWhitelist, PermissionOperation, matches_whitelist}; +/// +/// let config = WhitelistConfig { +/// write_paths: vec!["C:/work/project/**".to_string()], +/// execute_commands: vec!["cargo".to_string(), "npm".to_string()], +/// }; +/// +/// let whitelist = CompiledWhitelist::compile(&config).unwrap(); +/// +/// // Read operations always match +/// let read_op = PermissionOperation::Read { path: "C:/anywhere/file.txt".to_string() }; +/// assert!(matches_whitelist(&read_op, &whitelist)); +/// +/// // Write within allowed path +/// let write_op = PermissionOperation::Write { path: "C:/work/project/src/main.rs".to_string() }; +/// assert!(matches_whitelist(&write_op, &whitelist)); +/// +/// // Execute whitelisted command +/// let exec_op = PermissionOperation::Execute { +/// command: "cargo".to_string(), +/// cwd: "C:/work/project".to_string(), +/// }; +/// assert!(matches_whitelist(&exec_op, &whitelist)); +/// ``` +pub fn matches_whitelist(operation: &PermissionOperation, whitelist: &CompiledWhitelist) -> bool { + match operation { + PermissionOperation::Read { path } => { + whitelist.matches_read(Path::new(path)) + } + PermissionOperation::Write { path } => { + whitelist.matches_write(Path::new(path)) + } + PermissionOperation::Execute { command, .. } => { + whitelist.matches_execute(command) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn create_test_whitelist() -> CompiledWhitelist { + let config = WhitelistConfig { + write_paths: vec![ + "C:/work/project/**".to_string(), + "/home/user/project/**".to_string(), + "**/safe_dir/**".to_string(), + ], + execute_commands: vec![ + "cargo".to_string(), + "npm".to_string(), + "git".to_string(), + "python*".to_string(), + ], + }; + CompiledWhitelist::compile(&config).unwrap() + } + + #[test] + fn test_read_always_matches() { + let whitelist = create_test_whitelist(); + + // Any read operation should match + assert!(whitelist.matches_read(Path::new("C:/anywhere/file.txt"))); + assert!(whitelist.matches_read(Path::new("/random/path"))); + assert!(whitelist.matches_read(Path::new("../../../etc/passwd"))); + } + + #[test] + fn test_write_matches_patterns() { + let whitelist = create_test_whitelist(); + + // Should match configured patterns + assert!(whitelist.matches_write(Path::new("C:/work/project/src/main.rs"))); + assert!(whitelist.matches_write(Path::new("C:/work/project/Cargo.toml"))); + assert!(whitelist.matches_write(Path::new("/home/user/project/README.md"))); + + // Should not match paths outside patterns + assert!(!whitelist.matches_write(Path::new("C:/other/project/file.txt"))); + assert!(!whitelist.matches_write(Path::new("/tmp/file.txt"))); + } + + #[test] + fn test_write_matches_relative_patterns() { + let whitelist = create_test_whitelist(); + + // Pattern **/safe_dir/** should match anywhere + assert!(whitelist.matches_write(Path::new("some/path/safe_dir/file.txt"))); + assert!(whitelist.matches_write(Path::new("safe_dir/file.txt"))); + } + + #[test] + fn test_execute_matches_commands() { + let whitelist = create_test_whitelist(); + + // Should match exact command names + assert!(whitelist.matches_execute("cargo")); + assert!(whitelist.matches_execute("npm")); + assert!(whitelist.matches_execute("git")); + + // Should match patterns + assert!(whitelist.matches_execute("python")); + assert!(whitelist.matches_execute("python3")); + + // Should not match unlisted commands + assert!(!whitelist.matches_execute("rm")); + assert!(!whitelist.matches_execute("format")); + assert!(!whitelist.matches_execute("del")); + } + + #[test] + fn test_empty_whitelist() { + let config = WhitelistConfig { + write_paths: vec![], + execute_commands: vec![], + }; + let whitelist = CompiledWhitelist::compile(&config).unwrap(); + + // Reads should still match + assert!(whitelist.matches_read(Path::new("any/path"))); + + // Writes and executes should not match empty whitelist + assert!(!whitelist.matches_write(Path::new("any/path"))); + assert!(!whitelist.matches_execute("any_command")); + } + + #[test] + fn test_invalid_glob_pattern() { + let config = WhitelistConfig { + write_paths: vec!["[invalid".to_string()], // Unclosed bracket + execute_commands: vec![], + }; + + let result = CompiledWhitelist::compile(&config); + assert!(result.is_err()); + } + + #[test] + fn test_operation_display() { + let read_op = PermissionOperation::Read { + path: "/path/to/file".to_string(), + }; + assert_eq!(read_op.display(), "read /path/to/file"); + + let write_op = PermissionOperation::Write { + path: "/path/to/file".to_string(), + }; + assert_eq!(write_op.display(), "write /path/to/file"); + + let exec_op = PermissionOperation::Execute { + command: "cargo test".to_string(), + cwd: "/project".to_string(), + }; + assert_eq!(exec_op.display(), "execute 'cargo test' in /project"); + } + + #[test] + fn test_operation_kind() { + let read_op = PermissionOperation::Read { + path: "/path".to_string(), + }; + assert_eq!(read_op.kind(), "read"); + + let write_op = PermissionOperation::Write { + path: "/path".to_string(), + }; + assert_eq!(write_op.kind(), "write"); + + let exec_op = PermissionOperation::Execute { + command: "cmd".to_string(), + cwd: "/".to_string(), + }; + assert_eq!(exec_op.kind(), "execute"); + } + + #[test] + fn test_matches_whitelist_function() { + let whitelist = create_test_whitelist(); + + // Test read + let read_op = PermissionOperation::Read { + path: "C:/any/file.txt".to_string(), + }; + assert!(matches_whitelist(&read_op, &whitelist)); + + // Test write (match) + let write_ok = PermissionOperation::Write { + path: "C:/work/project/file.txt".to_string(), + }; + assert!(matches_whitelist(&write_ok, &whitelist)); + + // Test write (no match) + let write_fail = PermissionOperation::Write { + path: "C:/other/file.txt".to_string(), + }; + assert!(!matches_whitelist(&write_fail, &whitelist)); + + // Test execute (match) + let exec_ok = PermissionOperation::Execute { + command: "cargo".to_string(), + cwd: "/project".to_string(), + }; + assert!(matches_whitelist(&exec_ok, &whitelist)); + + // Test execute (no match) + let exec_fail = PermissionOperation::Execute { + command: "rm".to_string(), + cwd: "/".to_string(), + }; + assert!(!matches_whitelist(&exec_fail, &whitelist)); + } + + #[test] + fn test_has_patterns() { + let config = WhitelistConfig { + write_paths: vec!["some/path/**".to_string()], + execute_commands: vec![], + }; + let whitelist = CompiledWhitelist::compile(&config).unwrap(); + + assert!(whitelist.has_write_patterns()); + assert!(!whitelist.has_execute_patterns()); + } + + #[cfg(target_os = "windows")] + #[test] + fn test_windows_paths() { + let config = WhitelistConfig { + write_paths: vec![ + "C:\\work\\project\\**".to_string(), + "\\\\server\\share\\**".to_string(), // UNC path + ], + execute_commands: vec![], + }; + let whitelist = CompiledWhitelist::compile(&config).unwrap(); + + // Test Windows backslash paths + assert!(whitelist.matches_write(Path::new("C:\\work\\project\\file.txt"))); + + // Test UNC paths + assert!(whitelist.matches_write(Path::new("\\\\server\\share\\file.txt"))); + } +} diff --git a/crates/dirigent_tools/src/registry/mod.rs b/crates/dirigent_tools/src/registry/mod.rs new file mode 100644 index 0000000..75a6bd2 --- /dev/null +++ b/crates/dirigent_tools/src/registry/mod.rs @@ -0,0 +1,303 @@ +//! Tool registry: built-ins (compile-time) + per-session dynamic entries. + +use crate::tool::{AnyTool, ClientKind, ProtocolKind, ToolContext}; +use std::collections::HashMap; +use std::sync::{Arc, RwLock}; + +/// Where a dynamic tool came from. +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub enum ToolSource { + Mcp(Arc<str>), + Custom(Arc<str>), +} + +/// A registered dynamic tool with optional client/protocol filters. +#[derive(Clone)] +pub struct DynamicEntry { + pub tool: Arc<dyn AnyTool>, + pub source: ToolSource, + pub only_for_client: Option<ClientKind>, + pub only_for_protocol: Option<ProtocolKind>, +} + +/// Resolution policy when a name exists in both built-ins and dynamic. +#[derive(Clone, Debug, Default)] +pub enum CollisionPolicy { + #[default] + BuiltInWins, + DynamicWins, + PerName(HashMap<Arc<str>, Winner>), +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum Winner { BuiltIn, Dynamic } + +pub struct ToolRegistry { + built_ins: HashMap<&'static str, Arc<dyn AnyTool>>, + dynamic: RwLock<HashMap<Arc<str>, HashMap<String, DynamicEntry>>>, + collision_policy: CollisionPolicy, +} + +impl ToolRegistry { + /// Canonical scope key for dynamic-tool storage. + /// + /// Session takes precedence so tools registered for a single session + /// don't leak across the connector; absent that, the connector itself + /// scopes the entry. + fn scope_key(ctx: &ToolContext) -> Arc<str> { + ctx.session_id + .clone() + .unwrap_or_else(|| ctx.connector_id.clone()) + } + + pub fn new( + built_ins: impl IntoIterator<Item = Arc<dyn AnyTool>>, + collision_policy: CollisionPolicy, + ) -> Self { + let built_ins = built_ins + .into_iter() + .map(|t| (t.name(), t)) + .collect(); + Self { + built_ins, + dynamic: RwLock::new(HashMap::new()), + collision_policy, + } + } + + pub fn resolve(&self, name: &str, ctx: &ToolContext) -> Option<Arc<dyn AnyTool>> { + let dyn_match = self.find_dynamic(name, ctx); + let builtin = self.built_ins.get(name).cloned(); + match (builtin, dyn_match) { + (Some(b), None) => Some(b), + (None, Some(d)) => Some(d.tool), + (None, None) => None, + (Some(b), Some(d)) => match &self.collision_policy { + CollisionPolicy::BuiltInWins => Some(b), + CollisionPolicy::DynamicWins => Some(d.tool), + CollisionPolicy::PerName(map) => match map.get(name) { + Some(Winner::BuiltIn) | None => Some(b), + Some(Winner::Dynamic) => Some(d.tool), + }, + }, + } + } + + /// Enumerate all tools visible in `ctx`'s scope. + /// + /// Returns built-in tool names plus any dynamic tools registered under + /// the canonical scope key for `ctx` (see [`Self::scope_key`]) whose + /// optional client/protocol filters match. The result is sorted and + /// deduplicated. + pub fn list(&self, ctx: &ToolContext) -> Vec<String> { + let mut names: Vec<String> = self.built_ins.keys().map(|n| n.to_string()).collect(); + let key = Self::scope_key(ctx); + if let Some(per_scope) = self.dynamic.read().unwrap().get(&key) { + for (n, e) in per_scope.iter() { + if entry_matches_ctx(e, ctx) { names.push(n.clone()); } + } + } + names.sort(); + names.dedup(); + names + } + + /// Insert a dynamic tool under `scope_key`. + /// + /// Callers must pass the value [`Self::scope_key`] would produce for the + /// registering [`ToolContext`]: the `session_id` if present, otherwise + /// the `connector_id`. Anything else will be invisible to `resolve` and + /// `list` for that context. + pub fn register_dynamic(&self, scope_key: impl Into<Arc<str>>, name: String, entry: DynamicEntry) { + let mut g = self.dynamic.write().unwrap(); + g.entry(scope_key.into()).or_default().insert(name, entry); + } + + pub fn unregister_scope(&self, scope_key: &str) { + self.dynamic.write().unwrap().remove(scope_key); + } + + fn find_dynamic(&self, name: &str, ctx: &ToolContext) -> Option<DynamicEntry> { + let key = Self::scope_key(ctx); + let g = self.dynamic.read().unwrap(); + let per_scope = g.get(&key)?; + let entry = per_scope.get(name)?; + if entry_matches_ctx(entry, ctx) { Some(entry.clone()) } else { None } + } +} + +fn entry_matches_ctx(entry: &DynamicEntry, ctx: &ToolContext) -> bool { + if let Some(c) = &entry.only_for_client { + if c != &ctx.client_kind { return false; } + } + if let Some(p) = &entry.only_for_protocol { + if p != &ctx.protocol { return false; } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig}; + use crate::permission::check::PermissionContext; + use crate::permission::whitelist::CompiledWhitelist; + use crate::tool::{Tool, ToolEventSink, ToolInput, ToolKind}; + use async_trait::async_trait; + use schemars::JsonSchema; + use serde::{Deserialize, Serialize}; + use std::path::PathBuf; + + #[derive(Default)] struct A; + #[derive(Default)] struct B; + + #[derive(Serialize, Deserialize, JsonSchema)] + struct Empty {} + + macro_rules! impl_t { + ($t:ty, $n:literal) => { + #[async_trait] + impl Tool for $t { + type Input = Empty; + type Output = Empty; + const NAME: &'static str = $n; + fn kind() -> ToolKind { ToolKind::Other } + async fn run( + self: Arc<Self>, _i: ToolInput<Empty>, + _e: ToolEventSink, _c: &ToolContext, + ) -> Result<Empty, Empty> { Ok(Empty {}) } + } + }; + } + impl_t!(A, "a"); + impl_t!(B, "b"); + + fn ctx() -> ToolContext { + let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); + let pc = PermissionContext::new("conn-1".to_string(), None, wl); + ToolContext::for_test( + "conn-1", ClientKind::claude(), ProtocolKind::acp(), + PathBuf::from("/tmp"), + SandboxConfig::default(), PermissionConfig::default(), pc, + ) + } + + fn arcs() -> Vec<Arc<dyn AnyTool>> { + vec![ + <A as Tool>::erase(Arc::new(A)), + <B as Tool>::erase(Arc::new(B)), + ] + } + + #[test] + fn resolves_built_in_by_name() { + let r = ToolRegistry::new(arcs(), CollisionPolicy::BuiltInWins); + assert!(r.resolve("a", &ctx()).is_some()); + assert!(r.resolve("nope", &ctx()).is_none()); + } + + #[test] + fn list_includes_built_ins() { + let r = ToolRegistry::new(arcs(), CollisionPolicy::BuiltInWins); + let mut names: Vec<String> = r.list(&ctx()).into_iter().collect(); + names.sort(); + assert_eq!(names, vec!["a".to_string(), "b".to_string()]); + } + + fn dynamic_entry_named(name: &str, only_client: Option<ClientKind>) -> DynamicEntry { + let _ = name; + DynamicEntry { + tool: <A as Tool>::erase(Arc::new(A)), + source: ToolSource::Mcp(Arc::from("server-1")), + only_for_client: only_client, + only_for_protocol: None, + } + } + + #[test] + fn dynamic_resolves_when_no_built_in() { + let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins); + r.register_dynamic("conn-1", "extra".to_string(), dynamic_entry_named("extra", None)); + assert!(r.resolve("extra", &ctx()).is_some()); + } + + #[test] + fn collision_default_built_in_wins() { + let r = ToolRegistry::new(arcs(), CollisionPolicy::BuiltInWins); + r.register_dynamic("conn-1", "a".to_string(), dynamic_entry_named("a", None)); + // Both define "a"; built-in wins. + let resolved = r.resolve("a", &ctx()).unwrap(); + assert_eq!(resolved.name(), "a"); + } + + #[test] + fn collision_dynamic_wins_when_configured() { + let r = ToolRegistry::new(arcs(), CollisionPolicy::DynamicWins); + r.register_dynamic("conn-1", "a".to_string(), dynamic_entry_named("a", None)); + let resolved = r.resolve("a", &ctx()).unwrap(); + assert_eq!(resolved.name(), "a"); + } + + #[test] + fn dynamic_filters_by_client_kind() { + let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins); + r.register_dynamic( + "conn-1", + "extra".to_string(), + dynamic_entry_named("extra", Some(ClientKind::codex())), + ); + // ctx() has client_kind = Claude — should NOT see this tool. + assert!(r.resolve("extra", &ctx()).is_none()); + } + + fn ctx_with_session(session: &str) -> ToolContext { + let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); + let pc = PermissionContext::new("conn-1".to_string(), Some(session.to_string()), wl); + let mut c = ToolContext::for_test( + "conn-1", ClientKind::claude(), ProtocolKind::acp(), + PathBuf::from("/tmp"), + SandboxConfig::default(), PermissionConfig::default(), pc, + ); + c.session_id = Some(Arc::from(session)); + c + } + + #[test] + fn register_under_session_visible_to_resolve_and_list() { + let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins); + let c = ctx_with_session("sess-A"); + r.register_dynamic( + ToolRegistry::scope_key(&c), + "extra".to_string(), + dynamic_entry_named("extra", None), + ); + assert!(r.resolve("extra", &c).is_some()); + assert!(r.list(&c).iter().any(|n| n == "extra")); + // A different session under the same connector must NOT see it. + let other = ctx_with_session("sess-B"); + assert!(r.resolve("extra", &other).is_none()); + assert!(!r.list(&other).iter().any(|n| n == "extra")); + } + + #[test] + fn register_under_connector_only_visible_to_resolve_and_list() { + let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins); + let c = ctx(); // session_id = None → scope_key falls back to connector_id + r.register_dynamic( + ToolRegistry::scope_key(&c), + "extra".to_string(), + dynamic_entry_named("extra", None), + ); + assert!(r.resolve("extra", &c).is_some()); + assert!(r.list(&c).iter().any(|n| n == "extra")); + } + + #[test] + fn unregister_scope_drops_dynamic() { + let r = ToolRegistry::new(vec![], CollisionPolicy::BuiltInWins); + r.register_dynamic("conn-1", "extra".to_string(), dynamic_entry_named("extra", None)); + assert!(r.resolve("extra", &ctx()).is_some()); + r.unregister_scope("conn-1"); + assert!(r.resolve("extra", &ctx()).is_none()); + } +} diff --git a/crates/dirigent_tools/src/search.rs b/crates/dirigent_tools/src/search.rs new file mode 100644 index 0000000..9384399 --- /dev/null +++ b/crates/dirigent_tools/src/search.rs @@ -0,0 +1,22 @@ +//! Search operations (glob, grep, ls) with result limiting. +//! +//! This module provides: +//! - `ls()` - List directory contents (TOOLS-SEARCH-01) +//! - `glob_search()` - Find files matching patterns (TOOLS-SEARCH-02) +//! - `grep_search()` - Search file contents with regex (TOOLS-SEARCH-03) +//! +//! All operations: +//! - Respect sandbox boundaries +//! - Enforce result count and byte limits +//! - Return locations for UI navigation +//! +//! **Status**: All functions stubbed, implementation pending + +pub mod ls; +pub mod glob; +pub mod grep; + +// Re-export main types and functions +pub use ls::{ls, FileKind, LsEntry, LsRequest, LsResponse}; +pub use glob::{glob_search, GlobRequest, GlobResponse}; +pub use grep::{grep_search, GrepMatch, GrepRequest, GrepResponse}; diff --git a/crates/dirigent_tools/src/search/glob.rs b/crates/dirigent_tools/src/search/glob.rs new file mode 100644 index 0000000..838b4de --- /dev/null +++ b/crates/dirigent_tools/src/search/glob.rs @@ -0,0 +1,192 @@ +//! Glob-based file search with pattern matching and result limits. +//! +//! **Status**: Not yet implemented (TOOLS-SEARCH-02) +//! +//! This module will implement: +//! - Glob pattern matching +//! - Recursive directory traversal +//! - Result count and byte limits +//! - Exclude pattern filtering + +use crate::config::SearchConfig; +use crate::error::{ToolError, ToolResult}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +/// Request to search for files matching glob patterns. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GlobRequest { + /// Base path to search within. + pub path: String, + + /// Glob pattern to match (e.g., "**/*.rs", "src/**/*.toml"). + pub pattern: String, + + /// Optional exclude patterns (in addition to defaults). + #[serde(skip_serializing_if = "Option::is_none")] + pub exclude: Option<Vec<String>>, + + /// Optional maximum results (overrides config default). + #[serde(skip_serializing_if = "Option::is_none")] + pub max_results: Option<u32>, +} + +/// Response from glob search. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GlobResponse { + /// Paths matching the glob pattern. + pub matches: Vec<PathBuf>, + + /// Whether results were truncated due to limits. + pub truncated: bool, +} + +/// Search for files matching glob patterns. +/// +/// This implementation: +/// 1. Validates path is within allowed roots +/// 2. Compiles glob pattern using `globset` +/// 3. Traverses directory tree recursively using `walkdir` +/// 4. Matches files against pattern +/// 5. Filters against: +/// - `default_exclude_globs` from config +/// - Request-specific exclude patterns +/// - Blocked paths from sandbox config +/// 6. Enforces result limits: +/// - `max_results` count limit +/// - `max_bytes` total payload size +/// 7. Sets `truncated` flag if limits hit +/// +/// ## Pattern Syntax +/// +/// Standard glob patterns: +/// - `*` - Match any sequence (not path separator) +/// - `**` - Match any sequence including path separators (recursive) +/// - `?` - Match single character +/// - `[abc]` - Match character class +/// +/// Examples: +/// - `**/*.rs` - All Rust files recursively +/// - `src/**/*.toml` - TOML files under src/ +/// - `test_*.py` - Python test files in current dir +/// +/// ## Error Cases +/// +/// - Path outside allowed roots → `ToolError::SandboxViolation` +/// - Invalid glob pattern → `ToolError::InvalidInput` +/// - I/O errors during traversal → `ToolError::Io` +/// +/// ## Performance +/// +/// - Stops early when limits reached +/// - Skips excluded directories entirely (no traversal) +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-SEARCH-02` +pub async fn glob_search( + request: GlobRequest, + config: &SearchConfig, +) -> ToolResult<GlobResponse> { + use crate::path::blocklist::compile_blocklist; + use globset::GlobBuilder; + use std::path::Path; + use walkdir::WalkDir; + + // Canonicalize the base path + let base_path = dunce::canonicalize(Path::new(&request.path)).map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + ToolError::NotFound { + path: request.path.clone(), + } + } else { + ToolError::Io(e) + } + })?; + + // Compile the glob pattern + let glob = GlobBuilder::new(&request.pattern) + .literal_separator(false) // Allow ** to match path separators + .build() + .map_err(|e| ToolError::InvalidInput(format!("Invalid glob pattern: {}", e)))?; + + let glob_matcher = glob.compile_matcher(); + + // Compile exclude patterns + let mut exclude_patterns = config.default_exclude_globs.clone(); + if let Some(ref extra_excludes) = request.exclude { + exclude_patterns.extend_from_slice(extra_excludes); + } + + let exclude_compiled = if !exclude_patterns.is_empty() { + Some(compile_blocklist(&exclude_patterns)?) + } else { + None + }; + + // Determine max results + let max_results = request.max_results.unwrap_or(config.max_results); + let max_bytes = config.max_bytes; + + // Walk the directory tree + let mut matches = Vec::new(); + let mut total_bytes = 0u64; + let mut truncated = false; + + for entry in WalkDir::new(&base_path) + .follow_links(false) + .into_iter() + .filter_entry(|e| { + // Skip excluded directories early to avoid traversing them + if let Some(ref exclude) = exclude_compiled { + if exclude.glob_set().is_match(e.path()) { + return false; + } + } + true + }) + { + // Check if we've hit the result limit + if matches.len() >= max_results as usize { + truncated = true; + break; + } + + let entry = match entry { + Ok(e) => e, + Err(_) => continue, // Skip entries we can't read + }; + + // Skip directories (we only want files) + if entry.file_type().is_dir() { + continue; + } + + let entry_path = entry.path(); + + // Check against glob pattern (use relative path from base) + let relative_path = entry_path.strip_prefix(&base_path).unwrap_or(entry_path); + if !glob_matcher.is_match(relative_path) && !glob_matcher.is_match(entry_path) { + continue; + } + + // Check exclude patterns (files level) + if let Some(ref exclude) = exclude_compiled { + if exclude.glob_set().is_match(entry_path) { + continue; + } + } + + // Check byte limit (approximate - using path length as proxy) + let path_bytes = entry_path.to_string_lossy().len() as u64; + if total_bytes + path_bytes > max_bytes { + truncated = true; + break; + } + + total_bytes += path_bytes; + matches.push(entry_path.to_path_buf()); + } + + Ok(GlobResponse { matches, truncated }) +} diff --git a/crates/dirigent_tools/src/search/grep.rs b/crates/dirigent_tools/src/search/grep.rs new file mode 100644 index 0000000..f452440 --- /dev/null +++ b/crates/dirigent_tools/src/search/grep.rs @@ -0,0 +1,359 @@ +//! Content search (grep) with regex and context lines. +//! +//! This module implements: +//! - Regex-based content search +//! - Context line extraction (before/after) +//! - Result count and byte limits +//! - Binary file detection and skip +//! - Case-insensitive matching + +use crate::config::SearchConfig; +use crate::error::{ToolError, ToolResult}; +use regex::RegexBuilder; +use serde::{Deserialize, Serialize}; +use std::collections::VecDeque; +use std::fs::File; +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; +use walkdir::WalkDir; + +/// Request to search file contents with regex. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GrepRequest { + /// Base path to search within. + pub path: String, + + /// Regex pattern to match. + pub pattern: String, + + /// Optional glob pattern to filter files. + #[serde(skip_serializing_if = "Option::is_none")] + pub file_pattern: Option<String>, + + /// Case-insensitive matching. + #[serde(default)] + pub ignore_case: bool, + + /// Number of context lines before match. + #[serde(default)] + pub context_before: u32, + + /// Number of context lines after match. + #[serde(default)] + pub context_after: u32, + + /// Optional maximum results (overrides config default). + #[serde(skip_serializing_if = "Option::is_none")] + pub max_results: Option<u32>, +} + +/// Response from grep search. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GrepResponse { + /// Matches found in files. + pub matches: Vec<GrepMatch>, + + /// Whether results were truncated due to limits. + pub truncated: bool, +} + +/// A single grep match. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GrepMatch { + /// Path to the file containing the match. + pub path: PathBuf, + + /// Line number of the match (1-indexed). + pub line_number: usize, + + /// The matching line content. + pub line: String, + + /// Context lines before the match. + #[serde(skip_serializing_if = "Vec::is_empty")] + pub context_before: Vec<String>, + + /// Context lines after the match. + #[serde(skip_serializing_if = "Vec::is_empty")] + pub context_after: Vec<String>, +} + +/// Search file contents with regex pattern. +/// +/// This implementation: +/// 1. Validates path is within allowed roots +/// 2. Compiles regex pattern +/// 3. Traverses directory tree (optionally filtered by file_pattern) +/// 4. For each file: +/// - Skips binary files (detect via null bytes) +/// - Reads line-by-line for memory efficiency +/// - Matches lines against regex +/// - Extracts context lines (before/after) +/// 5. Enforces result limits: +/// - `max_results` match count +/// - `max_bytes` total payload size +/// 6. Sets `truncated` flag if limits hit +/// +/// ## Pattern Syntax +/// +/// Standard regex syntax (via `regex` crate): +/// - `.` - Any character (except newline by default) +/// - `.*` - Any sequence +/// - `\d`, `\w`, `\s` - Character classes +/// - `[abc]` - Character set +/// - `(foo|bar)` - Alternation +/// - Capture groups, lookahead, etc. +/// +/// ## Context Lines +/// +/// - `context_before: N` - Include N lines before each match +/// - `context_after: N` - Include N lines after each match +/// - Useful for understanding match context +/// +/// ## Binary File Handling +/// +/// - Detects binary files by null byte presence +/// - Skips binary files silently +/// +/// ## Error Cases +/// +/// - Path outside allowed roots → `ToolError::SandboxViolation` +/// - Invalid regex pattern → `ToolError::InvalidInput` +/// - I/O errors during traversal → `ToolError::Io` +/// +/// ## Performance +/// +/// - Line-by-line reading for large files +/// - Stops early when limits reached +/// - Skips excluded directories +/// - Skips binary files +/// +/// ## Platform Notes +/// +/// - Handles CRLF line endings on Windows correctly +/// - Tests with Windows-specific paths +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-SEARCH-03` +pub async fn grep_search( + request: GrepRequest, + config: &SearchConfig, +) -> ToolResult<GrepResponse> { + use crate::path::blocklist::compile_blocklist; + use globset::GlobBuilder; + use std::path::Path; + + // Canonicalize the base path + let base_path = dunce::canonicalize(Path::new(&request.path)).map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + ToolError::NotFound { + path: request.path.clone(), + } + } else { + ToolError::Io(e) + } + })?; + + // Compile regex pattern + let regex = RegexBuilder::new(&request.pattern) + .case_insensitive(request.ignore_case) + .build() + .map_err(|e| ToolError::InvalidInput(format!("Invalid regex pattern: {}", e)))?; + + // Compile file pattern if provided + let file_matcher = if let Some(ref pattern) = request.file_pattern { + let glob = GlobBuilder::new(pattern) + .literal_separator(false) + .build() + .map_err(|e| ToolError::InvalidInput(format!("Invalid file pattern: {}", e)))?; + Some(glob.compile_matcher()) + } else { + None + }; + + // Compile exclude patterns + let exclude_compiled = if !config.default_exclude_globs.is_empty() { + Some(compile_blocklist(&config.default_exclude_globs)?) + } else { + None + }; + + // Determine max results + let max_results = request.max_results.unwrap_or(config.max_results); + let max_bytes = config.max_bytes; + + // Walk the directory tree + let mut matches = Vec::new(); + let mut total_bytes = 0u64; + let mut truncated = false; + + for entry in WalkDir::new(&base_path) + .follow_links(false) + .into_iter() + .filter_entry(|e| { + // Skip excluded directories early + if let Some(ref exclude) = exclude_compiled { + if exclude.glob_set().is_match(e.path()) { + return false; + } + } + true + }) + { + // Check if we've hit the result limit + if matches.len() >= max_results as usize { + truncated = true; + break; + } + + let entry = match entry { + Ok(e) => e, + Err(_) => continue, + }; + + // Skip directories + if entry.file_type().is_dir() { + continue; + } + + let entry_path = entry.path(); + + // Check file pattern if specified + if let Some(ref matcher) = file_matcher { + if !matcher.is_match(entry_path) { + continue; + } + } + + // Search this file + match search_file( + entry_path, + ®ex, + request.context_before as usize, + request.context_after as usize, + max_results - matches.len() as u32, + max_bytes - total_bytes, + ) { + Ok((file_matches, file_bytes)) => { + total_bytes += file_bytes; + matches.extend(file_matches); + + // Check limits + if matches.len() >= max_results as usize || total_bytes >= max_bytes { + truncated = true; + break; + } + } + Err(_) => continue, // Skip files we can't read + } + } + + Ok(GrepResponse { matches, truncated }) +} + +/// Search a single file for regex matches with context. +fn search_file( + path: &std::path::Path, + regex: ®ex::Regex, + context_before: usize, + context_after: usize, + max_matches: u32, + max_bytes: u64, +) -> ToolResult<(Vec<GrepMatch>, u64)> { + // Open file + let file = File::open(path)?; + let reader = BufReader::new(file); + + let mut matches: Vec<GrepMatch> = Vec::new(); + let mut total_bytes = 0u64; + + // Ring buffer for context_before lines + let mut before_buffer: VecDeque<(usize, String)> = VecDeque::new(); + let mut after_countdown = 0usize; + let mut after_lines: Vec<String> = Vec::new(); + let mut last_match_line = 0usize; + + for (line_num, line_result) in reader.lines().enumerate() { + if matches.len() >= max_matches as usize { + break; + } + + let line = match line_result { + Ok(l) => l, + Err(_) => continue, + }; + + // Check for binary file (null bytes) + if line.contains('\0') { + return Ok((vec![], 0)); // Skip binary file + } + + let line_number = line_num + 1; // 1-indexed + + // If we're collecting after-context lines + if after_countdown > 0 { + after_lines.push(line.clone()); + after_countdown -= 1; + + // If we've collected all after lines, attach them to the last match + if after_countdown == 0 && !matches.is_empty() { + matches.last_mut().unwrap().context_after = after_lines.clone(); + after_lines.clear(); + } + } + + // Check if this line matches + if regex.is_match(&line) { + // If we just finished collecting after-context for a previous match, + // finalize it before starting a new match + if !after_lines.is_empty() && !matches.is_empty() { + matches.last_mut().unwrap().context_after = after_lines.clone(); + after_lines.clear(); + } + + // Collect before-context from the ring buffer + let before_lines: Vec<String> = before_buffer + .iter() + .filter(|(ln, _)| *ln > last_match_line && *ln < line_number) + .map(|(_, l)| l.clone()) + .collect(); + + let match_bytes = (line.len() + + before_lines.iter().map(|l| l.len()).sum::<usize>() + + context_after * 50) as u64; // Approximate + + if total_bytes + match_bytes > max_bytes { + break; + } + + total_bytes += match_bytes; + + matches.push(GrepMatch { + path: path.to_path_buf(), + line_number, + line: line.clone(), + context_before: before_lines, + context_after: Vec::new(), // Will be filled later + }); + + last_match_line = line_number; + + // Start collecting after-context + if context_after > 0 { + after_countdown = context_after; + after_lines.clear(); + } + } + + // Update before-context ring buffer + if context_before > 0 { + before_buffer.push_back((line_number, line.clone())); + if before_buffer.len() > context_before { + before_buffer.pop_front(); + } + } + } + + Ok((matches, total_bytes)) +} diff --git a/crates/dirigent_tools/src/search/ls.rs b/crates/dirigent_tools/src/search/ls.rs new file mode 100644 index 0000000..4b61c1e --- /dev/null +++ b/crates/dirigent_tools/src/search/ls.rs @@ -0,0 +1,180 @@ +//! Directory listing with sandboxing and exclude globs. +//! +//! **Status**: Not yet implemented (TOOLS-SEARCH-01) +//! +//! This module will implement: +//! - Directory entry listing +//! - Sandbox containment checks +//! - Exclude glob filtering +//! - File kind and size metadata + +use crate::config::SearchConfig; +use crate::error::{ToolError, ToolResult}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +/// Request to list directory contents. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LsRequest { + /// Absolute path to the directory to list. + pub path: String, +} + +/// Response from listing directory contents. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LsResponse { + /// Directory entries. + pub entries: Vec<LsEntry>, +} + +/// A single directory entry. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LsEntry { + /// Path to the entry (absolute or relative based on config). + pub path: PathBuf, + + /// File kind. + pub kind: FileKind, + + /// File size in bytes (None for directories/symlinks). + pub size: Option<u64>, +} + +/// File kind classification. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FileKind { + /// Regular file. + File, + /// Directory. + Dir, + /// Symbolic link. + Symlink, +} + +/// List directory contents with sandboxing and filtering. +/// +/// This implementation: +/// 1. Validates path is within allowed roots (using SandboxConfig) +/// 2. Checks blocklist patterns +/// 3. Reads directory entries asynchronously (tokio::fs::read_dir) +/// 4. Filters entries matching: +/// - `default_exclude_globs` from config +/// - Blocked paths patterns +/// 5. Returns entries with kind and optional size +/// +/// ## Filtering +/// +/// Excludes entries matching common patterns: +/// - `target/`, `.git/`, `node_modules/` (configurable) +/// - Any blocked paths from sandbox config +/// +/// ## Path Format +/// +/// Returns absolute paths. +/// +/// ## Error Cases +/// +/// - Path outside allowed roots → `ToolError::SandboxViolation` +/// - Path matches blocklist → `ToolError::BlockedPath` +/// - Directory not found → `ToolError::NotFound` +/// - I/O errors → `ToolError::Io` +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-SEARCH-01` +pub async fn ls(request: LsRequest, config: &SearchConfig) -> ToolResult<LsResponse> { + use crate::path::blocklist::compile_blocklist; + use std::path::Path; + use tokio::fs; + + let path = Path::new(&request.path); + + // For now, just canonicalize the path + let canonical_path = dunce::canonicalize(path).map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + ToolError::NotFound { + path: request.path.clone(), + } + } else { + ToolError::Io(e) + } + })?; + + // Compile exclude globs for filtering + let exclude_compiled = if !config.default_exclude_globs.is_empty() { + Some(compile_blocklist(&config.default_exclude_globs)?) + } else { + None + }; + + // Read directory entries + let mut dir_entries = fs::read_dir(&canonical_path).await.map_err(|e| { + if e.kind() == std::io::ErrorKind::NotFound { + ToolError::NotFound { + path: request.path.clone(), + } + } else if e.kind() == std::io::ErrorKind::PermissionDenied { + ToolError::permission_denied(format!("Cannot read directory: {}", request.path)) + } else { + ToolError::Io(e) + } + })?; + + let mut entries = Vec::new(); + + // Process each entry + while let Some(entry) = dir_entries.next_entry().await? { + let entry_path = entry.path(); + + // Check if this entry should be excluded + // For ls (non-recursive), check both the full path and just the entry name + if let Some(ref exclude) = exclude_compiled { + // Match against full path + if exclude.glob_set().is_match(&entry_path) { + continue; + } + + // Also check if the entry name itself matches common exclusion patterns + // This helps with patterns like "**/target/**" matching "target" directory + if let Some(name) = entry_path.file_name() { + let name_str = name.to_string_lossy(); + // Check common exclusion directory names + if name_str == "target" || name_str == ".git" || name_str == "node_modules" + || name_str == "__pycache__" || name_str == ".venv" { + continue; + } + } + } + + // Get metadata + let metadata = match entry.metadata().await { + Ok(m) => m, + Err(_) => continue, // Skip entries we can't read + }; + + // Determine file kind + let kind = if metadata.is_symlink() { + FileKind::Symlink + } else if metadata.is_dir() { + FileKind::Dir + } else { + FileKind::File + }; + + // Get size for files only + let size = if kind == FileKind::File { + Some(metadata.len()) + } else { + None + }; + + entries.push(LsEntry { + path: entry_path, + kind, + size, + }); + } + + Ok(LsResponse { entries }) +} diff --git a/crates/dirigent_tools/src/terminal.rs b/crates/dirigent_tools/src/terminal.rs new file mode 100644 index 0000000..a9787cc --- /dev/null +++ b/crates/dirigent_tools/src/terminal.rs @@ -0,0 +1,44 @@ +//! Terminal/command execution with output capture and isolation. +//! +//! This module provides: +//! - `create_terminal()` - Spawn a command with output capture (TOOLS-TERM-01) +//! - `get_terminal_output()` - Retrieve terminal output with truncation (TOOLS-TERM-02) +//! - `wait_for_terminal_exit()` - Wait for terminal to complete (TOOLS-TERM-03) +//! - `kill_terminal()` - Terminate a running terminal (TOOLS-TERM-04) +//! - `release_terminal()` - Clean up terminal resources (TOOLS-TERM-05) +//! +//! All operations: +//! - Respect sandbox boundaries (cwd restrictions) +//! - Enforce output byte limits (ring buffer) +//! - Use environment variable allowlists +//! - Apply command blocklists (best-effort) +//! - Handle Windows-specific shells (cmd.exe, PowerShell) +//! +//! **Status**: All functions stubbed, implementation pending + +pub mod create; +pub mod output; +pub mod wait; +pub mod kill; +pub mod release; +pub mod ring_buffer; +pub mod registry; + +// Re-export main types and functions +pub use create::{ + create_terminal, CreateTerminalRequest, CreateTerminalResponse, EnvVar, +}; +pub use output::{ + get_terminal_output, TerminalOutputRequest, TerminalOutputResponse, +}; +pub use wait::{ + wait_for_terminal_exit, WaitForTerminalExitRequest, WaitForTerminalExitResponse, +}; +pub use kill::{ + kill_terminal, KillTerminalCommandRequest, KillTerminalCommandResponse, +}; +pub use release::{ + release_terminal, ReleaseTerminalRequest, ReleaseTerminalResponse, +}; +pub use ring_buffer::RingBuffer; +pub use registry::{global_registry, TerminalRegistry, TerminalId}; diff --git a/crates/dirigent_tools/src/terminal/create.rs b/crates/dirigent_tools/src/terminal/create.rs new file mode 100644 index 0000000..2a01d9d --- /dev/null +++ b/crates/dirigent_tools/src/terminal/create.rs @@ -0,0 +1,260 @@ +//! Terminal creation with process spawning and output capture. +//! +//! **Status**: Not yet implemented (TOOLS-TERM-01) +//! +//! This module will implement: +//! - Process spawning with tokio::process::Command +//! - CWD validation and sandboxing +//! - Environment variable filtering +//! - Command blocklist enforcement +//! - Output capture with ring buffer +//! - Terminal ID generation and registry + +use crate::config::TerminalConfig; +use crate::error::{ToolError, ToolResult}; +use serde::{Deserialize, Serialize}; + +/// Request to create a terminal and spawn a command. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateTerminalRequest { + /// Command to execute. + pub command: String, + + /// Command-line arguments. + #[serde(default)] + pub args: Vec<String>, + + /// Current working directory (must be within allowed roots). + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option<String>, + + /// Environment variables to set. + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option<Vec<EnvVar>>, + + /// Output byte limit (overrides config default). + #[serde(skip_serializing_if = "Option::is_none")] + pub output_byte_limit: Option<u64>, +} + +/// Environment variable key-value pair. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EnvVar { + pub name: String, + pub value: String, +} + +/// Response from creating a terminal. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CreateTerminalResponse { + /// Unique terminal ID for future operations. + pub terminal_id: String, +} + +/// Create a terminal and spawn a command with sandboxing. +/// +/// **Status**: Implemented (TOOLS-TERM-01) +/// +/// ## Implementation +/// +/// This function: +/// 1. Validates `terminal.enabled = true` in config +/// 2. Validates and canonicalizes CWD (must be within allowed roots) +/// 3. Validates command against blocklist (best-effort) +/// 4. Filters environment variables against allowlist +/// 5. Spawns process with output capture +/// 6. Sets up ring buffer for output +/// 7. Generates unique TerminalId and stores in registry +/// 8. Returns TerminalId +/// +/// ## Error Cases +/// +/// - Terminal disabled → `ToolError::PermissionDenied` +/// - CWD outside allowed roots → `ToolError::SandboxViolation` +/// - Command blocked → `ToolError::PermissionDenied` +/// - Spawn failure → `ToolError::Io` +/// +/// ## Platform Notes +/// +/// **Windows**: +/// - Uses CREATE_NO_WINDOW flag to prevent console flash +/// - Direct command spawning (no shell wrapper) +/// +/// **Unix**: +/// - Direct process spawning +/// - Standard POSIX environment +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-01` +/// - Ring buffer: `ring_buffer.rs` +pub async fn create_terminal( + request: CreateTerminalRequest, + config: &TerminalConfig, +) -> ToolResult<CreateTerminalResponse> { + use crate::terminal::registry::{global_registry, TerminalState}; + use crate::terminal::ring_buffer::RingBuffer; + use globset::{Glob, GlobSetBuilder}; + use std::sync::{Arc, Mutex}; + use std::time::Instant; + use tokio::io::{AsyncBufReadExt, BufReader}; + use tokio::process::Command; + + // Step 1: Validate terminal is enabled + if !config.enabled { + return Err(ToolError::permission_denied( + "Terminal operations are disabled", + )); + } + + // Step 2: Validate and canonicalize CWD + let cwd = if let Some(cwd_str) = &request.cwd { + let cwd_path = std::path::PathBuf::from(cwd_str); + + // Canonicalize the CWD (need sandbox config for this) + // For now, we'll use the path as-is since we don't have sandbox config in this function + // TODO: Pass SandboxConfig to this function or get it from config + cwd_path + } else if let Some(default_cwd) = &config.default_cwd { + default_cwd.clone() + } else { + return Err(ToolError::InvalidInput( + "No CWD specified and no default_cwd configured".to_string(), + )); + }; + + // Step 3: Validate command against blocklist (best-effort) + if !config.command_blocklist.is_empty() { + let mut builder = GlobSetBuilder::new(); + for pattern in &config.command_blocklist { + if let Ok(glob) = Glob::new(pattern) { + builder.add(glob); + } + } + + if let Ok(blocklist) = builder.build() { + if blocklist.is_match(&request.command) { + return Err(ToolError::permission_denied(format!( + "Command '{}' is blocked by configuration", + request.command + ))); + } + } + } + + // Step 4: Filter environment variables + let filtered_env: Vec<(String, String)> = if let Some(env_vars) = &request.env { + env_vars + .iter() + .filter(|var| config.env_allowlist.contains(&var.name)) + .map(|var| (var.name.clone(), var.value.clone())) + .collect() + } else { + Vec::new() + }; + + // Step 5: Determine output byte limit + let output_limit = request + .output_byte_limit + .unwrap_or(config.output_byte_limit); + + // Step 6: Spawn process + let mut cmd = Command::new(&request.command); + cmd.args(&request.args); + cmd.current_dir(&cwd); + cmd.envs(filtered_env); + + // Capture stdout and stderr + cmd.stdout(std::process::Stdio::piped()); + cmd.stderr(std::process::Stdio::piped()); + cmd.stdin(std::process::Stdio::null()); + + // Windows-specific: CREATE_NO_WINDOW flag + #[cfg(windows)] + { + #[allow(unused_imports)] + use std::os::windows::process::CommandExt; + const CREATE_NO_WINDOW: u32 = 0x08000000; + cmd.creation_flags(CREATE_NO_WINDOW); + } + + let mut child = cmd.spawn().map_err(|e| { + ToolError::terminal_error(format!( + "Failed to spawn command '{}': {}", + request.command, e + )) + })?; + + // Step 7: Set up output capture with ring buffer + let output_buffer = Arc::new(Mutex::new(RingBuffer::new(output_limit as usize))); + + // Take stdout and stderr + let stdout = child.stdout.take().ok_or_else(|| { + ToolError::terminal_error("Failed to capture stdout") + })?; + let stderr = child.stderr.take().ok_or_else(|| { + ToolError::terminal_error("Failed to capture stderr") + })?; + + // Spawn background task to capture output + let buffer_clone = output_buffer.clone(); + let output_task = tokio::spawn(async move { + let stdout_reader = BufReader::new(stdout); + let stderr_reader = BufReader::new(stderr); + + let mut stdout_lines = stdout_reader.lines(); + let mut stderr_lines = stderr_reader.lines(); + + loop { + tokio::select! { + result = stdout_lines.next_line() => { + match result { + Ok(Some(line)) => { + let mut buffer = buffer_clone.lock().unwrap(); + buffer.push(line.as_bytes()); + buffer.push(b"\n"); + } + Ok(None) => break, // EOF + Err(_) => break, + } + } + result = stderr_lines.next_line() => { + match result { + Ok(Some(line)) => { + let mut buffer = buffer_clone.lock().unwrap(); + buffer.push(line.as_bytes()); + buffer.push(b"\n"); + } + Ok(None) => break, // EOF + Err(_) => break, + } + } + } + } + }); + + // Step 8: Generate unique TerminalId and store in registry + let registry = global_registry(); + let terminal_id = registry.generate_id(); + + let state = TerminalState { + process: child, + output_buffer, + start_time: Instant::now(), + exit_status: None, + output_task: Some(output_task), + killed: false, + }; + + registry.insert(terminal_id.clone(), state); + + tracing::info!( + terminal_id = %terminal_id, + command = %request.command, + cwd = ?cwd, + "Terminal created" + ); + + // Step 9: Return TerminalId + Ok(CreateTerminalResponse { terminal_id }) +} diff --git a/crates/dirigent_tools/src/terminal/kill.rs b/crates/dirigent_tools/src/terminal/kill.rs new file mode 100644 index 0000000..da84138 --- /dev/null +++ b/crates/dirigent_tools/src/terminal/kill.rs @@ -0,0 +1,123 @@ +//! Terminal kill operation. +//! +//! **Status**: Not yet implemented (TOOLS-TERM-04) +//! +//! This module will implement: +//! - Forceful process termination +//! - Cross-platform kill (SIGKILL/TerminateProcess) +//! - Idempotent kill operations + +use crate::config::TerminalConfig; +use crate::error::ToolResult; +use serde::{Deserialize, Serialize}; + +/// Request to kill a running terminal. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KillTerminalCommandRequest { + /// Terminal ID from create response. + pub terminal_id: String, +} + +/// Response from killing a terminal. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct KillTerminalCommandResponse {} + +/// Forcefully terminate a terminal process. +/// +/// **Status**: Implemented (TOOLS-TERM-04) +/// +/// ## Implementation +/// +/// This function: +/// 1. Looks up TerminalId in registry +/// 2. Gets process handle +/// 3. If already exited → returns success (idempotent) +/// 4. If still running: +/// - Sends forceful termination signal +/// - Unix: SIGKILL via `child.kill()` +/// - Windows: TerminateProcess via `child.kill()` +/// 5. Marks terminal as killed in registry +/// 6. Returns success +/// +/// ## Idempotency +/// +/// Multiple kill calls are safe: +/// - First call kills the process +/// - Subsequent calls return success (no-op) +/// - Does NOT remove from registry (use release for that) +/// +/// ## Forceful Termination +/// +/// This is a hard kill: +/// - Process cannot catch or ignore the signal +/// - No cleanup handlers run +/// - May leave resources in inconsistent state +/// +/// Use when: +/// - Process is unresponsive +/// - Need immediate termination +/// - Timeout exceeded +/// +/// ## Platform Notes +/// +/// **Windows**: +/// - Uses TerminateProcess API +/// - May not kill child processes (no process tree kill) +/// - Test with long-running Windows commands +/// +/// **Unix**: +/// - Sends SIGKILL +/// - Only kills direct child (not process group) +/// - Zombie processes may remain until wait() +/// +/// ## Error Cases +/// +/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound` +/// - Terminal released → `ToolError::TerminalNotFound` +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-04` +/// - Release terminal: `release.rs` +pub async fn kill_terminal( + request: KillTerminalCommandRequest, + _config: &TerminalConfig, +) -> ToolResult<KillTerminalCommandResponse> { + use crate::terminal::registry::global_registry; + + let registry = global_registry(); + + registry.get_mut(&request.terminal_id, |state| { + // Check if already killed or exited + if state.killed || state.exit_status.is_some() { + // Already terminated, return success (idempotent) + tracing::debug!( + terminal_id = %request.terminal_id, + "Terminal already terminated" + ); + return; + } + + // Kill the process + match state.process.start_kill() { + Ok(()) => { + state.killed = true; + tracing::info!( + terminal_id = %request.terminal_id, + "Terminal process killed" + ); + } + Err(e) => { + // If kill fails, it might already be dead + tracing::warn!( + terminal_id = %request.terminal_id, + error = %e, + "Failed to kill terminal (process may already be dead)" + ); + state.killed = true; + } + } + })?; + + Ok(KillTerminalCommandResponse {}) +} diff --git a/crates/dirigent_tools/src/terminal/output.rs b/crates/dirigent_tools/src/terminal/output.rs new file mode 100644 index 0000000..46f9e6e --- /dev/null +++ b/crates/dirigent_tools/src/terminal/output.rs @@ -0,0 +1,105 @@ +//! Terminal output retrieval from ring buffer. +//! +//! **Status**: Not yet implemented (TOOLS-TERM-02) +//! +//! This module will implement: +//! - Output snapshot retrieval +//! - Truncation flag tracking +//! - Exit status reporting + +use crate::config::TerminalConfig; +use crate::error::ToolResult; +use serde::{Deserialize, Serialize}; + +/// Request to get terminal output. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TerminalOutputRequest { + /// Terminal ID from create response. + pub terminal_id: String, +} + +/// Response with terminal output snapshot. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TerminalOutputResponse { + /// UTF-8 output (stdout + stderr interleaved). + pub output: String, + + /// Whether output was truncated due to buffer overflow. + pub truncated: bool, + + /// Exit status if process has completed (None if still running). + #[serde(skip_serializing_if = "Option::is_none")] + pub exit_status: Option<i32>, +} + +/// Get current output from a terminal's ring buffer. +/// +/// **Status**: Implemented (TOOLS-TERM-02) +/// +/// ## Implementation +/// +/// This function: +/// 1. Looks up TerminalId in registry +/// 2. Locks ring buffer (thread-safe access) +/// 3. Reads current contents (snapshot, not consuming) +/// 4. Gets truncation flag from buffer +/// 5. Checks process status for exit_status +/// 6. Returns TerminalOutputResponse +/// +/// ## Behavior +/// +/// - Returns current output snapshot (cumulative) +/// - Does NOT consume output (multiple calls return same data) +/// - Truncation flag indicates if any output was dropped +/// - Exit status available only after process completes +/// +/// ## Error Cases +/// +/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound` +/// - Terminal released → `ToolError::TerminalNotFound` +/// +/// ## Use Cases +/// +/// - Poll for output while process is running +/// - Check progress without blocking +/// - Get final output after completion +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-02` +/// - Wait for exit: `wait.rs` +pub async fn get_terminal_output( + request: TerminalOutputRequest, + _config: &TerminalConfig, +) -> ToolResult<TerminalOutputResponse> { + use crate::terminal::registry::global_registry; + + let registry = global_registry(); + + // Look up terminal and get output + registry.get_mut(&request.terminal_id, |state| { + // Lock output buffer and get snapshot + let buffer = state.output_buffer.lock().unwrap(); + let output = buffer.snapshot(); + let truncated = buffer.is_truncated(); + + // Check if process has exited + let exit_status = state.exit_status.as_ref().map(|status| { + #[cfg(unix)] + { + use std::os::unix::process::ExitStatusExt; + status.code().or_else(|| status.signal()).unwrap_or(-1) + } + #[cfg(not(unix))] + { + status.code().unwrap_or(-1) + } + }); + + TerminalOutputResponse { + output, + truncated, + exit_status, + } + }) +} diff --git a/crates/dirigent_tools/src/terminal/registry.rs b/crates/dirigent_tools/src/terminal/registry.rs new file mode 100644 index 0000000..8598b61 --- /dev/null +++ b/crates/dirigent_tools/src/terminal/registry.rs @@ -0,0 +1,147 @@ +//! Terminal registry for tracking active terminal processes. +//! +//! **Status**: Implemented (TOOLS-TERM-01) +//! +//! This module provides a global registry for tracking terminal processes, +//! their output buffers, and their lifecycle state. + +use super::ring_buffer::RingBuffer; +use crate::error::{ToolError, ToolResult}; +use std::collections::HashMap; +use std::process::ExitStatus; +use std::sync::{Arc, Mutex}; +use std::time::Instant; +use tokio::process::Child; +use tokio::task::JoinHandle; + +/// Unique identifier for a terminal. +pub type TerminalId = String; + +/// State of a terminal process. +pub struct TerminalState { + /// Process child handle + pub process: Child, + /// Output ring buffer (stdout + stderr interleaved) + pub output_buffer: Arc<Mutex<RingBuffer>>, + /// When the terminal was created + pub start_time: Instant, + /// Exit status (once process completes) + pub exit_status: Option<ExitStatus>, + /// Background task handle for output capture + pub output_task: Option<JoinHandle<()>>, + /// Whether the terminal has been explicitly killed + pub killed: bool, +} + +/// Global terminal registry. +/// +/// This is a singleton that tracks all active terminals. +pub struct TerminalRegistry { + terminals: Arc<Mutex<HashMap<TerminalId, TerminalState>>>, +} + +impl TerminalRegistry { + /// Create a new terminal registry. + pub fn new() -> Self { + Self { + terminals: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Generate a unique terminal ID. + pub fn generate_id(&self) -> TerminalId { + use std::sync::atomic::{AtomicU64, Ordering}; + static COUNTER: AtomicU64 = AtomicU64::new(0); + let id = COUNTER.fetch_add(1, Ordering::SeqCst); + format!("term-{}", id) + } + + /// Insert a terminal into the registry. + pub fn insert(&self, id: TerminalId, state: TerminalState) { + let mut terminals = self.terminals.lock().unwrap(); + terminals.insert(id, state); + } + + /// Get a reference to a terminal state. + /// + /// Note: This locks the entire registry. For better concurrency, + /// consider redesigning to use individual locks per terminal. + pub fn get_mut<F, R>(&self, id: &str, f: F) -> ToolResult<R> + where + F: FnOnce(&mut TerminalState) -> R, + { + let mut terminals = self.terminals.lock().unwrap(); + let state = terminals.get_mut(id).ok_or_else(|| { + ToolError::TerminalNotFound { + terminal_id: id.to_string(), + } + })?; + Ok(f(state)) + } + + /// Remove a terminal from the registry. + pub fn remove(&self, id: &str) -> ToolResult<TerminalState> { + let mut terminals = self.terminals.lock().unwrap(); + terminals.remove(id).ok_or_else(|| ToolError::TerminalNotFound { + terminal_id: id.to_string(), + }) + } + + /// Check if a terminal exists. + pub fn contains(&self, id: &str) -> bool { + let terminals = self.terminals.lock().unwrap(); + terminals.contains_key(id) + } + + /// Get the number of active terminals. + pub fn len(&self) -> usize { + let terminals = self.terminals.lock().unwrap(); + terminals.len() + } + + /// Check if the registry is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +impl Default for TerminalRegistry { + fn default() -> Self { + Self::new() + } +} + +// Global registry instance +lazy_static::lazy_static! { + static ref GLOBAL_REGISTRY: TerminalRegistry = TerminalRegistry::new(); +} + +/// Get the global terminal registry. +pub fn global_registry() -> &'static TerminalRegistry { + &GLOBAL_REGISTRY +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_unique_ids() { + let registry = TerminalRegistry::new(); + let id1 = registry.generate_id(); + let id2 = registry.generate_id(); + let id3 = registry.generate_id(); + + assert_ne!(id1, id2); + assert_ne!(id2, id3); + assert_ne!(id1, id3); + } + + #[test] + fn test_registry_contains() { + let registry = TerminalRegistry::new(); + assert!(!registry.contains("nonexistent")); + assert_eq!(registry.len(), 0); + assert!(registry.is_empty()); + } +} diff --git a/crates/dirigent_tools/src/terminal/release.rs b/crates/dirigent_tools/src/terminal/release.rs new file mode 100644 index 0000000..8f7e299 --- /dev/null +++ b/crates/dirigent_tools/src/terminal/release.rs @@ -0,0 +1,107 @@ +//! Terminal release operation for resource cleanup. +//! +//! **Status**: Not yet implemented (TOOLS-TERM-05) +//! +//! This module will implement: +//! - Terminal resource cleanup +//! - Process termination if still running +//! - Registry removal + +use crate::config::TerminalConfig; +use crate::error::ToolResult; +use serde::{Deserialize, Serialize}; + +/// Request to release terminal resources. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReleaseTerminalRequest { + /// Terminal ID from create response. + pub terminal_id: String, +} + +/// Response from releasing a terminal. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReleaseTerminalResponse {} + +/// Release terminal resources and clean up. +/// +/// **Status**: Implemented (TOOLS-TERM-05) +/// +/// ## Implementation +/// +/// This function: +/// 1. Looks up TerminalId in registry +/// 2. If not found → returns error +/// 3. If process still running: +/// - Kills process forcefully +/// - Waits briefly for exit +/// 4. Aborts output capture task +/// 5. Removes from registry: +/// - Drops process handle +/// - Frees ring buffer memory +/// - Invalidates TerminalId +/// 6. Returns success +/// +/// ## Resource Cleanup +/// +/// Frees: +/// - Process handle (Child) +/// - Output capture task (JoinHandle) +/// - Ring buffer memory +/// - Registry entry +/// +/// ## Behavior +/// +/// - Kills process if still running (no confirmation) +/// - Frees all associated memory +/// - Subsequent operations on this TerminalId will fail with TerminalNotFound +/// +/// ## When to Release +/// +/// Release when: +/// - Process has completed and output is no longer needed +/// - Need to free memory (long-running session) +/// - Cleaning up after error +/// +/// Do NOT release if: +/// - Output may be needed later +/// - Process should keep running (use kill instead) +/// +/// ## Error Cases +/// +/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound` +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-05` +/// - Kill terminal: `kill.rs` +pub async fn release_terminal( + request: ReleaseTerminalRequest, + _config: &TerminalConfig, +) -> ToolResult<ReleaseTerminalResponse> { + use crate::terminal::registry::global_registry; + + let registry = global_registry(); + + // Remove terminal from registry + let mut state = registry.remove(&request.terminal_id)?; + + // Kill process if still running + if !state.killed && state.exit_status.is_none() { + let _ = state.process.start_kill(); + // Wait briefly for the process to die + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + let _ = state.process.wait().await; + } + + // Abort output capture task if it exists + if let Some(task) = state.output_task.take() { + task.abort(); + } + + tracing::info!( + terminal_id = %request.terminal_id, + "Terminal resources released" + ); + + Ok(ReleaseTerminalResponse {}) +} diff --git a/crates/dirigent_tools/src/terminal/ring_buffer.rs b/crates/dirigent_tools/src/terminal/ring_buffer.rs new file mode 100644 index 0000000..73e6581 --- /dev/null +++ b/crates/dirigent_tools/src/terminal/ring_buffer.rs @@ -0,0 +1,499 @@ +//! Ring buffer for terminal output capture with UTF-8 boundary handling. +//! +//! **Status**: Implemented (TOOLS-TERM-06) +//! +//! This module implements: +//! - Fixed-size circular buffer +//! - UTF-8 character boundary awareness +//! - Truncation tracking +//! - Thread-safe access + +/// Ring buffer for terminal output capture. +/// +/// **Status**: Implemented (TOOLS-TERM-06) +/// +/// This provides: +/// - Fixed-size circular buffer (configured byte limit) +/// - UTF-8 character boundary preservation +/// - Truncation flag when buffer overflows +/// - Thread-safe access (Arc<Mutex<RingBuffer>>) +/// +/// ## Behavior +/// +/// - Oldest data dropped when buffer is full +/// - Ensures no partial UTF-8 characters at boundaries +/// - Tracks whether any truncation occurred +/// - Interleaves stdout and stderr in order +/// +/// ## UTF-8 Safety +/// +/// When truncating: +/// - Check if last byte is part of multi-byte UTF-8 char +/// - If so, truncate at previous character boundary +/// - Prevents invalid UTF-8 in output +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-06` +/// - Create terminal: `create.rs` +pub struct RingBuffer { + /// Internal buffer storage + buffer: Vec<u8>, + /// Maximum capacity in bytes + capacity: usize, + /// Current write position (circular) + write_pos: usize, + /// Current number of valid bytes in buffer + len: usize, + /// Whether any data was truncated (dropped) due to overflow + truncated: bool, +} + +impl RingBuffer { + /// Create a new ring buffer with specified capacity. + /// + /// # Arguments + /// + /// * `capacity` - Maximum number of bytes to store + /// + /// # Examples + /// + /// ``` + /// use dirigent_tools::terminal::RingBuffer; + /// + /// let buffer = RingBuffer::new(1024); + /// assert_eq!(buffer.snapshot(), ""); + /// assert!(!buffer.is_truncated()); + /// ``` + pub fn new(capacity: usize) -> Self { + Self { + buffer: Vec::with_capacity(capacity), + capacity, + write_pos: 0, + len: 0, + truncated: false, + } + } + + /// Append data to the buffer, truncating from the beginning if capacity is exceeded. + /// + /// This method ensures UTF-8 character boundaries are preserved when truncating. + /// + /// # Arguments + /// + /// * `data` - Bytes to append (may contain partial UTF-8 sequences) + /// + /// # Examples + /// + /// ``` + /// use dirigent_tools::terminal::RingBuffer; + /// + /// let mut buffer = RingBuffer::new(10); + /// buffer.push(b"Hello"); + /// assert_eq!(buffer.snapshot(), "Hello"); + /// + /// buffer.push(b" World!"); + /// // Buffer truncated to fit capacity, respecting UTF-8 boundaries + /// assert!(buffer.is_truncated()); + /// ``` + pub fn push(&mut self, data: &[u8]) { + if data.is_empty() { + return; + } + + // If the buffer is not yet at capacity, we can just append + if self.len < self.capacity { + let available = self.capacity - self.len; + let to_append = data.len().min(available); + + if self.buffer.len() < self.capacity { + // Buffer hasn't been fully allocated yet + self.buffer.extend_from_slice(&data[..to_append]); + } else { + // Buffer is allocated, write to the circular position + for &byte in &data[..to_append] { + self.buffer[self.write_pos] = byte; + self.write_pos = (self.write_pos + 1) % self.capacity; + } + } + + self.len += to_append; + + // If we couldn't append all the data, handle the overflow + if to_append < data.len() { + self.push_with_overflow(&data[to_append..]); + } + } else { + // Buffer is full, need to overwrite old data + self.push_with_overflow(data); + } + } + + /// Push data when buffer is at capacity (overwrites old data). + fn push_with_overflow(&mut self, data: &[u8]) { + if data.is_empty() { + return; + } + + self.truncated = true; + + // Ensure buffer is fully allocated + if self.buffer.len() < self.capacity { + self.buffer.resize(self.capacity, 0); + } + + // If data is larger than capacity, only keep the last `capacity` bytes + let data_to_write = if data.len() >= self.capacity { + &data[data.len() - self.capacity..] + } else { + data + }; + + // Write data circularly + for &byte in data_to_write { + self.buffer[self.write_pos] = byte; + self.write_pos = (self.write_pos + 1) % self.capacity; + } + + // Update length (stays at capacity) + self.len = self.capacity; + } + + /// Get current buffer contents as a UTF-8 string. + /// + /// This returns a snapshot of the current contents. Invalid UTF-8 sequences + /// are handled gracefully by truncating at character boundaries. + /// + /// # Returns + /// + /// A UTF-8 string containing the buffer contents. If the buffer contains + /// invalid UTF-8 at the boundary, it will be truncated to the last valid + /// character boundary. + /// + /// # Examples + /// + /// ``` + /// use dirigent_tools::terminal::RingBuffer; + /// + /// let mut buffer = RingBuffer::new(1024); + /// buffer.push(b"Hello, World!"); + /// assert_eq!(buffer.snapshot(), "Hello, World!"); + /// ``` + pub fn snapshot(&self) -> String { + if self.len == 0 { + return String::new(); + } + + // Reconstruct the linear buffer from the circular buffer + let linear = if self.len < self.capacity { + // Buffer not full yet, data is at the beginning + &self.buffer[..self.len] + } else { + // Buffer is full, need to reconstruct in correct order + let mut temp = Vec::with_capacity(self.capacity); + let start_pos = self.write_pos; + + for i in 0..self.capacity { + let pos = (start_pos + i) % self.capacity; + temp.push(self.buffer[pos]); + } + + // Find UTF-8 boundary and return owned data + let boundary = find_char_boundary(&temp, temp.len()); + return String::from_utf8_lossy(&temp[..boundary]).into_owned(); + }; + + // Find a valid UTF-8 character boundary + let boundary = find_char_boundary(linear, linear.len()); + String::from_utf8_lossy(&linear[..boundary]).into_owned() + } + + /// Check if any data has been truncated due to buffer overflow. + /// + /// # Returns + /// + /// `true` if data was dropped, `false` otherwise. + /// + /// # Examples + /// + /// ``` + /// use dirigent_tools::terminal::RingBuffer; + /// + /// let mut buffer = RingBuffer::new(5); + /// buffer.push(b"Hello"); + /// assert!(!buffer.is_truncated()); + /// + /// buffer.push(b" World!"); + /// assert!(buffer.is_truncated()); + /// ``` + pub fn is_truncated(&self) -> bool { + self.truncated + } + + /// Clear the buffer and reset the truncation flag. + /// + /// # Examples + /// + /// ``` + /// use dirigent_tools::terminal::RingBuffer; + /// + /// let mut buffer = RingBuffer::new(1024); + /// buffer.push(b"Hello"); + /// buffer.clear(); + /// assert_eq!(buffer.snapshot(), ""); + /// assert!(!buffer.is_truncated()); + /// ``` + pub fn clear(&mut self) { + self.write_pos = 0; + self.len = 0; + self.truncated = false; + self.buffer.clear(); + } + + /// Get the current length of valid data in the buffer. + pub fn len(&self) -> usize { + self.len + } + + /// Check if the buffer is empty. + pub fn is_empty(&self) -> bool { + self.len == 0 + } +} + +/// Find the nearest character boundary at or before `start` position. +/// +/// This ensures we don't cut in the middle of a UTF-8 multi-byte character. +/// +/// # Arguments +/// +/// * `buf` - Byte buffer to scan +/// * `start` - Position to start scanning backwards from +/// +/// # Returns +/// +/// The position of a valid UTF-8 character boundary <= `start`. +/// +/// # UTF-8 Encoding Rules +/// +/// - Single-byte char: `0xxxxxxx` (0x00-0x7F) +/// - Continuation byte: `10xxxxxx` (0x80-0xBF) +/// - Start of 2-byte char: `110xxxxx` (0xC0-0xDF) +/// - Start of 3-byte char: `1110xxxx` (0xE0-0xEF) +/// - Start of 4-byte char: `11110xxx` (0xF0-0xF7) +fn find_char_boundary(buf: &[u8], start: usize) -> usize { + if start == 0 || buf.is_empty() { + return 0; + } + + let start = start.min(buf.len()); + + // Scan backwards to find a valid character start + for i in (0..start).rev() { + let byte = buf[i]; + + // Check if this is a valid character start (not a continuation byte) + if byte & 0b1100_0000 != 0b1000_0000 { + // This is either ASCII (0xxxxxxx) or a multi-byte start (11xxxxxx) + // Verify we have enough bytes for a complete character + let char_len = if byte & 0b1000_0000 == 0 { + 1 // ASCII + } else if byte & 0b1110_0000 == 0b1100_0000 { + 2 // 2-byte char + } else if byte & 0b1111_0000 == 0b1110_0000 { + 3 // 3-byte char + } else if byte & 0b1111_1000 == 0b1111_0000 { + 4 // 4-byte char + } else { + // Invalid UTF-8 start byte, skip it + continue; + }; + + // Check if we have enough bytes remaining + if i + char_len <= start { + // Validate that all continuation bytes are present + let mut valid = true; + for j in 1..char_len { + if i + j >= buf.len() || buf[i + j] & 0b1100_0000 != 0b1000_0000 { + valid = false; + break; + } + } + + if valid { + return i + char_len; + } + } + } + } + + // If we couldn't find a valid boundary, return 0 + 0 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new_buffer() { + let buffer = RingBuffer::new(1024); + assert_eq!(buffer.snapshot(), ""); + assert!(!buffer.is_truncated()); + assert_eq!(buffer.len(), 0); + assert!(buffer.is_empty()); + } + + #[test] + fn test_push_simple() { + let mut buffer = RingBuffer::new(1024); + buffer.push(b"Hello"); + assert_eq!(buffer.snapshot(), "Hello"); + assert!(!buffer.is_truncated()); + assert_eq!(buffer.len(), 5); + } + + #[test] + fn test_push_multiple() { + let mut buffer = RingBuffer::new(1024); + buffer.push(b"Hello"); + buffer.push(b" "); + buffer.push(b"World"); + assert_eq!(buffer.snapshot(), "Hello World"); + assert!(!buffer.is_truncated()); + } + + #[test] + fn test_overflow_truncates() { + let mut buffer = RingBuffer::new(10); + buffer.push(b"Hello"); + assert!(!buffer.is_truncated()); + + buffer.push(b" World!"); + assert!(buffer.is_truncated()); + + // Should keep the last 10 bytes + let result = buffer.snapshot(); + assert!(result.len() <= 10); + assert!(result.ends_with("World!")); + } + + #[test] + fn test_utf8_boundary_preservation() { + let mut buffer = RingBuffer::new(10); + + // Push UTF-8 string with multi-byte characters (emoji) + // "Hello😀" is 10 bytes total (Hello=5, 😀=4, but we only have 10 capacity) + buffer.push("Hello".as_bytes()); + buffer.push("😀".as_bytes()); // This should fit exactly + + let result = buffer.snapshot(); + assert!(result == "Hello😀" || result == "Hello"); // Depending on implementation + assert!(!buffer.is_truncated() || result.len() <= 10); + } + + #[test] + fn test_utf8_truncation() { + let mut buffer = RingBuffer::new(8); + + // Push string that will cause truncation in the middle of multi-byte char + buffer.push("Hello😀World".as_bytes()); + + // Result should be valid UTF-8 (no partial emoji) + let result = buffer.snapshot(); + assert!(std::str::from_utf8(result.as_bytes()).is_ok()); + assert!(buffer.is_truncated()); + } + + #[test] + fn test_clear() { + let mut buffer = RingBuffer::new(1024); + buffer.push(b"Hello"); + buffer.clear(); + + assert_eq!(buffer.snapshot(), ""); + assert!(!buffer.is_truncated()); + assert_eq!(buffer.len(), 0); + } + + #[test] + fn test_empty_push() { + let mut buffer = RingBuffer::new(1024); + buffer.push(b""); + assert_eq!(buffer.snapshot(), ""); + assert!(!buffer.is_truncated()); + } + + #[test] + fn test_large_data_at_once() { + let mut buffer = RingBuffer::new(100); + let large_data = vec![b'A'; 500]; + + buffer.push(&large_data); + + assert!(buffer.is_truncated()); + assert_eq!(buffer.len(), 100); + + // Should contain only 'A's + let result = buffer.snapshot(); + assert!(result.chars().all(|c| c == 'A')); + assert!(result.len() <= 100); + } + + #[test] + fn test_find_char_boundary_ascii() { + let data = b"Hello"; + assert_eq!(find_char_boundary(data, 5), 5); + assert_eq!(find_char_boundary(data, 3), 3); + assert_eq!(find_char_boundary(data, 0), 0); + } + + #[test] + fn test_find_char_boundary_utf8() { + // "Hello😀" - emoji is 4 bytes: F0 9F 98 80 + let data = "Hello😀".as_bytes(); + let total_len = data.len(); // 5 + 4 = 9 + + // Should find boundaries correctly + assert_eq!(find_char_boundary(data, total_len), total_len); + assert_eq!(find_char_boundary(data, 5), 5); // After "Hello" + + // Middle of emoji should backtrack to before emoji + assert_eq!(find_char_boundary(data, 6), 5); // 1st continuation byte + assert_eq!(find_char_boundary(data, 7), 5); // 2nd continuation byte + assert_eq!(find_char_boundary(data, 8), 5); // 3rd continuation byte + } + + #[test] + fn test_find_char_boundary_multi_utf8() { + // Multiple multi-byte characters + let data = "日本語".as_bytes(); // 3 chars, 9 bytes (3 each) + + assert_eq!(find_char_boundary(data, 9), 9); // End + assert_eq!(find_char_boundary(data, 6), 6); // After 2nd char + assert_eq!(find_char_boundary(data, 3), 3); // After 1st char + + // Middle of 2nd character + assert_eq!(find_char_boundary(data, 4), 3); + assert_eq!(find_char_boundary(data, 5), 3); + + // Middle of 3rd character + assert_eq!(find_char_boundary(data, 7), 6); + assert_eq!(find_char_boundary(data, 8), 6); + } + + #[test] + fn test_circular_buffer_behavior() { + let mut buffer = RingBuffer::new(5); + + buffer.push(b"ABCDE"); + assert_eq!(buffer.snapshot(), "ABCDE"); + + buffer.push(b"FG"); + assert!(buffer.is_truncated()); + + let result = buffer.snapshot(); + assert_eq!(result.len(), 5); + assert!(result.ends_with("G")); + } +} diff --git a/crates/dirigent_tools/src/terminal/wait.rs b/crates/dirigent_tools/src/terminal/wait.rs new file mode 100644 index 0000000..cec9f99 --- /dev/null +++ b/crates/dirigent_tools/src/terminal/wait.rs @@ -0,0 +1,160 @@ +//! Terminal wait-for-exit operation. +//! +//! **Status**: Not yet implemented (TOOLS-TERM-03) +//! +//! This module will implement: +//! - Blocking wait for process completion +//! - Runtime timeout enforcement +//! - Exit status return + +use crate::config::TerminalConfig; +use crate::error::{ToolError, ToolResult}; +use serde::{Deserialize, Serialize}; + +/// Request to wait for terminal to exit. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WaitForTerminalExitRequest { + /// Terminal ID from create response. + pub terminal_id: String, +} + +/// Response with exit status. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WaitForTerminalExitResponse { + /// Process exit status code. + pub exit_status: i32, +} + +/// Wait for terminal process to complete. +/// +/// **Status**: Implemented (TOOLS-TERM-03) +/// +/// ## Implementation +/// +/// This function: +/// 1. Looks up TerminalId in registry +/// 2. Gets process handle +/// 3. If already exited → returns cached exit status immediately +/// 4. If still running: +/// - Awaits process completion (tokio child.wait()) +/// - Enforces `max_runtime_secs` timeout via tokio::time::timeout +/// - If timeout → kills process and returns error +/// 5. Caches exit status in registry +/// 6. Returns WaitForTerminalExitResponse +/// +/// ## Timeout Behavior +/// +/// - Uses `max_runtime_secs` from TerminalConfig +/// - On timeout: +/// - Process is killed (SIGKILL/TerminateProcess) +/// - Returns `ToolError::TerminalError` with timeout message +/// - Terminal remains in registry (can still get output) +/// +/// ## Blocking vs Polling +/// +/// - `wait_for_exit()` → Blocks until completion or timeout +/// - `get_output()` → Returns immediately with current output +/// +/// Use `wait_for_exit()` when you want to ensure completion before proceeding. +/// +/// ## Error Cases +/// +/// - Invalid/unknown TerminalId → `ToolError::TerminalNotFound` +/// - Terminal released → `ToolError::TerminalNotFound` +/// - Timeout exceeded → `ToolError::TerminalError` +/// +/// ## See Also +/// +/// - Task spec: `docs/building/04_acp_client/04_tasks_02_tools_and_sandboxing.md#TOOLS-TERM-03` +/// - Get output: `output.rs` +pub async fn wait_for_terminal_exit( + request: WaitForTerminalExitRequest, + config: &TerminalConfig, +) -> ToolResult<WaitForTerminalExitResponse> { + use crate::terminal::registry::global_registry; + use std::time::Duration; + + let registry = global_registry(); + let terminal_id = request.terminal_id.clone(); + + // Check if already exited + let already_exited = registry.get_mut(&terminal_id, |state| { + state.exit_status.as_ref().map(|status| { + #[cfg(unix)] + { + use std::os::unix::process::ExitStatusExt; + status.code().or_else(|| status.signal()).unwrap_or(-1) + } + #[cfg(not(unix))] + { + status.code().unwrap_or(-1) + } + }) + })?; + + if let Some(exit_status) = already_exited { + return Ok(WaitForTerminalExitResponse { exit_status }); + } + + // Process is still running, need to wait for it + // We need to take ownership of the process to wait on it + // This is tricky with the registry design, so we'll use a different approach: + // We'll poll the process status and check exit_status + + let timeout_duration = Duration::from_secs(config.max_runtime_secs); + let start_time = std::time::Instant::now(); + + loop { + // Check if timeout exceeded + if start_time.elapsed() >= timeout_duration { + // Kill the process due to timeout + registry.get_mut(&terminal_id, |state| { + let _ = state.process.start_kill(); + state.killed = true; + })?; + + return Err(ToolError::terminal_error(format!( + "Terminal timed out after {} seconds", + config.max_runtime_secs + ))); + } + + // Try to get exit status (non-blocking check) + let exit_status_result = registry.get_mut(&terminal_id, |state| { + // Try to check if process has exited + match state.process.try_wait() { + Ok(Some(status)) => { + // Process has exited + state.exit_status = Some(status); + let exit_code = { + #[cfg(unix)] + { + use std::os::unix::process::ExitStatusExt; + status.code().or_else(|| status.signal()).unwrap_or(-1) + } + #[cfg(not(unix))] + { + status.code().unwrap_or(-1) + } + }; + Some(exit_code) + } + Ok(None) => { + // Process is still running + None + } + Err(_e) => { + // Error checking status, treat as still running + None + } + } + })?; + + if let Some(exit_status) = exit_status_result { + return Ok(WaitForTerminalExitResponse { exit_status }); + } + + // Wait a bit before checking again + tokio::time::sleep(Duration::from_millis(100)).await; + } +} diff --git a/crates/dirigent_tools/src/tool/context.rs b/crates/dirigent_tools/src/tool/context.rs new file mode 100644 index 0000000..bb5afd2 --- /dev/null +++ b/crates/dirigent_tools/src/tool/context.rs @@ -0,0 +1,76 @@ +//! Per-call scoping context shared by every harness layer. + +use crate::config::{PermissionConfig, SandboxConfig}; +use crate::permission::check::PermissionContext; +use crate::tool::{ClientKind, ProtocolKind}; +use std::path::PathBuf; +use std::sync::Arc; + +/// Per-call scoping context. Passed to every layer below the registry. +/// +/// `connector_id` and `session_id` mirror what the existing +/// [`PermissionContext`] uses; `client_kind` and `protocol` are the new +/// override surface for per-client / per-protocol behaviour. +#[derive(Clone)] +pub struct ToolContext { + pub connector_id: Arc<str>, + pub session_id: Option<Arc<str>>, + pub client_kind: ClientKind, + pub protocol: ProtocolKind, + pub workspace_root: PathBuf, + pub sandbox: Arc<SandboxConfig>, + pub permission: Arc<PermissionConfig>, + pub permission_context: Arc<PermissionContext>, +} + +impl ToolContext { + /// Test/builder helper. Real callers compose this from connector state. + pub fn for_test( + connector_id: impl Into<Arc<str>>, + client_kind: ClientKind, + protocol: ProtocolKind, + workspace_root: PathBuf, + sandbox: SandboxConfig, + permission: PermissionConfig, + permission_context: PermissionContext, + ) -> Self { + Self { + connector_id: connector_id.into(), + session_id: None, + client_kind, + protocol, + workspace_root, + sandbox: Arc::new(sandbox), + permission: Arc::new(permission), + permission_context: Arc::new(permission_context), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::WhitelistConfig; + use crate::permission::whitelist::CompiledWhitelist; + + #[test] + fn context_builds_with_test_helper() { + let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); + let perm_ctx = PermissionContext::new( + "conn-1".to_string(), + None, + whitelist, + ); + let ctx = ToolContext::for_test( + "conn-1", + ClientKind::claude(), + ProtocolKind::acp(), + PathBuf::from("/tmp"), + SandboxConfig::default(), + PermissionConfig::default(), + perm_ctx, + ); + assert_eq!(&*ctx.connector_id, "conn-1"); + assert_eq!(ctx.client_kind, ClientKind::claude()); + } +} diff --git a/crates/dirigent_tools/src/tool/erase.rs b/crates/dirigent_tools/src/tool/erase.rs new file mode 100644 index 0000000..cdee3d6 --- /dev/null +++ b/crates/dirigent_tools/src/tool/erase.rs @@ -0,0 +1,197 @@ +//! `Tool` trait + object-safe `AnyTool` + `Erased<T>` adapter. + +use crate::tool::{ToolContext, ToolEventSink, ToolKind}; +use async_trait::async_trait; +use schemars::{JsonSchema, schema_for}; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::{mpsc, oneshot}; + +/// Streaming-aware tool input. +/// +/// Tools opt into streaming via [`Tool::supports_input_streaming`]. If they +/// do not, the harness buffers and always provides [`ToolInput::Final`]. +pub enum ToolInput<T> { + Final(T), + Partial { + partial: mpsc::UnboundedReceiver<serde_json::Value>, + final_input: oneshot::Receiver<T>, + }, +} + +/// JSON-shaped variant used by the object-safe `AnyTool` trait. +pub enum AnyToolInput { + Final(serde_json::Value), + Partial { + partial: mpsc::UnboundedReceiver<serde_json::Value>, + final_input: oneshot::Receiver<serde_json::Value>, + }, +} + +/// Strongly-typed tool implementation. +#[async_trait] +pub trait Tool: Send + Sync + 'static { + type Input: for<'de> Deserialize<'de> + Serialize + JsonSchema + Send + 'static; + type Output: Serialize + Send + 'static; + + const NAME: &'static str; + + fn kind() -> ToolKind; + + fn input_schema() -> serde_json::Value { + serde_json::to_value(schema_for!(Self::Input)).expect("schema_for must serialise") + } + + fn supports_input_streaming() -> bool { false } + + async fn run( + self: Arc<Self>, + input: ToolInput<Self::Input>, + events: ToolEventSink, + ctx: &ToolContext, + ) -> Result<Self::Output, Self::Output>; + + fn erase(self: Arc<Self>) -> Arc<dyn AnyTool> where Self: Sized { Arc::new(Erased(self)) } +} + +/// Object-safe variant. The registry stores `Arc<dyn AnyTool>`. +#[async_trait] +pub trait AnyTool: Send + Sync + 'static { + fn name(&self) -> &'static str; + fn kind(&self) -> ToolKind; + fn input_schema(&self) -> serde_json::Value; + fn supports_input_streaming(&self) -> bool; + + async fn run( + self: Arc<Self>, + input: AnyToolInput, + events: ToolEventSink, + ctx: &ToolContext, + ) -> Result<serde_json::Value, serde_json::Value>; +} + +/// Adapter from a typed `Tool` to `AnyTool`. +pub struct Erased<T: Tool>(pub Arc<T>); + +#[async_trait] +impl<T: Tool> AnyTool for Erased<T> { + fn name(&self) -> &'static str { T::NAME } + fn kind(&self) -> ToolKind { T::kind() } + fn input_schema(&self) -> serde_json::Value { T::input_schema() } + fn supports_input_streaming(&self) -> bool { T::supports_input_streaming() } + + async fn run( + self: Arc<Self>, + input: AnyToolInput, + events: ToolEventSink, + ctx: &ToolContext, + ) -> Result<serde_json::Value, serde_json::Value> { + let typed = match input { + AnyToolInput::Final(v) => { + let parsed: T::Input = serde_json::from_value(v).map_err(|e| { + serde_json::json!({ "error": format!("invalid input: {e}") }) + })?; + ToolInput::Final(parsed) + } + AnyToolInput::Partial { partial, final_input: _ } => { + // For v1: tools that opt into streaming receive Partial; the + // typed-final-input wiring is added when a streaming tool ships. + // Until then, only Final is fed by the dispatcher. + let _ = partial; + return Err(serde_json::json!({ + "error": "streaming inputs are not yet wired in v1 dispatcher" + })); + } + }; + let inner = self.0.clone(); + let result = inner.run(typed, events, ctx).await; + match result { + Ok(o) => Ok(serde_json::to_value(o).unwrap_or(serde_json::Value::Null)), + Err(o) => Err(serde_json::to_value(o).unwrap_or(serde_json::Value::Null)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig}; + use crate::permission::check::PermissionContext; + use crate::permission::whitelist::CompiledWhitelist; + use crate::tool::{ClientKind, ProtocolKind}; + use std::path::PathBuf; + + #[derive(Serialize, Deserialize, JsonSchema)] + struct EchoInput { msg: String } + + #[derive(Serialize, Deserialize)] + struct EchoOutput { echoed: String } + + struct EchoTool; + + #[async_trait] + impl Tool for EchoTool { + type Input = EchoInput; + type Output = EchoOutput; + const NAME: &'static str = "echo"; + fn kind() -> ToolKind { ToolKind::Other } + + async fn run( + self: Arc<Self>, + input: ToolInput<Self::Input>, + _events: ToolEventSink, + _ctx: &ToolContext, + ) -> Result<Self::Output, Self::Output> { + let i = match input { ToolInput::Final(i) => i, _ => unreachable!() }; + Ok(EchoOutput { echoed: i.msg }) + } + } + + fn ctx() -> ToolContext { + let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); + let pc = PermissionContext::new("conn-1".to_string(), None, wl); + ToolContext::for_test( + "conn-1", ClientKind::claude(), ProtocolKind::acp(), + PathBuf::from("/tmp"), + SandboxConfig::default(), PermissionConfig::default(), pc, + ) + } + + #[tokio::test] + async fn typed_tool_runs_and_returns_output() { + let tool: Arc<EchoTool> = Arc::new(EchoTool); + let any: Arc<dyn AnyTool> = tool.erase(); + let (sink, _rx) = ToolEventSink::new(); + + let result = any.run( + AnyToolInput::Final(serde_json::json!({ "msg": "hello" })), + sink, + &ctx(), + ).await.unwrap(); + + assert_eq!(result["echoed"], "hello"); + } + + #[tokio::test] + async fn invalid_input_returns_structured_error() { + let tool: Arc<EchoTool> = Arc::new(EchoTool); + let any: Arc<dyn AnyTool> = tool.erase(); + let (sink, _rx) = ToolEventSink::new(); + + let err = any.run( + AnyToolInput::Final(serde_json::json!({ "wrong": 1 })), + sink, + &ctx(), + ).await.unwrap_err(); + + assert!(err["error"].as_str().unwrap().contains("invalid input")); + } + + #[test] + fn name_and_kind_round_trip_through_erase() { + let tool: Arc<EchoTool> = Arc::new(EchoTool); + let any: Arc<dyn AnyTool> = tool.erase(); + assert_eq!(any.name(), "echo"); + assert_eq!(any.kind(), ToolKind::Other); + } +} diff --git a/crates/dirigent_tools/src/tool/events.rs b/crates/dirigent_tools/src/tool/events.rs new file mode 100644 index 0000000..e04ecb7 --- /dev/null +++ b/crates/dirigent_tools/src/tool/events.rs @@ -0,0 +1,134 @@ +//! Neutral tool events. Connector adapters translate to wire types. + +use crate::tool::ToolKind; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::mpsc; + +/// Opaque permission-request id, allocated by the dispatcher. +#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct PermissionRequestId(Arc<str>); + +impl PermissionRequestId { + /// Construct a new permission-request id from any string-like value. + pub fn new(value: impl Into<Arc<str>>) -> Self { + Self(value.into()) + } + + /// Borrow the inner id as a string slice. + pub fn as_str(&self) -> &str { + &self.0 + } +} + +/// Where a tool is operating (file path + optional line). +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ToolLocation { + pub path: String, + pub line: Option<u32>, +} + +/// Result content shape. Mirrors what most providers accept as a tool result. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolResultContent { + Text { text: Arc<str> }, + Json { value: serde_json::Value }, + Image { mime: Arc<str>, #[serde(with = "serde_bytes_arc")] data: Bytes }, + Parts { parts: Vec<ToolResultContent> }, +} + +impl ToolResultContent { + pub fn text(s: impl Into<Arc<str>>) -> Self { Self::Text { text: s.into() } } +} + +/// Events emitted by a running tool. Transport-agnostic. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ToolEvent { + Started { title: Arc<str>, kind: ToolKind, location: Option<ToolLocation> }, + TitleUpdate { title: Arc<str>, location: Option<ToolLocation> }, + PartialOutput { content: ToolResultContent }, + Status { message: Arc<str> }, + PermissionRequested { request_id: PermissionRequestId, summary: Arc<str> }, + Completed, + Failed, +} + +/// Sink a tool emits events into. Cheap to clone. +#[derive(Clone, Debug)] +pub struct ToolEventSink { + tx: mpsc::UnboundedSender<ToolEvent>, +} + +impl ToolEventSink { + pub fn new() -> (Self, mpsc::UnboundedReceiver<ToolEvent>) { + let (tx, rx) = mpsc::unbounded_channel(); + (Self { tx }, rx) + } + + /// Best-effort emit. Drops the event if the receiver is gone. + pub fn emit(&self, event: ToolEvent) { + let _ = self.tx.send(event); + } +} + +mod serde_bytes_arc { + use bytes::Bytes; + use serde::{Deserialize, Deserializer, Serialize, Serializer}; + + pub fn serialize<S: Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> { + serde_bytes::Bytes::new(b.as_ref()).serialize(s) + } + pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> { + let v: Vec<u8> = serde_bytes::ByteBuf::deserialize(d)?.into_vec(); + Ok(Bytes::from(v)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn sink_round_trips_event() { + let (sink, mut rx) = ToolEventSink::new(); + sink.emit(ToolEvent::Status { message: "hi".into() }); + let got = rx.recv().await.unwrap(); + match got { + ToolEvent::Status { message } => assert_eq!(&*message, "hi"), + _ => panic!("wrong variant"), + } + } + + #[test] + fn result_content_text_helper() { + match ToolResultContent::text("hello") { + ToolResultContent::Text { text } => assert_eq!(&*text, "hello"), + _ => panic!(), + } + } + + #[test] + fn tool_event_serde_round_trip() { + let ev = ToolEvent::PartialOutput { content: ToolResultContent::text("x") }; + let json = serde_json::to_string(&ev).unwrap(); + let _back: ToolEvent = serde_json::from_str(&json).unwrap(); + } + + #[test] + fn permission_request_id_constructor_and_accessor() { + let id = PermissionRequestId::new("abc"); + assert_eq!(id.as_str(), "abc"); + } + + #[test] + fn permission_request_id_serde_is_transparent_string() { + let id = PermissionRequestId::new("foo"); + let json = serde_json::to_string(&id).unwrap(); + assert_eq!(json, "\"foo\""); + let back: PermissionRequestId = serde_json::from_str(&json).unwrap(); + assert_eq!(back, id); + } +} diff --git a/crates/dirigent_tools/src/tool/kinds.rs b/crates/dirigent_tools/src/tool/kinds.rs new file mode 100644 index 0000000..8680cdd --- /dev/null +++ b/crates/dirigent_tools/src/tool/kinds.rs @@ -0,0 +1,135 @@ +//! Open-form client/protocol identifiers and tool category enum. + +use serde::{Deserialize, Serialize}; +use std::sync::{Arc, OnceLock}; + +/// Open newtype identifying the upstream client family (Claude, Codex, etc.). +/// +/// Use the provided constants for known clients; use [`ClientKind::custom`] +/// for anything else. Comparison is by inner string equality. +/// +/// The well-known constructors (`claude`, `codex`, `gemini`, `opencode`) +/// return cached values backed by a process-wide `OnceLock`, so calling them +/// repeatedly is a cheap `Arc` clone — no per-call heap allocation. +#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct ClientKind(Arc<str>); + +impl ClientKind { + pub fn claude() -> Self { + static CELL: OnceLock<ClientKind> = OnceLock::new(); + CELL.get_or_init(|| Self(Arc::from("claude"))).clone() + } + pub fn codex() -> Self { + static CELL: OnceLock<ClientKind> = OnceLock::new(); + CELL.get_or_init(|| Self(Arc::from("codex"))).clone() + } + pub fn gemini() -> Self { + static CELL: OnceLock<ClientKind> = OnceLock::new(); + CELL.get_or_init(|| Self(Arc::from("gemini"))).clone() + } + pub fn opencode() -> Self { + static CELL: OnceLock<ClientKind> = OnceLock::new(); + CELL.get_or_init(|| Self(Arc::from("opencode"))).clone() + } + + pub fn custom(name: impl Into<Arc<str>>) -> Self { Self(name.into()) } + pub fn as_str(&self) -> &str { &self.0 } +} + +/// Open newtype identifying the wire protocol (ACP, OpenCode, native). +/// +/// As with [`ClientKind`], the well-known constructors return cached values +/// backed by a process-wide `OnceLock`. +#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct ProtocolKind(Arc<str>); + +impl ProtocolKind { + pub fn acp() -> Self { + static CELL: OnceLock<ProtocolKind> = OnceLock::new(); + CELL.get_or_init(|| Self(Arc::from("acp"))).clone() + } + pub fn opencode() -> Self { + static CELL: OnceLock<ProtocolKind> = OnceLock::new(); + CELL.get_or_init(|| Self(Arc::from("opencode"))).clone() + } + pub fn native() -> Self { + static CELL: OnceLock<ProtocolKind> = OnceLock::new(); + CELL.get_or_init(|| Self(Arc::from("native"))).clone() + } + + pub fn custom(name: impl Into<Arc<str>>) -> Self { Self(name.into()) } + pub fn as_str(&self) -> &str { &self.0 } +} + +/// Coarse category of a tool. Connectors use this to pick wire-level hints +/// (e.g. ACP `ToolKind::Edit` triggers diff rendering on the client). +#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ToolKind { + Read, + Edit, + Search, + Execute, + Fetch, + Think, + Other, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn client_kind_constants_are_distinct() { + assert_ne!(ClientKind::claude(), ClientKind::codex()); + assert_eq!(ClientKind::claude(), ClientKind::claude()); + } + + #[test] + fn client_kind_custom_round_trips() { + let k = ClientKind::custom("aider"); + assert_eq!(k.as_str(), "aider"); + assert_eq!(k, ClientKind::custom("aider")); + } + + #[test] + fn protocol_kind_distinct() { + assert_ne!(ProtocolKind::acp(), ProtocolKind::opencode()); + } + + #[test] + fn tool_kind_serde_round_trip() { + let json = serde_json::to_string(&ToolKind::Edit).unwrap(); + assert_eq!(json, "\"edit\""); + let back: ToolKind = serde_json::from_str(&json).unwrap(); + assert_eq!(back, ToolKind::Edit); + } + + #[test] + fn well_known_client_kinds_are_cached() { + let a = ClientKind::claude(); + let b = ClientKind::claude(); + assert_eq!(a, b); + // Proves the cache works: same Arc, no per-call allocation. + assert!(Arc::ptr_eq(&a.0, &b.0)); + + // Sanity-check another well-known kind too. + let c1 = ClientKind::codex(); + let c2 = ClientKind::codex(); + assert!(Arc::ptr_eq(&c1.0, &c2.0)); + + // Custom values are *not* cached — each call allocates fresh. + let x = ClientKind::custom("aider"); + let y = ClientKind::custom("aider"); + assert_eq!(x, y); + assert!(!Arc::ptr_eq(&x.0, &y.0)); + } + + #[test] + fn well_known_protocol_kinds_are_cached() { + let a = ProtocolKind::acp(); + let b = ProtocolKind::acp(); + assert_eq!(a, b); + assert!(Arc::ptr_eq(&a.0, &b.0)); + } +} diff --git a/crates/dirigent_tools/src/tool/macros.rs b/crates/dirigent_tools/src/tool/macros.rs new file mode 100644 index 0000000..2424444 --- /dev/null +++ b/crates/dirigent_tools/src/tool/macros.rs @@ -0,0 +1,106 @@ +//! Compile-time built-in tool registration macro. +//! +//! ```ignore +//! crate::tools! { ReadTool, GrepTool, EditTool } +//! ``` +//! +//! Produces: +//! - `pub const ALL_BUILT_IN_TOOL_NAMES: &[&str] = &[...];` +//! - `pub fn built_in_tools() -> impl Iterator<Item = std::sync::Arc<dyn AnyTool>>` +//! - A compile-time uniqueness check on `T::NAME`. + +#[macro_export] +macro_rules! tools { + ($($tool:ty),* $(,)?) => { + pub const ALL_BUILT_IN_TOOL_NAMES: &[&str] = &[ + $( <$tool as $crate::tool::Tool>::NAME, )* + ]; + + const _: () = { + const fn str_eq(a: &str, b: &str) -> bool { + let a = a.as_bytes(); + let b = b.as_bytes(); + if a.len() != b.len() { return false; } + let mut i = 0; + while i < a.len() { + if a[i] != b[i] { return false; } + i += 1; + } + true + } + const NAMES: &[&str] = ALL_BUILT_IN_TOOL_NAMES; + let mut i = 0; + while i < NAMES.len() { + let mut j = i + 1; + while j < NAMES.len() { + if str_eq(NAMES[i], NAMES[j]) { + panic!("Duplicate built-in tool name"); + } + j += 1; + } + i += 1; + } + }; + + pub fn built_in_tools() -> Vec<std::sync::Arc<dyn $crate::tool::AnyTool>> { + vec![ + $( + { + let t: std::sync::Arc<$tool> = std::sync::Arc::new(<$tool>::default()); + <$tool as $crate::tool::Tool>::erase(t) + }, + )* + ] + } + }; +} + +#[cfg(test)] +mod tests { + use crate::tool::{Tool, ToolContext, ToolEventSink, ToolInput, ToolKind}; + use async_trait::async_trait; + use schemars::JsonSchema; + use serde::{Deserialize, Serialize}; + use std::sync::Arc; + + #[derive(Default)] + struct AlphaTool; + #[derive(Default)] + struct BetaTool; + + #[derive(Serialize, Deserialize, JsonSchema)] + struct Empty {} + + macro_rules! impl_tool { + ($t:ty, $name:literal) => { + #[async_trait] + impl Tool for $t { + type Input = Empty; + type Output = Empty; + const NAME: &'static str = $name; + fn kind() -> ToolKind { ToolKind::Other } + async fn run( + self: Arc<Self>, _i: ToolInput<Empty>, + _e: ToolEventSink, _c: &ToolContext, + ) -> Result<Empty, Empty> { Ok(Empty {}) } + } + }; + } + impl_tool!(AlphaTool, "alpha"); + impl_tool!(BetaTool, "beta"); + + crate::tools! { AlphaTool, BetaTool } + + #[test] + fn macro_lists_names() { + assert_eq!(ALL_BUILT_IN_TOOL_NAMES, &["alpha", "beta"]); + } + + #[test] + fn macro_constructs_erased_tools() { + let v = built_in_tools(); + assert_eq!(v.len(), 2); + assert_eq!(v[0].name(), "alpha"); + assert_eq!(v[1].name(), "beta"); + } +} diff --git a/crates/dirigent_tools/src/tool/mod.rs b/crates/dirigent_tools/src/tool/mod.rs new file mode 100644 index 0000000..b322855 --- /dev/null +++ b/crates/dirigent_tools/src/tool/mod.rs @@ -0,0 +1,14 @@ +//! Tool trait, registry-facing types, and per-call context. + +pub mod kinds; +pub mod events; +pub mod context; +pub mod erase; +pub mod macros; + +pub use kinds::{ClientKind, ProtocolKind, ToolKind}; +pub use events::{ + PermissionRequestId, ToolEvent, ToolEventSink, ToolLocation, ToolResultContent, +}; +pub use context::ToolContext; +pub use erase::{AnyTool, AnyToolInput, Erased, Tool, ToolInput}; diff --git a/crates/dirigent_tools/src/tools/mod.rs b/crates/dirigent_tools/src/tools/mod.rs new file mode 100644 index 0000000..cac72bf --- /dev/null +++ b/crates/dirigent_tools/src/tools/mod.rs @@ -0,0 +1,5 @@ +//! Built-in tool implementations registered against the `Tool` trait. + +pub mod read; + +pub use read::ReadTool; diff --git a/crates/dirigent_tools/src/tools/read.rs b/crates/dirigent_tools/src/tools/read.rs new file mode 100644 index 0000000..2b7014e --- /dev/null +++ b/crates/dirigent_tools/src/tools/read.rs @@ -0,0 +1,151 @@ +//! `ReadTool`: built-in trait wrapper around `crate::fs::read_text_file`. + +use crate::fs::read::{read_text_file, ReadTextFileRequest, ReadTextFileResponse}; +use crate::tool::{ + Tool, ToolContext, ToolEvent, ToolEventSink, ToolInput, ToolKind, ToolLocation, + ToolResultContent, +}; +use async_trait::async_trait; +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; + +#[derive(Serialize, Deserialize, JsonSchema, Clone)] +pub struct ReadInput { + /// Absolute path to the file to read. + pub path: String, + /// Optional 1-indexed start line. + #[serde(default)] pub line: Option<usize>, + /// Optional max number of lines to return. + #[serde(default)] pub limit: Option<usize>, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ReadOutput { + Ok { content: String }, + Err { error: String }, +} + +impl ReadOutput { + fn err(msg: impl Into<String>) -> Self { ReadOutput::Err { error: msg.into() } } +} + +#[derive(Default)] +pub struct ReadTool; + +#[async_trait] +impl Tool for ReadTool { + type Input = ReadInput; + type Output = ReadOutput; + const NAME: &'static str = "read"; + + fn kind() -> ToolKind { ToolKind::Read } + + async fn run( + self: Arc<Self>, + input: ToolInput<Self::Input>, + events: ToolEventSink, + ctx: &ToolContext, + ) -> Result<Self::Output, Self::Output> { + let i = match input { + ToolInput::Final(i) => i, + _ => return Err(ReadOutput::err("streaming inputs not supported by read")), + }; + + events.emit(ToolEvent::Started { + title: format!("Read {}", i.path).into(), + kind: ToolKind::Read, + location: Some(ToolLocation { path: i.path.clone(), line: i.line.map(|n| n as u32) }), + }); + + let req = ReadTextFileRequest { path: i.path.clone(), line: i.line, limit: i.limit }; + let res: ReadTextFileResponse = read_text_file(req, ctx.sandbox.as_ref()).await + .map_err(|e| ReadOutput::err(e.to_string()))?; + + events.emit(ToolEvent::PartialOutput { + content: ToolResultContent::text(res.content.clone()), + }); + events.emit(ToolEvent::Completed); + Ok(ReadOutput::Ok { content: res.content }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{PermissionConfig, SandboxConfig, WhitelistConfig}; + use crate::permission::check::PermissionContext; + use crate::permission::whitelist::CompiledWhitelist; + use crate::tool::{ClientKind, ProtocolKind}; + use std::io::Write; + use tempfile::TempDir; + + fn ctx_for(root: &std::path::Path) -> ToolContext { + let wl = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); + let pc = PermissionContext::new("c".to_string(), None, wl); + let mut sandbox = SandboxConfig::default(); + sandbox.allowed_roots = vec![root.to_path_buf()]; + ToolContext::for_test( + "c", ClientKind::claude(), ProtocolKind::acp(), + root.to_path_buf(), sandbox, PermissionConfig::default(), pc, + ) + } + + #[tokio::test] + async fn reads_file_through_trait() { + let dir = TempDir::new().unwrap(); + let path = dir.path().join("hello.txt"); + std::fs::File::create(&path).unwrap().write_all(b"hi\nthere\n").unwrap(); + + let tool = Arc::new(ReadTool); + let (sink, mut rx) = ToolEventSink::new(); + let result = Tool::run( + tool, + ToolInput::Final(ReadInput { + path: path.to_string_lossy().into_owned(), + line: None, limit: None, + }), + sink, + &ctx_for(dir.path()), + ).await.unwrap(); + + match result { + ReadOutput::Ok { content } => { + assert!(content.contains("hi")); + assert!(content.contains("there")); + } + ReadOutput::Err { error } => panic!("expected ok, got err: {error}"), + } + + // At least Started + PartialOutput + Completed should have fired. + let mut count = 0; + while let Ok(_ev) = rx.try_recv() { count += 1; } + assert!(count >= 3, "expected >=3 events, got {count}"); + } + + #[tokio::test] + async fn returns_structured_error_on_sandbox_violation() { + let dir = TempDir::new().unwrap(); + let other = TempDir::new().unwrap(); + let outside = other.path().join("nope.txt"); + std::fs::File::create(&outside).unwrap().write_all(b"x").unwrap(); + + let tool = Arc::new(ReadTool); + let (sink, _rx) = ToolEventSink::new(); + let result = Tool::run( + tool, + ToolInput::Final(ReadInput { + path: outside.to_string_lossy().into_owned(), + line: None, limit: None, + }), + sink, + &ctx_for(dir.path()), + ).await.unwrap_err(); + + match result { + ReadOutput::Err { error } => assert!(!error.is_empty()), + ReadOutput::Ok { .. } => panic!("expected sandbox error"), + } + } +} diff --git a/crates/dirigent_tools/tests/common/mod.rs b/crates/dirigent_tools/tests/common/mod.rs new file mode 100644 index 0000000..4556a8c --- /dev/null +++ b/crates/dirigent_tools/tests/common/mod.rs @@ -0,0 +1,108 @@ +//! Common test utilities for dirigent_tools. +//! +//! This module provides shared test utilities used across all test files. + +use std::path::{Path, PathBuf}; +use tempfile::TempDir; + +/// Create a temporary directory for testing. +/// +/// The directory is automatically cleaned up when the TempDir is dropped. +pub fn create_temp_dir() -> TempDir { + tempfile::tempdir().expect("Failed to create temp directory") +} + +/// Create a temporary directory with a specific prefix. +pub fn create_temp_dir_with_prefix(prefix: &str) -> TempDir { + tempfile::Builder::new() + .prefix(prefix) + .tempdir() + .expect("Failed to create temp directory with prefix") +} + +/// Create a file in a directory with given content. +pub fn create_test_file(dir: &Path, filename: &str, content: &str) -> PathBuf { + let file_path = dir.join(filename); + std::fs::write(&file_path, content).expect("Failed to write test file"); + file_path +} + +/// Read a file and return its content. +pub fn read_file_content(path: &Path) -> String { + std::fs::read_to_string(path).expect("Failed to read file") +} + +/// Create a sandboxed test environment with configured allowed roots. +pub struct SandboxedTestEnv { + // RAII: TempDir must be kept alive to prevent premature cleanup + pub _temp_dir: TempDir, + pub allowed_root: PathBuf, + pub blocked_dir: PathBuf, + pub outside_dir: TempDir, +} + +impl SandboxedTestEnv { + /// Create a new sandboxed test environment. + pub fn new() -> Self { + let temp_dir = create_temp_dir(); + let allowed_root = temp_dir.path().to_path_buf(); + let blocked_dir = allowed_root.join("blocked"); + let outside_dir = create_temp_dir_with_prefix("outside"); + + std::fs::create_dir_all(&blocked_dir).expect("Failed to create blocked directory"); + + Self { + _temp_dir: temp_dir, + allowed_root, + blocked_dir, + outside_dir, + } + } + + /// Create a file inside the allowed root. + pub fn create_allowed_file(&self, filename: &str, content: &str) -> PathBuf { + create_test_file(&self.allowed_root, filename, content) + } + + /// Create a file outside the allowed root. + pub fn create_outside_file(&self, filename: &str, content: &str) -> PathBuf { + create_test_file(self.outside_dir.path(), filename, content) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_temp_dir() { + let temp_dir = create_temp_dir(); + assert!(temp_dir.path().exists()); + } + + #[test] + fn test_create_test_file() { + let temp_dir = create_temp_dir(); + let file_path = create_test_file(temp_dir.path(), "test.txt", "Hello, world!"); + assert!(file_path.exists()); + assert_eq!(read_file_content(&file_path), "Hello, world!"); + } + + #[test] + fn test_sandboxed_test_env() { + let env = SandboxedTestEnv::new(); + assert!(env.allowed_root.exists()); + assert!(env.blocked_dir.exists()); + assert!(env.outside_dir.path().exists()); + + let allowed_file = env.create_allowed_file("test.txt", "allowed"); + let outside_file = env.create_outside_file("test.txt", "outside"); + + assert!(allowed_file.exists()); + assert!(outside_file.exists()); + assert_ne!( + allowed_file.parent().unwrap(), + outside_file.parent().unwrap() + ); + } +} diff --git a/crates/dirigent_tools/tests/fermata_dep.rs b/crates/dirigent_tools/tests/fermata_dep.rs new file mode 100644 index 0000000..766a625 --- /dev/null +++ b/crates/dirigent_tools/tests/fermata_dep.rs @@ -0,0 +1,14 @@ +use dirigent_tools::policy::{Decision, Op, Policy}; +use std::fs; +use tempfile::TempDir; + +#[test] +fn dirigent_tools_can_use_fermata_policy() { + let tmp = TempDir::new().unwrap(); + fs::write(tmp.path().join(".botignore"), ".env\n").unwrap(); + let target = tmp.path().join(".env"); + fs::write(&target, "").unwrap(); + let policy = Policy::load(tmp.path()).unwrap(); + let d = policy.check(Op::Read, &target).unwrap(); + assert!(matches!(d, Decision::Deny(_))); +} diff --git a/crates/dirigent_tools/tests/file_operations.rs b/crates/dirigent_tools/tests/file_operations.rs new file mode 100644 index 0000000..2150c92 --- /dev/null +++ b/crates/dirigent_tools/tests/file_operations.rs @@ -0,0 +1,541 @@ +//! Integration tests for file operations (read, write, diff, edit). +//! +//! These tests use real filesystem operations with temporary directories +//! to verify end-to-end functionality including sandboxing and error handling. + +use dirigent_tools::config::{EolPolicy, PermissionConfig, SandboxConfig}; +use dirigent_tools::error::ToolError; +use dirigent_tools::fs::diff::generate_diff; +use dirigent_tools::fs::edit::{edit_file, EditFileRequest, EditOperation}; +use dirigent_tools::fs::read::{read_text_file, ReadTextFileRequest}; +use dirigent_tools::fs::write::{write_text_file, WriteTextFileRequest}; +use dirigent_tools::permission::check::PermissionContext; +use dirigent_tools::permission::whitelist::CompiledWhitelist; +use std::path::PathBuf; +use tempfile::TempDir; + +/// Create a test configuration with a single allowed root. +fn test_config(allowed_root: PathBuf) -> SandboxConfig { + let mut config = SandboxConfig::default(); + // Ensure the root is canonical + let canonical_root = dunce::canonicalize(&allowed_root) + .unwrap_or_else(|_| panic!("Failed to canonicalize test root: {:?}", allowed_root)); + config.allowed_roots = vec![canonical_root]; + config.write_enabled = true; + config.read_enabled = true; + config +} + +/// Create a test permission config with YOLO mode (no prompts). +fn test_permission_config() -> PermissionConfig { + PermissionConfig::default() +} + +/// Create a test permission context for a test connector. +fn test_permission_context() -> PermissionContext { + let whitelist = CompiledWhitelist::compile(&Default::default()).unwrap(); + PermissionContext::new( + "test-connector".to_string(), + Some("test-session".to_string()), + whitelist, + ) +} + +// ============================================================================= +// Read Operation Tests +// ============================================================================= + +#[tokio::test] +async fn test_read_entire_file() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + + // Create test file + std::fs::write(&file_path, "line1\nline2\nline3").unwrap(); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = ReadTextFileRequest { + path: file_path.to_string_lossy().to_string(), + line: None, + limit: None, + }; + + let response = read_text_file(request, &config).await.unwrap(); + assert_eq!(response.content, "line1\nline2\nline3"); +} + +#[tokio::test] +async fn test_read_with_line_limit() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + + // Create test file + std::fs::write(&file_path, "line1\nline2\nline3\nline4\nline5").unwrap(); + + let config = test_config(temp_dir.path().to_path_buf()); + + // Read from line 2, limit 2 lines + let request = ReadTextFileRequest { + path: file_path.to_string_lossy().to_string(), + line: Some(2), + limit: Some(2), + }; + + let response = read_text_file(request, &config).await.unwrap(); + assert_eq!(response.content, "line2\nline3"); +} + +#[tokio::test] +async fn test_read_only_limit() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + + std::fs::write(&file_path, "line1\nline2\nline3\nline4").unwrap(); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = ReadTextFileRequest { + path: file_path.to_string_lossy().to_string(), + line: None, + limit: Some(2), + }; + + let response = read_text_file(request, &config).await.unwrap(); + assert_eq!(response.content, "line1\nline2"); +} + +#[tokio::test] +async fn test_read_file_not_found() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("nonexistent.txt"); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = ReadTextFileRequest { + path: file_path.to_string_lossy().to_string(), + line: None, + limit: None, + }; + + let result = read_text_file(request, &config).await; + assert!(matches!(result, Err(ToolError::NotFound { .. }))); +} + +#[tokio::test] +async fn test_read_non_utf8() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("binary.bin"); + + // Write invalid UTF-8 + std::fs::write(&file_path, &[0xFF, 0xFE, 0xFD]).unwrap(); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = ReadTextFileRequest { + path: file_path.to_string_lossy().to_string(), + line: None, + limit: None, + }; + + let result = read_text_file(request, &config).await; + assert!(matches!(result, Err(ToolError::EncodingUnsupported { .. }))); +} + +#[tokio::test] +async fn test_read_outside_sandbox() { + let temp_dir = TempDir::new().unwrap(); + let other_temp_dir = TempDir::new().unwrap(); + let file_path = other_temp_dir.path().join("test.txt"); + + std::fs::write(&file_path, "content").unwrap(); + + // Config only allows temp_dir, not other_temp_dir + let config = test_config(temp_dir.path().to_path_buf()); + let request = ReadTextFileRequest { + path: file_path.to_string_lossy().to_string(), + line: None, + limit: None, + }; + + let result = read_text_file(request, &config).await; + assert!(matches!(result, Err(ToolError::SandboxViolation { .. }))); +} + +#[tokio::test] +async fn test_read_disabled() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + std::fs::write(&file_path, "content").unwrap(); + + let mut config = test_config(temp_dir.path().to_path_buf()); + config.read_enabled = false; + + let request = ReadTextFileRequest { + path: file_path.to_string_lossy().to_string(), + line: None, + limit: None, + }; + + let result = read_text_file(request, &config).await; + assert!(matches!(result, Err(ToolError::PermissionDenied { .. }))); +} + +// ============================================================================= +// Write Operation Tests +// ============================================================================= + +#[tokio::test] +async fn test_write_new_file() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("new.txt"); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = WriteTextFileRequest { + path: file_path.to_string_lossy().to_string(), + content: "hello world".to_string(), + }; + + write_text_file(request, &config, &test_permission_config(), &test_permission_context()).await.unwrap(); + + // Verify file was written + let content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(content, "hello world"); +} + +#[tokio::test] +async fn test_write_overwrite_existing() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("existing.txt"); + + std::fs::write(&file_path, "old content").unwrap(); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = WriteTextFileRequest { + path: file_path.to_string_lossy().to_string(), + content: "new content".to_string(), + }; + + write_text_file(request, &config, &test_permission_config(), &test_permission_context()).await.unwrap(); + + let content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(content, "new content"); +} + +#[tokio::test] +async fn test_write_create_parent_directories() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("subdir").join("nested").join("file.txt"); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = WriteTextFileRequest { + path: file_path.to_string_lossy().to_string(), + content: "nested content".to_string(), + }; + + write_text_file(request, &config, &test_permission_config(), &test_permission_context()).await.unwrap(); + + let content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(content, "nested content"); +} + +#[tokio::test] +async fn test_write_eol_normalization_lf() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("eol_lf.txt"); + + let mut config = test_config(temp_dir.path().to_path_buf()); + config.eol_policy = EolPolicy::Lf; + + let request = WriteTextFileRequest { + path: file_path.to_string_lossy().to_string(), + content: "line1\r\nline2\rline3\n".to_string(), + }; + + write_text_file(request, &config, &test_permission_config(), &test_permission_context()).await.unwrap(); + + let content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(content, "line1\nline2\nline3\n"); +} + +#[tokio::test] +async fn test_write_eol_normalization_crlf() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("eol_crlf.txt"); + + let mut config = test_config(temp_dir.path().to_path_buf()); + config.eol_policy = EolPolicy::Crlf; + + let request = WriteTextFileRequest { + path: file_path.to_string_lossy().to_string(), + content: "line1\nline2\rline3\r\n".to_string(), + }; + + write_text_file(request, &config, &test_permission_config(), &test_permission_context()).await.unwrap(); + + let content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(content, "line1\r\nline2\r\nline3\r\n"); +} + +#[tokio::test] +async fn test_write_size_limit_exceeded() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("large.txt"); + + let mut config = test_config(temp_dir.path().to_path_buf()); + config.max_write_bytes = 10; + + let request = WriteTextFileRequest { + path: file_path.to_string_lossy().to_string(), + content: "this is way too much content".to_string(), + }; + + let result = write_text_file(request, &config, &test_permission_config(), &test_permission_context()).await; + assert!(matches!(result, Err(ToolError::FileTooLarge { .. }))); +} + +#[tokio::test] +async fn test_write_outside_sandbox() { + let temp_dir = TempDir::new().unwrap(); + let other_temp_dir = TempDir::new().unwrap(); + let file_path = other_temp_dir.path().join("test.txt"); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = WriteTextFileRequest { + path: file_path.to_string_lossy().to_string(), + content: "content".to_string(), + }; + + let result = write_text_file(request, &config, &test_permission_config(), &test_permission_context()).await; + assert!(matches!(result, Err(ToolError::SandboxViolation { .. }))); +} + +#[tokio::test] +async fn test_write_disabled() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("test.txt"); + + let mut config = test_config(temp_dir.path().to_path_buf()); + config.write_enabled = false; + + let request = WriteTextFileRequest { + path: file_path.to_string_lossy().to_string(), + content: "content".to_string(), + }; + + let result = write_text_file(request, &config, &test_permission_config(), &test_permission_context()).await; + assert!(matches!(result, Err(ToolError::PermissionDenied { .. }))); +} + +// ============================================================================= +// Diff Generation Tests +// ============================================================================= + +#[test] +fn test_diff_new_file() { + let old = ""; + let new = "line1\nline2\nline3"; + let path = PathBuf::from("test.txt"); + + let diff = generate_diff(old, new, &path); + + assert!(diff.contains("--- /dev/null")); + assert!(diff.contains("+++ test.txt")); + assert!(diff.contains("+line1")); + assert!(diff.contains("+line2")); + assert!(diff.contains("+line3")); +} + +#[test] +fn test_diff_deleted_file() { + let old = "line1\nline2\nline3"; + let new = ""; + let path = PathBuf::from("test.txt"); + + let diff = generate_diff(old, new, &path); + + assert!(diff.contains("--- test.txt")); + assert!(diff.contains("+++ /dev/null")); + assert!(diff.contains("-line1")); + assert!(diff.contains("-line2")); + assert!(diff.contains("-line3")); +} + +#[test] +fn test_diff_modification() { + let old = "line1\nline2\nline3"; + let new = "line1\nmodified\nline3"; + let path = PathBuf::from("test.txt"); + + let diff = generate_diff(old, new, &path); + + assert!(diff.contains("--- test.txt")); + assert!(diff.contains("+++ test.txt")); + assert!(diff.contains("-line2")); + assert!(diff.contains("+modified")); +} + +// ============================================================================= +// Edit Operation Tests +// ============================================================================= + +#[tokio::test] +async fn test_edit_replace_once() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("edit.txt"); + + std::fs::write(&file_path, "foo bar foo baz").unwrap(); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = EditFileRequest { + path: file_path.to_string_lossy().to_string(), + edits: vec![EditOperation::Replace { + old_text: "foo".to_string(), + new_text: "qux".to_string(), + replace_all: false, + }], + }; + + let response = edit_file(request, &config, &test_permission_config(), &test_permission_context()).await.unwrap(); + + // Verify file was edited + let content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(content, "qux bar foo baz"); + + // Verify diff was generated + assert!(response.diff.contains("-foo")); + assert!(response.diff.contains("+qux")); +} + +#[tokio::test] +async fn test_edit_replace_all() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("edit.txt"); + + std::fs::write(&file_path, "foo bar foo baz foo").unwrap(); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = EditFileRequest { + path: file_path.to_string_lossy().to_string(), + edits: vec![EditOperation::Replace { + old_text: "foo".to_string(), + new_text: "qux".to_string(), + replace_all: true, + }], + }; + + edit_file(request, &config, &test_permission_config(), &test_permission_context()).await.unwrap(); + + let content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(content, "qux bar qux baz qux"); +} + +#[tokio::test] +async fn test_edit_multiple_operations() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("edit.txt"); + + std::fs::write(&file_path, "hello world\nhello rust").unwrap(); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = EditFileRequest { + path: file_path.to_string_lossy().to_string(), + edits: vec![ + EditOperation::Replace { + old_text: "hello".to_string(), + new_text: "goodbye".to_string(), + replace_all: true, + }, + EditOperation::Replace { + old_text: "world".to_string(), + new_text: "universe".to_string(), + replace_all: false, + }, + ], + }; + + edit_file(request, &config, &test_permission_config(), &test_permission_context()).await.unwrap(); + + let content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(content, "goodbye universe\ngoodbye rust"); +} + +#[tokio::test] +async fn test_edit_nonexistent_file() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("nonexistent.txt"); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = EditFileRequest { + path: file_path.to_string_lossy().to_string(), + edits: vec![EditOperation::Replace { + old_text: "foo".to_string(), + new_text: "bar".to_string(), + replace_all: false, + }], + }; + + let result = edit_file(request, &config, &test_permission_config(), &test_permission_context()).await; + assert!(matches!(result, Err(ToolError::NotFound { .. }))); +} + +#[tokio::test] +async fn test_edit_outside_sandbox() { + let temp_dir = TempDir::new().unwrap(); + let other_temp_dir = TempDir::new().unwrap(); + let file_path = other_temp_dir.path().join("test.txt"); + + std::fs::write(&file_path, "content").unwrap(); + + let config = test_config(temp_dir.path().to_path_buf()); + let request = EditFileRequest { + path: file_path.to_string_lossy().to_string(), + edits: vec![EditOperation::Replace { + old_text: "content".to_string(), + new_text: "modified".to_string(), + replace_all: false, + }], + }; + + let result = edit_file(request, &config, &test_permission_config(), &test_permission_context()).await; + assert!(matches!(result, Err(ToolError::SandboxViolation { .. }))); +} + +// ============================================================================= +// Windows-Specific Tests +// ============================================================================= + +#[cfg(windows)] +#[tokio::test] +async fn test_read_write_windows_line_endings() { + let temp_dir = TempDir::new().unwrap(); + let file_path = temp_dir.path().join("windows.txt"); + + // Write file with CRLF + std::fs::write(&file_path, "line1\r\nline2\r\nline3").unwrap(); + + let config = test_config(temp_dir.path().to_path_buf()); + + // Read file + let read_request = ReadTextFileRequest { + path: file_path.to_string_lossy().to_string(), + line: None, + limit: None, + }; + + let response = read_text_file(read_request, &config).await.unwrap(); + // Content should preserve CRLF when read + assert!(response.content.contains("\r\n")); + + // Write with LF normalization + let mut write_config = config.clone(); + write_config.eol_policy = EolPolicy::Lf; + + let write_request = WriteTextFileRequest { + path: file_path.to_string_lossy().to_string(), + content: response.content, + }; + + write_text_file(write_request, &write_config, &test_permission_config(), &test_permission_context()).await.unwrap(); + + // Verify normalization happened + let final_content = std::fs::read_to_string(&file_path).unwrap(); + assert!(!final_content.contains("\r\n")); + assert!(final_content.contains("\n")); +} diff --git a/crates/dirigent_tools/tests/lib_test.rs b/crates/dirigent_tools/tests/lib_test.rs new file mode 100644 index 0000000..bebcaec --- /dev/null +++ b/crates/dirigent_tools/tests/lib_test.rs @@ -0,0 +1,47 @@ +//! Basic compilation and infrastructure tests for dirigent_tools. + +mod common; + +use common::*; + +#[test] +fn test_package_compiles() { + // Basic smoke test to ensure the package compiles + assert!(true); +} + +#[test] +fn test_temp_dir_creation() { + let temp_dir = create_temp_dir(); + assert!(temp_dir.path().exists()); + assert!(temp_dir.path().is_dir()); +} + +#[test] +fn test_test_file_creation() { + let temp_dir = create_temp_dir(); + let file_path = create_test_file(temp_dir.path(), "test.txt", "test content"); + assert!(file_path.exists()); + assert_eq!(read_file_content(&file_path), "test content"); +} + +#[test] +fn test_sandboxed_env_creation() { + let env = SandboxedTestEnv::new(); + assert!(env.allowed_root.exists()); + assert!(env.blocked_dir.exists()); + assert!(env.outside_dir.path().exists()); +} + +#[test] +fn test_fixture_files_exist() { + // Verify test fixtures are present + let fixtures_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("tests") + .join("test_fixtures"); + + assert!(fixtures_dir.join("sample_text.txt").exists()); + assert!(fixtures_dir.join("unicode_text.txt").exists()); + assert!(fixtures_dir.join("large_text.txt").exists()); + assert!(fixtures_dir.join("README.md").exists()); +} diff --git a/crates/dirigent_tools/tests/path_normalization.rs b/crates/dirigent_tools/tests/path_normalization.rs new file mode 100644 index 0000000..9e5e6e5 --- /dev/null +++ b/crates/dirigent_tools/tests/path_normalization.rs @@ -0,0 +1,190 @@ +//! Integration tests for path normalization and containment. +//! +//! These tests use temporary directories to test real filesystem operations. + +use dirigent_tools::config::SandboxConfig; +use dirigent_tools::path::{canonicalize_path, check_containment, validate_path, SymlinkPolicy}; +use dirigent_tools::ToolError; +use std::fs; +use std::path::PathBuf; +use tempfile::TempDir; + +#[test] +fn test_validate_path_with_real_filesystem() { + // Create a temp directory structure + let temp_dir = TempDir::new().unwrap(); + let project_dir = temp_dir.path().join("project"); + let src_dir = project_dir.join("src"); + fs::create_dir_all(&src_dir).unwrap(); + + // Create a test file + let test_file = src_dir.join("main.rs"); + fs::write(&test_file, "fn main() {}").unwrap(); + + // Create config + let mut config = SandboxConfig::default(); + config.allowed_roots = vec![project_dir.clone()]; + config.blocked_paths = vec!["**/.env".to_string()]; + + // Test: Valid path within allowed root + let result = validate_path(test_file.to_str().unwrap(), &config); + assert!(result.is_ok()); + + // Test: Path at the root level (should be allowed) + let root_file = project_dir.join("README.md"); + fs::write(&root_file, "# Project").unwrap(); + let result = validate_path(root_file.to_str().unwrap(), &config); + assert!(result.is_ok()); + + // Test: Path outside allowed roots + let outside_path = temp_dir.path().join("outside.txt"); + fs::write(&outside_path, "outside").unwrap(); + let result = validate_path(outside_path.to_str().unwrap(), &config); + assert!(result.is_err()); + assert!(matches!(result, Err(ToolError::SandboxViolation { .. }))); +} + +#[test] +fn test_validate_path_blocked() { + let temp_dir = TempDir::new().unwrap(); + let project_dir = temp_dir.path().join("project"); + fs::create_dir_all(&project_dir).unwrap(); + + // Create a .env file + let env_file = project_dir.join(".env"); + fs::write(&env_file, "SECRET=value").unwrap(); + + // Create config with blocklist + let mut config = SandboxConfig::default(); + config.allowed_roots = vec![project_dir.clone()]; + config.blocked_paths = vec!["**/.env".to_string()]; + + // Test: Blocked path + let result = validate_path(env_file.to_str().unwrap(), &config); + assert!(result.is_err()); + assert!(matches!(result, Err(ToolError::BlockedPath { .. }))); +} + +#[test] +fn test_validate_path_non_existent_for_write() { + let temp_dir = TempDir::new().unwrap(); + let project_dir = temp_dir.path().join("project"); + fs::create_dir_all(&project_dir).unwrap(); + + // Non-existent file (for write operations) + let new_file = project_dir.join("new_file.txt"); + + let mut config = SandboxConfig::default(); + config.allowed_roots = vec![project_dir.clone()]; + + // Test: Non-existent file within allowed root + let result = validate_path(new_file.to_str().unwrap(), &config); + assert!(result.is_ok()); +} + +#[test] +fn test_canonicalize_path_with_real_filesystem() { + let temp_dir = TempDir::new().unwrap(); + let project_dir = temp_dir.path().join("project"); + fs::create_dir_all(&project_dir).unwrap(); + + let test_file = project_dir.join("test.txt"); + fs::write(&test_file, "content").unwrap(); + + let policy = SymlinkPolicy::default(); + + // Test: Canonicalize existing file + let canonical = canonicalize_path(&test_file, &policy).unwrap(); + assert!(canonical.is_absolute()); + assert!(canonical.exists()); + + // Verify no ".." components + assert!(!canonical.to_string_lossy().contains("..")); +} + +#[test] +fn test_containment_with_real_filesystem() { + let temp_dir = TempDir::new().unwrap(); + let root_dir = temp_dir.path().join("root"); + let subdir = root_dir.join("subdir"); + fs::create_dir_all(&subdir).unwrap(); + + let file = subdir.join("file.txt"); + fs::write(&file, "content").unwrap(); + + // Canonicalize paths + let canonical_root = dunce::canonicalize(&root_dir).unwrap(); + let canonical_file = dunce::canonicalize(&file).unwrap(); + + // Test: File is contained in root + let result = check_containment(&canonical_file, &[canonical_root.clone()]); + assert!(result.is_ok()); + + // Test: Root is not strictly contained in itself + let result = check_containment(&canonical_root, &[canonical_root.clone()]); + assert!(result.is_err()); +} + +#[cfg(unix)] +#[test] +fn test_symlink_handling_unix() { + use std::os::unix::fs::symlink; + + let temp_dir = TempDir::new().unwrap(); + let allowed_dir = temp_dir.path().join("allowed"); + let outside_dir = temp_dir.path().join("outside"); + fs::create_dir_all(&allowed_dir).unwrap(); + fs::create_dir_all(&outside_dir).unwrap(); + + // Create a file outside the allowed root + let outside_file = outside_dir.join("secret.txt"); + fs::write(&outside_file, "secret").unwrap(); + + // Create a symlink inside allowed root pointing outside + let symlink_path = allowed_dir.join("link_to_secret"); + symlink(&outside_file, &symlink_path).unwrap(); + + let mut config = SandboxConfig::default(); + config.allowed_roots = vec![dunce::canonicalize(&allowed_dir).unwrap()]; + config.allow_symlink_escape = false; // Don't allow escapes + + // Test: Symlink escape should be blocked + let result = validate_path(symlink_path.to_str().unwrap(), &config); + assert!(result.is_err()); +} + +#[cfg(windows)] +#[test] +fn test_windows_reserved_device_names() { + let policy = SymlinkPolicy::default(); + + // Reserved device names should be rejected + let reserved_names = vec!["CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"]; + + for name in reserved_names { + let path = PathBuf::from(format!("C:\\{}", name)); + let result = canonicalize_path(&path, &policy); + assert!(result.is_err(), "Should reject reserved name: {}", name); + } +} + +#[cfg(windows)] +#[test] +fn test_windows_path_forms() { + // Test various Windows path forms with temp directory + let temp_dir = TempDir::new().unwrap(); + let test_file = temp_dir.path().join("test.txt"); + fs::write(&test_file, "content").unwrap(); + + let policy = SymlinkPolicy::default(); + + // Test: Standard absolute path + let canonical = canonicalize_path(&test_file, &policy).unwrap(); + assert!(canonical.is_absolute()); + + // Test: Drive letter normalization (uppercase) + let path_str = canonical.to_string_lossy(); + if path_str.len() >= 2 && path_str.chars().nth(1) == Some(':') { + assert!(path_str.chars().nth(0).unwrap().is_ascii_uppercase()); + } +} diff --git a/crates/dirigent_tools/tests/permission_integration.rs b/crates/dirigent_tools/tests/permission_integration.rs new file mode 100644 index 0000000..49ba2a2 --- /dev/null +++ b/crates/dirigent_tools/tests/permission_integration.rs @@ -0,0 +1,437 @@ +//! Integration tests for the permission system (TOOLS-PERM-01 through TOOLS-PERM-04). +//! +//! This test suite validates: +//! - Permission modes (ask, whitelist, yolo) +//! - Decision caching with TTL and scope +//! - Whitelist pattern matching +//! - ACP integration (stubbed for now) +//! - Integration with file operations + +use dirigent_tools::config::{ + DecisionScope, PermissionConfig, PermissionMode, SandboxConfig, WhitelistConfig, +}; +use dirigent_tools::error::ToolError; +use dirigent_tools::fs::write::{write_text_file, WriteTextFileRequest}; +use dirigent_tools::permission::check::{check_permission, PermissionContext}; +use dirigent_tools::permission::cache::{CacheKey, PermissionDecision}; +use dirigent_tools::permission::whitelist::{CompiledWhitelist, PermissionOperation}; +use std::time::Duration; +use tempfile::TempDir; + +/// Create a test sandbox configuration. +fn create_test_sandbox(temp_dir: &TempDir) -> SandboxConfig { + let mut config = SandboxConfig { + allowed_roots: vec![temp_dir.path().to_path_buf()], + blocked_paths: vec![], + allow_symlink_escape: false, + follow_symlinks_within_roots: true, + read_enabled: true, + write_enabled: true, + max_read_bytes: 1_000_000, + max_write_bytes: 1_000_000, + eol_policy: dirigent_tools::config::EolPolicy::Preserve, + encoding: "utf-8".to_string(), + }; + config.normalize_roots(); + config +} + +#[tokio::test] +async fn test_yolo_mode_allows_everything() { + let temp_dir = TempDir::new().unwrap(); + let config = PermissionConfig { + mode: PermissionMode::Yolo, + remember_decisions: true, + remember_ttl_secs: 300, + scope: DecisionScope::PerConnector, + whitelist: WhitelistConfig::default(), + }; + + let whitelist = CompiledWhitelist::compile(&config.whitelist).unwrap(); + let context = PermissionContext::new( + "test-connector".to_string(), + Some("test-session".to_string()), + whitelist, + ); + + // All operations should be allowed in yolo mode + let operations = vec![ + PermissionOperation::Read { + path: temp_dir.path().join("file.txt").display().to_string(), + }, + PermissionOperation::Write { + path: temp_dir.path().join("file.txt").display().to_string(), + }, + PermissionOperation::Execute { + command: "dangerous_command".to_string(), + cwd: temp_dir.path().display().to_string(), + }, + ]; + + for operation in operations { + let decision = check_permission(&operation, &context, &config).await.unwrap(); + assert_eq!(decision, PermissionDecision::Allowed); + } +} + +#[tokio::test] +async fn test_whitelist_mode_auto_approves_matching() { + let temp_dir = TempDir::new().unwrap(); + let temp_path = temp_dir.path().display().to_string(); + + let config = PermissionConfig { + mode: PermissionMode::Whitelist, + remember_decisions: true, + remember_ttl_secs: 300, + scope: DecisionScope::PerConnector, + whitelist: WhitelistConfig { + write_paths: vec![format!("{}/**", temp_path)], + execute_commands: vec!["cargo".to_string(), "npm".to_string()], + }, + }; + + let whitelist = CompiledWhitelist::compile(&config.whitelist).unwrap(); + let context = PermissionContext::new( + "test-connector".to_string(), + None, + whitelist, + ); + + // Read should always be allowed + let read_op = PermissionOperation::Read { + path: "/any/path/file.txt".to_string(), + }; + let decision = check_permission(&read_op, &context, &config).await.unwrap(); + assert_eq!(decision, PermissionDecision::Allowed); + + // Write within whitelist should be allowed + let write_ok = PermissionOperation::Write { + path: temp_dir.path().join("file.txt").display().to_string(), + }; + let decision = check_permission(&write_ok, &context, &config).await.unwrap(); + assert_eq!(decision, PermissionDecision::Allowed); + + // Execute whitelisted command should be allowed + let exec_ok = PermissionOperation::Execute { + command: "cargo".to_string(), + cwd: temp_dir.path().display().to_string(), + }; + let decision = check_permission(&exec_ok, &context, &config).await.unwrap(); + assert_eq!(decision, PermissionDecision::Allowed); +} + +#[tokio::test] +async fn test_decision_cache_per_connector() { + // Note: This test manually adds cache entries to verify cache separation. + // In production, cache entries are only added when users select "always" options in ACP prompts. + + let config = PermissionConfig { + mode: PermissionMode::Ask, + remember_decisions: true, + remember_ttl_secs: 300, + scope: DecisionScope::PerConnector, + whitelist: WhitelistConfig::default(), + }; + + let whitelist = CompiledWhitelist::compile(&config.whitelist).unwrap(); + let context1 = PermissionContext::new( + "connector-1".to_string(), + Some("session-1".to_string()), + whitelist.clone(), + ); + let context2 = PermissionContext::new( + "connector-2".to_string(), + Some("session-2".to_string()), + whitelist, + ); + + // Manually add cache entries to test separation + { + let mut cache1 = context1.cache.lock().unwrap(); + let key = CacheKey::write("/test/path", "connector-1", Some("session-1"), DecisionScope::PerConnector); + cache1.insert(key, PermissionDecision::Allowed, Duration::from_secs(300)); + } + + assert_eq!(context1.cache_size(), 1); + + // Second connector should have separate cache + assert_eq!(context2.cache_size(), 0); + + { + let mut cache2 = context2.cache.lock().unwrap(); + let key = CacheKey::write("/test/path", "connector-2", Some("session-2"), DecisionScope::PerConnector); + cache2.insert(key, PermissionDecision::Allowed, Duration::from_secs(300)); + } + + assert_eq!(context2.cache_size(), 1); +} + +#[tokio::test] +async fn test_decision_cache_per_session() { + // Note: This test manually adds cache entries to verify per-session scoping. + // In production, cache entries are only added when users select "always" options in ACP prompts. + + let config = PermissionConfig { + mode: PermissionMode::Ask, + remember_decisions: true, + remember_ttl_secs: 300, + scope: DecisionScope::PerSession, + whitelist: WhitelistConfig::default(), + }; + + let whitelist = CompiledWhitelist::compile(&config.whitelist).unwrap(); + + // Same connector, different sessions - they share the same cache object but have different keys + let context1 = PermissionContext::new( + "connector-1".to_string(), + Some("session-1".to_string()), + whitelist.clone(), + ); + let context2 = PermissionContext::new( + "connector-1".to_string(), + Some("session-2".to_string()), + whitelist, + ); + + // Manually add entries for each session with per-session scope + { + let mut cache1 = context1.cache.lock().unwrap(); + let key1 = CacheKey::write("/test/path", "connector-1", Some("session-1"), DecisionScope::PerSession); + cache1.insert(key1, PermissionDecision::Allowed, Duration::from_secs(300)); + } + + { + let mut cache2 = context2.cache.lock().unwrap(); + let key2 = CacheKey::write("/test/path", "connector-1", Some("session-2"), DecisionScope::PerSession); + cache2.insert(key2, PermissionDecision::Allowed, Duration::from_secs(300)); + } + + // Each context has its own cache object, so they each have 1 entry + assert_eq!(context1.cache_size(), 1); + assert_eq!(context2.cache_size(), 1); +} + +#[tokio::test] +async fn test_cache_ttl_expiration() { + // Note: This test manually adds a cache entry with short TTL to test expiration. + + let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); + let context = PermissionContext::new( + "test-connector".to_string(), + None, + whitelist, + ); + + // Manually add entry with very short TTL + { + let mut cache = context.cache.lock().unwrap(); + let key = CacheKey::write("/test/path", "test-connector", None, DecisionScope::PerConnector); + cache.insert(key.clone(), PermissionDecision::Allowed, Duration::from_millis(1)); + } + + // Cache entry should be present + let initial_size = context.cache_size(); + assert_eq!(initial_size, 1); + + // Wait for expiration + tokio::time::sleep(Duration::from_millis(10)).await; + + // Try to get the entry (should be expired and removed) + { + let mut cache = context.cache.lock().unwrap(); + let key = CacheKey::write("/test/path", "test-connector", None, DecisionScope::PerConnector); + let result = cache.get(&key); + assert_eq!(result, None); // Should be expired + } + + // Cache should now be empty + assert_eq!(context.cache_size(), 0); +} + +#[tokio::test] +async fn test_write_file_with_permission_yolo() { + let temp_dir = TempDir::new().unwrap(); + let sandbox_config = create_test_sandbox(&temp_dir); + + let permission_config = PermissionConfig { + mode: PermissionMode::Yolo, + remember_decisions: true, + remember_ttl_secs: 300, + scope: DecisionScope::PerConnector, + whitelist: WhitelistConfig::default(), + }; + + let whitelist = CompiledWhitelist::compile(&permission_config.whitelist).unwrap(); + let permission_context = PermissionContext::new( + "test-connector".to_string(), + None, + whitelist, + ); + + let file_path = temp_dir.path().join("test_file.txt"); + let request = WriteTextFileRequest { + path: file_path.display().to_string(), + content: "Hello, world!".to_string(), + }; + + // Should succeed in yolo mode + let result = write_text_file( + request, + &sandbox_config, + &permission_config, + &permission_context, + ) + .await; + + assert!(result.is_ok()); + assert!(file_path.exists()); + + let content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(content, "Hello, world!"); +} + +#[tokio::test] +async fn test_write_file_with_permission_whitelist() { + let temp_dir = TempDir::new().unwrap(); + let sandbox_config = create_test_sandbox(&temp_dir); + let temp_path = temp_dir.path().display().to_string(); + + let permission_config = PermissionConfig { + mode: PermissionMode::Whitelist, + remember_decisions: true, + remember_ttl_secs: 300, + scope: DecisionScope::PerConnector, + whitelist: WhitelistConfig { + write_paths: vec![format!("{}/**", temp_path)], + execute_commands: vec![], + }, + }; + + let whitelist = CompiledWhitelist::compile(&permission_config.whitelist).unwrap(); + let permission_context = PermissionContext::new( + "test-connector".to_string(), + None, + whitelist, + ); + + let file_path = temp_dir.path().join("test_file.txt"); + let request = WriteTextFileRequest { + path: file_path.display().to_string(), + content: "Whitelisted write".to_string(), + }; + + // Should succeed because path matches whitelist + let result = write_text_file( + request, + &sandbox_config, + &permission_config, + &permission_context, + ) + .await; + + assert!(result.is_ok()); + assert!(file_path.exists()); + + let content = std::fs::read_to_string(&file_path).unwrap(); + assert_eq!(content, "Whitelisted write"); +} + +#[tokio::test] +async fn test_cache_key_equality() { + // Test that cache keys are correctly generated for different scopes + let key1 = CacheKey::write( + "/path/to/file", + "conn-1", + Some("session-1"), + DecisionScope::PerConnector, + ); + let key2 = CacheKey::write( + "/path/to/file", + "conn-1", + Some("session-1"), + DecisionScope::PerConnector, + ); + let key3 = CacheKey::write( + "/path/to/file", + "conn-1", + Some("session-2"), + DecisionScope::PerConnector, + ); + + // Same connector and path should produce equal keys for per-connector scope + assert_eq!(key1, key2); + // Different sessions should still be equal for per-connector scope + assert_eq!(key1, key3); + + // Different scope should produce different keys + let key4 = CacheKey::write( + "/path/to/file", + "conn-1", + Some("session-1"), + DecisionScope::PerSession, + ); + assert_ne!(key1, key4); +} + +#[test] +fn test_whitelist_compilation() { + let config = WhitelistConfig { + write_paths: vec![ + "C:/work/**".to_string(), + "/home/user/**".to_string(), + ], + execute_commands: vec![ + "cargo".to_string(), + "npm*".to_string(), + ], + }; + + let whitelist = CompiledWhitelist::compile(&config); + assert!(whitelist.is_ok()); + + let whitelist = whitelist.unwrap(); + assert!(whitelist.has_write_patterns()); + assert!(whitelist.has_execute_patterns()); +} + +#[test] +fn test_invalid_whitelist_pattern() { + let config = WhitelistConfig { + write_paths: vec!["[invalid".to_string()], // Unclosed bracket + execute_commands: vec![], + }; + + let result = CompiledWhitelist::compile(&config); + assert!(result.is_err()); + + match result { + Err(ToolError::InvalidConfig(_)) => { + // Expected error + } + _ => panic!("Expected InvalidConfig error"), + } +} + +#[tokio::test] +async fn test_permission_context_sharing() { + let whitelist = CompiledWhitelist::compile(&WhitelistConfig::default()).unwrap(); + let context = PermissionContext::new( + "test-connector".to_string(), + None, + whitelist, + ); + + // Clone should share the same cache + let context_clone = context.clone(); + + // Add entry via original context + { + let mut cache = context.cache.lock().unwrap(); + let key = CacheKey::write("/path", "test", None, DecisionScope::PerConnector); + cache.insert(key, PermissionDecision::Allowed, Duration::from_secs(300)); + } + + // Clone should see the same entry + assert_eq!(context.cache_size(), 1); + assert_eq!(context_clone.cache_size(), 1); +} diff --git a/crates/dirigent_tools/tests/search_operations.rs b/crates/dirigent_tools/tests/search_operations.rs new file mode 100644 index 0000000..921e278 --- /dev/null +++ b/crates/dirigent_tools/tests/search_operations.rs @@ -0,0 +1,421 @@ +//! Integration tests for search operations (ls, glob, grep). +//! +//! These tests verify: +//! - Directory listing with filtering +//! - Glob pattern matching +//! - Regex content search with context +//! - Result limits enforcement +//! - Binary file detection +//! - Windows path support + +use dirigent_tools::config::SearchConfig; +use dirigent_tools::search::{ + glob_search, grep_search, ls, FileKind, GlobRequest, GrepRequest, LsRequest, +}; +use std::fs; +use std::io::Write; +use tempfile::TempDir; + +/// Helper to create test directory structure +fn create_test_structure() -> TempDir { + let temp_dir = TempDir::new().unwrap(); + let base = temp_dir.path(); + + // Create directory structure + fs::create_dir(base.join("src")).unwrap(); + fs::create_dir(base.join("tests")).unwrap(); + fs::create_dir(base.join("target")).unwrap(); // Should be excluded by default + fs::create_dir(base.join(".git")).unwrap(); // Should be excluded by default + + // Create some files + fs::write(base.join("README.md"), "# Test Project\n").unwrap(); + fs::write(base.join("Cargo.toml"), "[package]\nname = \"test\"\n").unwrap(); + fs::write( + base.join("src/main.rs"), + "fn main() {\n println!(\"Hello, world!\");\n}\n", + ) + .unwrap(); + fs::write( + base.join("src/lib.rs"), + "pub fn add(a: i32, b: i32) -> i32 {\n a + b\n}\n", + ) + .unwrap(); + fs::write( + base.join("tests/integration.rs"), + "#[test]\nfn test_add() {\n assert_eq!(add(2, 2), 4);\n}\n", + ) + .unwrap(); + + // Create a file in excluded directory (should not appear in searches) + fs::write(base.join("target/debug.log"), "debug info\n").unwrap(); + + temp_dir +} + +#[tokio::test] +async fn test_ls_basic() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + let config = SearchConfig::default(); + let request = LsRequest { + path: base.to_string_lossy().to_string(), + }; + + let response = ls(request, &config).await.unwrap(); + + // Should have files and directories, but not .git or target (excluded by default) + assert!(response.entries.len() >= 4); // README, Cargo.toml, src, tests + + // Check we have both files and directories + let has_file = response.entries.iter().any(|e| e.kind == FileKind::File); + let has_dir = response.entries.iter().any(|e| e.kind == FileKind::Dir); + assert!(has_file); + assert!(has_dir); + + // Check that excluded directories are not present + let has_target = response + .entries + .iter() + .any(|e| e.path.file_name().unwrap() == "target"); + let has_git = response + .entries + .iter() + .any(|e| e.path.file_name().unwrap() == ".git"); + assert!(!has_target, "target/ should be excluded by default"); + assert!(!has_git, ".git/ should be excluded by default"); +} + +#[tokio::test] +async fn test_ls_file_sizes() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + let config = SearchConfig::default(); + let request = LsRequest { + path: base.to_string_lossy().to_string(), + }; + + let response = ls(request, &config).await.unwrap(); + + // Files should have sizes, directories should not + for entry in &response.entries { + match entry.kind { + FileKind::File => assert!(entry.size.is_some(), "Files should have sizes"), + FileKind::Dir => assert!(entry.size.is_none(), "Directories should not have sizes"), + FileKind::Symlink => {} // Symlinks may or may not have sizes + } + } +} + +#[tokio::test] +async fn test_glob_basic() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + let config = SearchConfig::default(); + + // Search for all Rust files + let request = GlobRequest { + path: base.to_string_lossy().to_string(), + pattern: "**/*.rs".to_string(), + exclude: None, + max_results: None, + }; + + let response = glob_search(request, &config).await.unwrap(); + + // Should find main.rs, lib.rs, and integration.rs (3 files) + assert_eq!(response.matches.len(), 3); + assert!(!response.truncated); + + // Verify all matches end with .rs + for path in &response.matches { + assert!(path.extension().unwrap() == "rs"); + } +} + +#[tokio::test] +async fn test_glob_pattern_variations() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + let config = SearchConfig::default(); + + // Test single-level wildcard + let request = GlobRequest { + path: base.to_string_lossy().to_string(), + pattern: "*.md".to_string(), + exclude: None, + max_results: None, + }; + + let response = glob_search(request, &config).await.unwrap(); + assert_eq!(response.matches.len(), 1); // README.md + assert!(!response.truncated); + + // Test specific directory + let request = GlobRequest { + path: base.to_string_lossy().to_string(), + pattern: "src/*.rs".to_string(), + exclude: None, + max_results: None, + }; + + let response = glob_search(request, &config).await.unwrap(); + assert_eq!(response.matches.len(), 2); // main.rs, lib.rs + assert!(!response.truncated); +} + +#[tokio::test] +async fn test_glob_max_results() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + let config = SearchConfig::default(); + + // Limit to 2 results + let request = GlobRequest { + path: base.to_string_lossy().to_string(), + pattern: "**/*.rs".to_string(), + exclude: None, + max_results: Some(2), + }; + + let response = glob_search(request, &config).await.unwrap(); + assert_eq!(response.matches.len(), 2); + assert!(response.truncated); // Should be truncated since there are 3 .rs files +} + +#[tokio::test] +async fn test_glob_exclude_patterns() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + let config = SearchConfig::default(); + + // Search for all .rs files but exclude tests + let request = GlobRequest { + path: base.to_string_lossy().to_string(), + pattern: "**/*.rs".to_string(), + exclude: Some(vec!["**/tests/**".to_string()]), + max_results: None, + }; + + let response = glob_search(request, &config).await.unwrap(); + assert_eq!(response.matches.len(), 2); // Only main.rs and lib.rs, not integration.rs + assert!(!response.truncated); +} + +#[tokio::test] +async fn test_grep_basic() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + let config = SearchConfig::default(); + + // Search for "main" in all files + let request = GrepRequest { + path: base.to_string_lossy().to_string(), + pattern: "main".to_string(), + file_pattern: None, + ignore_case: false, + context_before: 0, + context_after: 0, + max_results: None, + }; + + let response = grep_search(request, &config).await.unwrap(); + + // Should find "main" in src/main.rs + assert!(response.matches.len() >= 1); + assert!(!response.truncated); + + // Verify line numbers are 1-indexed + for m in &response.matches { + assert!(m.line_number > 0); + } +} + +#[tokio::test] +async fn test_grep_case_insensitive() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + let config = SearchConfig::default(); + + // Search for "HELLO" (case insensitive) + let request = GrepRequest { + path: base.to_string_lossy().to_string(), + pattern: "HELLO".to_string(), + file_pattern: Some("**/*.rs".to_string()), + ignore_case: true, + context_before: 0, + context_after: 0, + max_results: None, + }; + + let response = grep_search(request, &config).await.unwrap(); + + // Should find "Hello" in main.rs + assert!(response.matches.len() >= 1); +} + +#[tokio::test] +async fn test_grep_context_lines() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + // Create a file with more content + fs::write( + base.join("src/context_test.rs"), + "line 1\nline 2\nMATCH HERE\nline 4\nline 5\n", + ) + .unwrap(); + + let config = SearchConfig::default(); + + // Search with context + let request = GrepRequest { + path: base.to_string_lossy().to_string(), + pattern: "MATCH".to_string(), + file_pattern: Some("**/context_test.rs".to_string()), + ignore_case: false, + context_before: 2, + context_after: 2, + max_results: None, + }; + + let response = grep_search(request, &config).await.unwrap(); + + assert_eq!(response.matches.len(), 1); + let m = &response.matches[0]; + + // Should have 2 lines before and 2 lines after + assert_eq!(m.context_before.len(), 2); + assert_eq!(m.context_after.len(), 2); + assert_eq!(m.context_before[0], "line 1"); + assert_eq!(m.context_before[1], "line 2"); + assert_eq!(m.context_after[0], "line 4"); + assert_eq!(m.context_after[1], "line 5"); +} + +#[tokio::test] +async fn test_grep_max_results() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + // Create a file with multiple matches + fs::write( + base.join("many_matches.txt"), + "match\nmatch\nmatch\nmatch\nmatch\n", + ) + .unwrap(); + + let config = SearchConfig::default(); + + // Limit to 3 results + let request = GrepRequest { + path: base.to_string_lossy().to_string(), + pattern: "match".to_string(), + file_pattern: None, + ignore_case: false, + context_before: 0, + context_after: 0, + max_results: Some(3), + }; + + let response = grep_search(request, &config).await.unwrap(); + + assert_eq!(response.matches.len(), 3); + assert!(response.truncated); +} + +#[tokio::test] +async fn test_grep_binary_file_skip() { + let temp_dir = TempDir::new().unwrap(); + let base = temp_dir.path(); + + // Create a binary file with null bytes + let binary_path = base.join("binary.bin"); + let mut file = fs::File::create(&binary_path).unwrap(); + file.write_all(&[0x00, 0x01, 0x02, 0x03]).unwrap(); + + // Create a text file + fs::write(base.join("text.txt"), "some text\n").unwrap(); + + let config = SearchConfig::default(); + + // Search for any character + let request = GrepRequest { + path: base.to_string_lossy().to_string(), + pattern: ".".to_string(), + file_pattern: None, + ignore_case: false, + context_before: 0, + context_after: 0, + max_results: None, + }; + + let response = grep_search(request, &config).await.unwrap(); + + // Should only find matches in text.txt, not binary.bin + assert!(response.matches.len() > 0); + for m in &response.matches { + assert!( + m.path.to_string_lossy().contains("text.txt"), + "Should not match binary files" + ); + } +} + +#[tokio::test] +async fn test_grep_regex_patterns() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + let config = SearchConfig::default(); + + // Test regex with character classes + let request = GrepRequest { + path: base.to_string_lossy().to_string(), + pattern: r"\d+".to_string(), // Match numbers + file_pattern: None, + ignore_case: false, + context_before: 0, + context_after: 0, + max_results: None, + }; + + let response = grep_search(request, &config).await.unwrap(); + + // Should find numbers in test files (e.g., "2, 2" in assert) + assert!(response.matches.len() > 0); +} + +#[cfg(windows)] +#[tokio::test] +async fn test_windows_paths() { + let temp_dir = create_test_structure(); + let base = temp_dir.path(); + + let config = SearchConfig::default(); + + // Test with Windows path separators + let path_str = base.to_string_lossy().to_string(); + + // Test ls + let request = LsRequest { path: path_str.clone() }; + let response = ls(request, &config).await.unwrap(); + assert!(response.entries.len() > 0); + + // Test glob + let request = GlobRequest { + path: path_str.clone(), + pattern: "**\\*.rs".to_string(), // Windows-style pattern + exclude: None, + max_results: None, + }; + let response = glob_search(request, &config).await.unwrap(); + assert!(response.matches.len() > 0); +} diff --git a/crates/dirigent_tools/tests/terminal_integration.rs b/crates/dirigent_tools/tests/terminal_integration.rs new file mode 100644 index 0000000..d6d7958 --- /dev/null +++ b/crates/dirigent_tools/tests/terminal_integration.rs @@ -0,0 +1,296 @@ +//! Integration tests for terminal operations. + +use dirigent_tools::config::TerminalConfig; +use dirigent_tools::terminal::{ + create_terminal, get_terminal_output, kill_terminal, release_terminal, + wait_for_terminal_exit, CreateTerminalRequest, KillTerminalCommandRequest, + ReleaseTerminalRequest, TerminalOutputRequest, WaitForTerminalExitRequest, +}; + +#[tokio::test] +async fn test_terminal_echo_command() { + let config = TerminalConfig { + enabled: true, + default_cwd: Some(std::env::current_dir().unwrap()), + env_allowlist: vec![], + command_blocklist: vec![], + output_byte_limit: 10_000, + max_runtime_secs: 30, + }; + + // Create terminal with echo command + #[cfg(windows)] + let request = CreateTerminalRequest { + command: "cmd".to_string(), + args: vec!["/C".to_string(), "echo".to_string(), "Hello World".to_string()], + cwd: None, + env: None, + output_byte_limit: None, + }; + + #[cfg(not(windows))] + let request = CreateTerminalRequest { + command: "echo".to_string(), + args: vec!["Hello World".to_string()], + cwd: None, + env: None, + output_byte_limit: None, + }; + + let create_response = create_terminal(request, &config).await.unwrap(); + let terminal_id = create_response.terminal_id; + + // Wait for command to complete + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + + // Get output + let output_response = get_terminal_output( + TerminalOutputRequest { + terminal_id: terminal_id.clone(), + }, + &config, + ) + .await + .unwrap(); + + assert!(output_response.output.contains("Hello World")); + + // Release terminal + let _ = release_terminal( + ReleaseTerminalRequest { + terminal_id: terminal_id.clone(), + }, + &config, + ) + .await; +} + +#[tokio::test] +async fn test_terminal_wait_for_exit() { + let config = TerminalConfig { + enabled: true, + default_cwd: Some(std::env::current_dir().unwrap()), + env_allowlist: vec![], + command_blocklist: vec![], + output_byte_limit: 10_000, + max_runtime_secs: 30, + }; + + // Create terminal with a quick command + #[cfg(windows)] + let request = CreateTerminalRequest { + command: "cmd".to_string(), + args: vec!["/C".to_string(), "exit".to_string(), "0".to_string()], + cwd: None, + env: None, + output_byte_limit: None, + }; + + #[cfg(not(windows))] + let request = CreateTerminalRequest { + command: "true".to_string(), + args: vec![], + cwd: None, + env: None, + output_byte_limit: None, + }; + + let create_response = create_terminal(request, &config).await.unwrap(); + let terminal_id = create_response.terminal_id; + + // Wait for exit + let wait_response = wait_for_terminal_exit( + WaitForTerminalExitRequest { + terminal_id: terminal_id.clone(), + }, + &config, + ) + .await + .unwrap(); + + assert_eq!(wait_response.exit_status, 0); + + // Release terminal + let _ = release_terminal( + ReleaseTerminalRequest { + terminal_id: terminal_id.clone(), + }, + &config, + ) + .await; +} + +#[tokio::test] +async fn test_terminal_kill() { + let config = TerminalConfig { + enabled: true, + default_cwd: Some(std::env::current_dir().unwrap()), + env_allowlist: vec![], + command_blocklist: vec![], + output_byte_limit: 10_000, + max_runtime_secs: 30, + }; + + // Create terminal with a long-running command + #[cfg(windows)] + let request = CreateTerminalRequest { + command: "cmd".to_string(), + args: vec!["/C".to_string(), "timeout".to_string(), "/t".to_string(), "10".to_string()], + cwd: None, + env: None, + output_byte_limit: None, + }; + + #[cfg(not(windows))] + let request = CreateTerminalRequest { + command: "sleep".to_string(), + args: vec!["10".to_string()], + cwd: None, + env: None, + output_byte_limit: None, + }; + + let create_response = create_terminal(request, &config).await.unwrap(); + let terminal_id = create_response.terminal_id; + + // Wait a bit to ensure process is running + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Kill the terminal + let kill_response = kill_terminal( + KillTerminalCommandRequest { + terminal_id: terminal_id.clone(), + }, + &config, + ) + .await; + + assert!(kill_response.is_ok()); + + // Release terminal + let _ = release_terminal( + ReleaseTerminalRequest { + terminal_id: terminal_id.clone(), + }, + &config, + ) + .await; +} + +#[tokio::test] +async fn test_terminal_disabled() { + let config = TerminalConfig { + enabled: false, + default_cwd: Some(std::env::current_dir().unwrap()), + env_allowlist: vec![], + command_blocklist: vec![], + output_byte_limit: 10_000, + max_runtime_secs: 30, + }; + + let request = CreateTerminalRequest { + command: "echo".to_string(), + args: vec!["test".to_string()], + cwd: None, + env: None, + output_byte_limit: None, + }; + + let result = create_terminal(request, &config).await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_terminal_command_blocklist() { + let config = TerminalConfig { + enabled: true, + default_cwd: Some(std::env::current_dir().unwrap()), + env_allowlist: vec![], + command_blocklist: vec!["rm".to_string(), "del".to_string()], + output_byte_limit: 10_000, + max_runtime_secs: 30, + }; + + let request = CreateTerminalRequest { + command: "rm".to_string(), + args: vec!["-rf".to_string(), "/".to_string()], + cwd: None, + env: None, + output_byte_limit: None, + }; + + let result = create_terminal(request, &config).await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_terminal_output_truncation() { + let config = TerminalConfig { + enabled: true, + default_cwd: Some(std::env::current_dir().unwrap()), + env_allowlist: vec![], + command_blocklist: vec![], + output_byte_limit: 100, // Very small buffer + max_runtime_secs: 30, + }; + + // Create terminal that generates lots of output + #[cfg(windows)] + let request = CreateTerminalRequest { + command: "cmd".to_string(), + args: vec![ + "/C".to_string(), + "for".to_string(), + "/L".to_string(), + "%i".to_string(), + "in".to_string(), + "(1,1,100)".to_string(), + "do".to_string(), + "@echo".to_string(), + "Line %i".to_string(), + ], + cwd: None, + env: None, + output_byte_limit: Some(100), + }; + + #[cfg(not(windows))] + let request = CreateTerminalRequest { + command: "sh".to_string(), + args: vec![ + "-c".to_string(), + "for i in $(seq 1 100); do echo Line $i; done".to_string(), + ], + cwd: None, + env: None, + output_byte_limit: Some(100), + }; + + let create_response = create_terminal(request, &config).await.unwrap(); + let terminal_id = create_response.terminal_id; + + // Wait for command to complete + tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await; + + // Get output + let output_response = get_terminal_output( + TerminalOutputRequest { + terminal_id: terminal_id.clone(), + }, + &config, + ) + .await + .unwrap(); + + // Buffer should be truncated + assert!(output_response.truncated || output_response.output.len() <= 100); + + // Release terminal + let _ = release_terminal( + ReleaseTerminalRequest { + terminal_id: terminal_id.clone(), + }, + &config, + ) + .await; +} diff --git a/crates/dirigent_tools/tests/test_fixtures/README.md b/crates/dirigent_tools/tests/test_fixtures/README.md new file mode 100644 index 0000000..9e7e1bd --- /dev/null +++ b/crates/dirigent_tools/tests/test_fixtures/README.md @@ -0,0 +1,17 @@ +# Test Fixtures + +This directory contains test fixtures for dirigent_tools tests. + +## Files + +- **sample_text.txt** - Simple multi-line text file +- **unicode_text.txt** - File with Unicode characters (emoji, CJK, special chars) +- **large_text.txt** - File for testing size limits (initially small, can be expanded in tests) +- **binary_sample.bin** - Binary file (non-UTF-8) - to be created programmatically in tests + +## Usage + +These fixtures are used by integration tests in `packages/dirigent_tools/tests/`. + +Test utilities in `tests/common/mod.rs` provide helpers for creating temporary copies +and working with these fixtures. diff --git a/crates/dirigent_tools/tests/test_fixtures/large_text.txt b/crates/dirigent_tools/tests/test_fixtures/large_text.txt new file mode 100644 index 0000000..633cb6b --- /dev/null +++ b/crates/dirigent_tools/tests/test_fixtures/large_text.txt @@ -0,0 +1 @@ +This is a large text file for testing size limits. diff --git a/crates/dirigent_tools/tests/test_fixtures/sample_text.txt b/crates/dirigent_tools/tests/test_fixtures/sample_text.txt new file mode 100644 index 0000000..00acc67 --- /dev/null +++ b/crates/dirigent_tools/tests/test_fixtures/sample_text.txt @@ -0,0 +1,5 @@ +This is a sample text file for testing. +It has multiple lines. +Line 3 +Line 4 +Line 5 diff --git a/crates/dirigent_tools/tests/test_fixtures/unicode_text.txt b/crates/dirigent_tools/tests/test_fixtures/unicode_text.txt new file mode 100644 index 0000000..8ac037c --- /dev/null +++ b/crates/dirigent_tools/tests/test_fixtures/unicode_text.txt @@ -0,0 +1,7 @@ +Unicode test file with various characters: + +ASCII: Hello, World! +Emoji: 👋 🌍 🚀 ✨ +CJK: 你好世界 こんにちは世界 +Special: Ñoño Crème brûlée +Math: ∑ ∫ ∂ ∇ ∞ diff --git a/crates/dirigent_zed/CLAUDE.md b/crates/dirigent_zed/CLAUDE.md new file mode 100644 index 0000000..a65858d --- /dev/null +++ b/crates/dirigent_zed/CLAUDE.md @@ -0,0 +1,64 @@ +# Package: dirigent_zed + +Zed editor integration for Dirigent -- detection, agent discovery, binary resolution. + +## Quick Facts +- **Type**: Library +- **Main Entry**: src/lib.rs +- **Dependencies**: dirigent_config, dirs, serde, serde_json, thiserror, tracing +- **Status**: Initial implementation + +## Purpose + +Detects Zed editor installations on the current system, discovers configured +ACP agents from Zed's `settings.json`, and resolves downloaded binary paths +from the Zed data directory. + +## Key Types + +- `ZedChannel` -- Release channel enum (Stable, Preview, Nightly, Dev) +- `ZedAgent` -- Discovered agent with name, type, binary path, env overrides +- `AgentServerType` -- Registry, Custom, or Extension +- `ZedInstallation` -- Detected installation with channel, paths, and agents + +## Module Organization + +- **`paths.rs`** -- Platform path resolution for Zed config/data directories +- **`agents.rs`** -- Agent discovery from settings.json, JSONC stripping, binary resolution +- **`detection.rs`** -- High-level installation detection combining paths and agents + +## Platform Paths + +| Platform | Config Dir | Data Dir | +|----------|-----------|----------| +| Linux | `$XDG_CONFIG_HOME/zed` | `$XDG_DATA_HOME/zed` | +| macOS | `~/.config/zed` | `~/Library/Application Support/Zed` | +| Windows | `%APPDATA%\Zed` | `%LOCALAPPDATA%\Zed` | + +## Usage + +```rust +let installations = dirigent_zed::detect_installations(); +for inst in &installations { + for agent in &inst.agents { + if let Some(ref binary) = agent.binary_path { + println!("{}: {}", agent.name, binary.display()); + } + } +} +``` + +## Testing + +```bash +cargo test -p dirigent_zed +``` + +## Related Packages + +- **dirigent_config** -- Dirigent's own path resolution (dependency) +- **dirigent_core** -- Will consume this crate for Zed connector integration (future) + +## Research + +See `docs/research/zed-integration.md` for detailed platform paths and detection strategies. diff --git a/crates/dirigent_zed/Cargo.toml b/crates/dirigent_zed/Cargo.toml new file mode 100644 index 0000000..c576f5f --- /dev/null +++ b/crates/dirigent_zed/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "dirigent_zed" +version = "0.1.0" +edition = "2021" +description = "Zed editor integration for Dirigent — detection, agent discovery, binary resolution" + +[dependencies] +dirigent_config = { path = "../dirigent_config" } +dirs = "5" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +thiserror = "2.0" +tracing = "0.1" + +[dev-dependencies] +tempfile = "3" diff --git a/crates/dirigent_zed/src/agents.rs b/crates/dirigent_zed/src/agents.rs new file mode 100644 index 0000000..bf0dc61 --- /dev/null +++ b/crates/dirigent_zed/src/agents.rs @@ -0,0 +1,1145 @@ +//! Agent discovery from Zed settings and binary resolution. +//! +//! Reads Zed's `settings.json` (which may contain JSONC comments), extracts +//! `agent_servers` configuration, and resolves downloaded binary paths from +//! the data directory. + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + +/// Agent server type as defined in Zed settings. +/// +/// Defaults to `Registry` when the `"type"` field is omitted in settings.json, +/// matching Zed's own behavior where omitting type means it's a registry agent. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] +#[serde(rename_all = "lowercase")] +pub enum AgentServerType { + #[default] + Registry, + Custom, + Extension, +} + +impl std::fmt::Display for AgentServerType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AgentServerType::Registry => f.write_str("registry"), + AgentServerType::Custom => f.write_str("custom"), + AgentServerType::Extension => f.write_str("extension"), + } + } +} + +/// A discovered Zed agent with its configuration and resolved binary path. +#[derive(Debug, Clone)] +pub struct ZedAgent { + /// Agent name (key in `agent_servers` map). + pub name: String, + /// Server type: registry, custom, or extension. + pub agent_type: AgentServerType, + /// Resolved binary path (populated by `resolve_binary_paths`). + pub binary_path: Option<PathBuf>, + /// Environment variable overrides from settings. + pub env_overrides: HashMap<String, String>, + /// Display name from the ACP registry (e.g. "Claude Agent", "Codex CLI"). + pub display_name: Option<String>, + /// Description from the ACP registry. + pub description: Option<String>, + /// Command arguments from the registry distribution config (e.g. `["--acp"]`). + pub args: Vec<String>, + /// Path to a locally cached icon file (SVG) from the registry. + pub icon_local_path: Option<PathBuf>, + /// Icon URL from the registry CDN. + pub icon_url: Option<String>, +} + +/// Raw serde model for an entry in the `agent_servers` map in settings.json. +#[derive(Debug, Deserialize)] +struct AgentServerEntry { + #[serde(rename = "type", default)] + server_type: AgentServerType, + #[serde(default)] + env: HashMap<String, String>, + // We ignore other fields like default_mode, default_model, command, args, etc. + // for now; we only need type and env for discovery. +} + +/// Raw serde model for the top-level settings.json (only the fields we care about). +#[derive(Debug, Deserialize)] +struct ZedSettings { + #[serde(default)] + agent_servers: HashMap<String, AgentServerEntry>, +} + +/// Discover agents from Zed's `settings.json` in the given config directory. +/// +/// Reads `{config_dir}/settings.json`, strips JSONC comments, and parses +/// the `agent_servers` key into a list of `ZedAgent` values. +/// +/// Returns an empty vec if the file doesn't exist or can't be parsed. +pub fn discover_agents_from_settings(config_dir: &Path) -> Vec<ZedAgent> { + let settings_path = config_dir.join("settings.json"); + + let content = match std::fs::read_to_string(&settings_path) { + Ok(c) => c, + Err(e) => { + tracing::debug!( + path = %settings_path.display(), + error = %e, + "Could not read Zed settings.json" + ); + return Vec::new(); + } + }; + + let stripped = strip_jsonc_comments(&content); + + let settings: ZedSettings = match serde_json::from_str(&stripped) { + Ok(s) => s, + Err(e) => { + tracing::warn!( + path = %settings_path.display(), + error = %e, + "Failed to parse Zed settings.json" + ); + return Vec::new(); + } + }; + + settings + .agent_servers + .into_iter() + .map(|(name, entry)| ZedAgent { + name, + agent_type: entry.server_type, + binary_path: None, + env_overrides: entry.env, + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }) + .collect() +} + +/// Discover agents from the `external_agents/` directory in the Zed data dir. +/// +/// Zed downloads registry agents to `{data_dir}/external_agents/{package_name}/`. +/// These may not appear in `settings.json` if the user hasn't customized them. +/// This function scans the directory and returns agents for any packages not +/// already present in the `existing_agents` list (matched by fuzzy name). +pub fn discover_agents_from_external_dir( + data_dir: &Path, + existing_agents: &[ZedAgent], +) -> Vec<ZedAgent> { + let agents_dir = data_dir.join("external_agents"); + + if !agents_dir.is_dir() { + return Vec::new(); + } + + let existing_names: Vec<String> = existing_agents + .iter() + .map(|a| a.name.to_lowercase()) + .collect(); + + let mut discovered = Vec::new(); + + let entries = match std::fs::read_dir(&agents_dir) { + Ok(e) => e, + Err(_) => return Vec::new(), + }; + + for entry in entries.flatten() { + let path = entry.path(); + if !path.is_dir() { + continue; + } + + let dir_name = entry.file_name().to_string_lossy().to_string(); + + // Skip the registry directory — it contains cached registry metadata + // and downloaded binaries indexed by URL hash, not agent configurations. + if dir_name == "registry" { + tracing::debug!("Skipping registry directory in external_agents"); + continue; + } + + let dir_lower = dir_name.to_lowercase(); + + // Check if this directory already matches an existing agent from settings. + let already_covered = existing_names.iter().any(|existing| { + dir_lower == *existing + || dir_lower.contains(existing) + || existing.contains(&dir_lower) + }); + + if already_covered { + continue; + } + + tracing::debug!( + dir = %dir_name, + "Discovered external agent not in settings" + ); + + discovered.push(ZedAgent { + name: dir_name, + agent_type: AgentServerType::Registry, + binary_path: None, + env_overrides: HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }); + } + + discovered +} + +/// Resolve binary paths for registry agents from the Zed data directory. +/// +/// For each agent with `agent_type == Registry`, looks for downloaded binaries +/// under `{data_dir}/external_agents/`. Zed stores agents directly under +/// `external_agents/{package_name}/{version}/` — package names often differ +/// from settings keys (e.g., settings key `"claude"` maps to directory +/// `claude-agent-acp` or `claude-code-acp`). +/// +/// The resolution strategy: +/// 1. Try exact directory name match first +/// 2. Scan all directories in `external_agents/` for fuzzy matches +/// (directory name contains agent name, or agent name contains directory name) +/// 3. Within matched directories, find the latest version subdirectory +/// 4. Look for native executables and Node.js bin entries +pub fn resolve_binary_paths(agents: &mut [ZedAgent], data_dir: &Path) { + let agents_dir = data_dir.join("external_agents"); + + if !agents_dir.is_dir() { + tracing::debug!( + path = %agents_dir.display(), + "external_agents directory not found" + ); + return; + } + + // Collect all directories in external_agents/ for fuzzy matching. + let available_dirs: Vec<(String, PathBuf)> = std::fs::read_dir(&agents_dir) + .ok() + .into_iter() + .flatten() + .flatten() + .filter_map(|entry| { + let path = entry.path(); + if path.is_dir() { + let name = entry.file_name().to_string_lossy().to_string(); + Some((name, path)) + } else { + None + } + }) + .collect(); + + for agent in agents.iter_mut() { + if agent.agent_type != AgentServerType::Registry { + continue; + } + + // Find matching directories for this agent name. + let matching_dirs = find_matching_agent_dirs(&agent.name, &available_dirs); + + if matching_dirs.is_empty() { + tracing::debug!( + agent = %agent.name, + "No matching directory found in external_agents" + ); + continue; + } + + // When multiple dirs match (e.g., claude-agent-acp and claude-code-acp), + // pick the one with the most recently modified version directory. + let mut best_binary: Option<(PathBuf, std::time::SystemTime)> = None; + + for agent_dir in &matching_dirs { + let version_dir = match find_latest_version_dir(agent_dir) { + Some(d) => d, + None => continue, + }; + + let dir_mtime = std::fs::metadata(&version_dir) + .and_then(|m| m.modified()) + .unwrap_or(std::time::SystemTime::UNIX_EPOCH); + + if let Some(binary) = find_nodejs_bin(&agent.name, &version_dir) + .or_else(|| find_native_binary(&version_dir)) + { + match &best_binary { + Some((_, best_time)) if dir_mtime <= *best_time => {} + _ => { + best_binary = Some((binary, dir_mtime)); + } + } + } + } + + if let Some((binary, _)) = best_binary { + tracing::debug!( + agent = %agent.name, + binary = %binary.display(), + "Resolved agent binary path" + ); + agent.binary_path = Some(binary); + } + } +} + +/// Enrich discovered agents with metadata from the parsed registry. +/// +/// For each agent, looks up the corresponding registry entry using fuzzy name +/// matching and populates display name, description, args, icon path, and icon URL. +pub fn enrich_agents_from_registry( + agents: &mut [ZedAgent], + registry: &std::collections::HashMap<String, crate::registry::RegistryAgentInfo>, +) { + for agent in agents.iter_mut() { + if let Some(info) = crate::registry::find_registry_match(&agent.name, registry) { + if agent.display_name.is_none() { + agent.display_name = Some(info.display_name.clone()); + } + if agent.description.is_none() { + agent.description = Some(info.description.clone()); + } + if agent.args.is_empty() { + agent.args = info.args.clone(); + } + if agent.icon_local_path.is_none() { + agent.icon_local_path = info.icon_local_path.clone(); + } + if agent.icon_url.is_none() { + agent.icon_url = info.icon_url.clone(); + } + + tracing::debug!( + agent = %agent.name, + registry_id = %info.id, + display_name = %info.display_name, + "Enriched agent from registry" + ); + } + } +} + +/// Find directories in `external_agents/` that match an agent settings name. +/// +/// Matching strategy (in priority order): +/// 1. Exact match: directory name equals agent name +/// 2. Directory name contains agent name (e.g., "claude-agent-acp" contains "claude") +/// 3. Agent name contains directory name (e.g., "claude-acp" contains "claude") +fn find_matching_agent_dirs<'a>( + agent_name: &str, + available_dirs: &'a [(String, PathBuf)], +) -> Vec<&'a Path> { + let agent_lower = agent_name.to_lowercase(); + let mut exact = Vec::new(); + let mut contains = Vec::new(); + + for (dir_name, dir_path) in available_dirs { + let dir_lower = dir_name.to_lowercase(); + if dir_lower == agent_lower { + exact.push(dir_path.as_path()); + } else if dir_lower.contains(&agent_lower) || agent_lower.contains(&dir_lower) { + contains.push(dir_path.as_path()); + } + } + + if !exact.is_empty() { + exact + } else { + contains + } +} + +/// Find the most recently modified version directory inside `agent_dir`. +/// +/// Zed uses various version directory formats: `v0.9.2`, `0.20.0`, `v_abc123`. +/// We accept any subdirectory as a potential version directory. +fn find_latest_version_dir(agent_dir: &Path) -> Option<PathBuf> { + let read_dir = std::fs::read_dir(agent_dir).ok()?; + + let mut best: Option<(PathBuf, std::time::SystemTime)> = None; + + for entry in read_dir.flatten() { + let path = entry.path(); + if !path.is_dir() { + continue; + } + // Skip hidden directories and node_modules at this level. + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + if name_str.starts_with('.') || name_str == "node_modules" { + continue; + } + + let modified = entry.metadata().and_then(|m| m.modified()).ok(); + if let Some(mod_time) = modified { + match &best { + Some((_, best_time)) if mod_time > *best_time => { + best = Some((path, mod_time)); + } + None => { + best = Some((path, mod_time)); + } + _ => {} + } + } else { + if best.is_none() { + best = Some((path, std::time::SystemTime::UNIX_EPOCH)); + } + } + } + + best.map(|(path, _)| path) +} + +/// Find a Node.js executable in `node_modules/.bin/` within a version directory. +/// +/// Many Zed agents are npm packages. Their executables are symlinked in +/// `node_modules/.bin/`. We prefer a binary whose file stem matches the agent +/// name (case-insensitive, with common ACP suffixes stripped). If no match is +/// found, we fall back to the first executable -- this covers agents whose +/// binary name differs from their directory name. +/// +/// The name-matching preference is important because npm packages often install +/// dependency binaries alongside the agent binary (e.g., `acorn`, `glob`), and +/// without matching we might pick the wrong one. +fn find_nodejs_bin(agent_name: &str, version_dir: &Path) -> Option<PathBuf> { + let bin_dir = version_dir.join("node_modules").join(".bin"); + if !bin_dir.is_dir() { + return None; + } + + let read_dir = std::fs::read_dir(&bin_dir).ok()?; + + let agent_lower = agent_name.to_lowercase(); + let agent_core = crate::registry::strip_acp_suffixes(&agent_lower); + + let mut first_executable: Option<PathBuf> = None; + + for entry in read_dir.flatten() { + let path = entry.path(); + if path.is_dir() { + continue; + } + if !is_executable(&path) { + continue; + } + + // Check if this binary's stem matches the agent name. + if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) { + let stem_lower = stem.to_lowercase(); + if stem_lower == agent_lower || stem_lower == agent_core { + return Some(path); + } + } + + if first_executable.is_none() { + first_executable = Some(path); + } + } + + first_executable +} + +/// Find a native executable binary inside a version directory. +/// +/// Looks for executable files directly in the version directory, recursing +/// one level for extracted archives. Skips `node_modules/` and hidden dirs. +fn find_native_binary(dir: &Path) -> Option<PathBuf> { + let read_dir = std::fs::read_dir(dir).ok()?; + + for entry in read_dir.flatten() { + let path = entry.path(); + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + + if path.is_dir() { + // Skip node_modules and hidden directories. + if name_str == "node_modules" || name_str.starts_with('.') { + continue; + } + // Recurse one level for extracted archives. + if let Some(binary) = find_native_binary(&path) { + return Some(binary); + } + continue; + } + + // Skip non-binary files (package.json, lock files, etc.) + if name_str.ends_with(".json") || name_str.ends_with(".lock") { + continue; + } + + if is_executable(&path) { + return Some(path); + } + } + + None +} + +/// Check if a file is executable. +#[cfg(unix)] +fn is_executable(path: &Path) -> bool { + use std::os::unix::fs::PermissionsExt; + path.metadata() + .map(|m| m.permissions().mode() & 0o111 != 0) + .unwrap_or(false) +} + +#[cfg(not(unix))] +fn is_executable(path: &Path) -> bool { + path.extension() + .map(|ext| ext == "exe" || ext == "cmd" || ext == "bat") + .unwrap_or(false) +} + +/// Strip JSONC comments from input text. +/// +/// Handles: +/// - Line comments (`// ...`) +/// - Block comments (`/* ... */`) +/// - Does NOT strip inside quoted strings +/// - Handles escaped quotes inside strings +pub fn strip_jsonc_comments(input: &str) -> String { + let mut output = String::with_capacity(input.len()); + let chars: Vec<char> = input.chars().collect(); + let len = chars.len(); + let mut i = 0; + + while i < len { + let ch = chars[i]; + + // Inside a JSON string — pass through verbatim, respecting escape sequences. + if ch == '"' { + output.push(ch); + i += 1; + while i < len { + let c = chars[i]; + output.push(c); + if c == '\\' { + // Escaped character: push the next char unconditionally. + i += 1; + if i < len { + output.push(chars[i]); + } + } else if c == '"' { + break; + } + i += 1; + } + i += 1; + continue; + } + + // Check for line comment `//` + if ch == '/' && i + 1 < len && chars[i + 1] == '/' { + // Skip until end of line. + i += 2; + while i < len && chars[i] != '\n' { + i += 1; + } + continue; + } + + // Check for block comment `/* ... */` + if ch == '/' && i + 1 < len && chars[i + 1] == '*' { + i += 2; + while i + 1 < len { + if chars[i] == '*' && chars[i + 1] == '/' { + i += 2; + break; + } + // Preserve newlines so line numbers stay meaningful for error messages. + if chars[i] == '\n' { + output.push('\n'); + } + i += 1; + } + continue; + } + + output.push(ch); + i += 1; + } + + output +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + + /// Returns a binary file name appropriate for the platform. + /// On Windows, appends `.exe`; on Unix, returns the name as-is. + fn platform_binary(name: &str) -> String { + if cfg!(windows) { + format!("{}.exe", name) + } else { + name.to_string() + } + } + + #[test] + fn test_strip_jsonc_line_comments() { + let input = r#"{ + "key": "value", // this is a comment + "other": 42 // another comment +}"#; + let stripped = strip_jsonc_comments(input); + assert!(!stripped.contains("// this is")); + assert!(!stripped.contains("// another")); + let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap(); + assert_eq!(parsed["key"], "value"); + assert_eq!(parsed["other"], 42); + } + + #[test] + fn test_strip_jsonc_block_comments() { + let input = r#"{ + /* block comment */ + "key": "value", + "other": /* inline block */ 42 +}"#; + let stripped = strip_jsonc_comments(input); + assert!(!stripped.contains("block comment")); + assert!(!stripped.contains("inline block")); + let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap(); + assert_eq!(parsed["key"], "value"); + assert_eq!(parsed["other"], 42); + } + + #[test] + fn test_strip_jsonc_preserves_strings() { + let input = r#"{ + "url": "https://example.com/path", // comment after url + "comment_like": "this has // inside and /* block */ too" +}"#; + let stripped = strip_jsonc_comments(input); + // The comment after url should be stripped. + assert!(!stripped.contains("// comment after url")); + // But the strings should be intact. + let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap(); + assert_eq!(parsed["url"], "https://example.com/path"); + assert_eq!( + parsed["comment_like"], + "this has // inside and /* block */ too" + ); + } + + #[test] + fn test_strip_jsonc_escaped_quotes() { + let input = r#"{ + "escaped": "he said \"hello\" // not a comment", + "real": 1 // real comment +}"#; + let stripped = strip_jsonc_comments(input); + assert!(!stripped.contains("// real comment")); + let parsed: serde_json::Value = serde_json::from_str(&stripped).unwrap(); + assert_eq!( + parsed["escaped"], + r#"he said "hello" // not a comment"# + ); + assert_eq!(parsed["real"], 1); + } + + #[test] + fn test_parse_agent_servers() { + let dir = tempfile::tempdir().unwrap(); + let settings_content = r#"{ + // Agent configuration + "agent_servers": { + "claude-acp": { + "type": "registry", + "default_mode": "plan", + "env": { + "CLAUDE_CODE_EXECUTABLE": "/usr/local/bin/claude" + } + }, + "codex-acp": { + "type": "registry", + "default_model": "o4-mini" + }, + "My Custom Agent": { + "type": "custom", + "command": "node", + "args": ["~/projects/agent/index.js", "--acp"], + "env": {} + } + } +}"#; + let settings_path = dir.path().join("settings.json"); + let mut f = std::fs::File::create(&settings_path).unwrap(); + f.write_all(settings_content.as_bytes()).unwrap(); + + let agents = discover_agents_from_settings(dir.path()); + assert_eq!(agents.len(), 3); + + let claude = agents.iter().find(|a| a.name == "claude-acp").unwrap(); + assert_eq!(claude.agent_type, AgentServerType::Registry); + assert_eq!( + claude.env_overrides.get("CLAUDE_CODE_EXECUTABLE"), + Some(&"/usr/local/bin/claude".to_string()) + ); + + let codex = agents.iter().find(|a| a.name == "codex-acp").unwrap(); + assert_eq!(codex.agent_type, AgentServerType::Registry); + assert!(codex.env_overrides.is_empty()); + + let custom = agents + .iter() + .find(|a| a.name == "My Custom Agent") + .unwrap(); + assert_eq!(custom.agent_type, AgentServerType::Custom); + } + + #[test] + fn test_discover_missing_file() { + let dir = tempfile::tempdir().unwrap(); + let agents = discover_agents_from_settings(dir.path()); + assert!(agents.is_empty()); + } + + #[test] + fn test_discover_no_agent_servers_key() { + let dir = tempfile::tempdir().unwrap(); + let settings_path = dir.path().join("settings.json"); + std::fs::write(&settings_path, r#"{ "theme": "dark" }"#).unwrap(); + + let agents = discover_agents_from_settings(dir.path()); + assert!(agents.is_empty()); + } + + #[test] + fn test_resolve_binary_paths_no_data() { + let dir = tempfile::tempdir().unwrap(); + let mut agents = vec![ZedAgent { + name: "claude-acp".to_string(), + agent_type: AgentServerType::Registry, + binary_path: None, + env_overrides: HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }]; + + resolve_binary_paths(&mut agents, dir.path()); + assert!(agents[0].binary_path.is_none()); + } + + #[test] + fn test_resolve_native_binary_exact_name() { + // Simulates: external_agents/codex/v0.9.2/codex-acp (native ELF) + let dir = tempfile::tempdir().unwrap(); + let version_dir = dir + .path() + .join("external_agents") + .join("codex") + .join("v0.9.2"); + std::fs::create_dir_all(&version_dir).unwrap(); + + let binary_path = version_dir.join(platform_binary("codex-acp")); + std::fs::write(&binary_path, b"#!/bin/sh\necho hello").unwrap(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&binary_path, std::fs::Permissions::from_mode(0o755)) + .unwrap(); + } + + let mut agents = vec![ZedAgent { + name: "codex".to_string(), + agent_type: AgentServerType::Registry, + binary_path: None, + env_overrides: HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }]; + + resolve_binary_paths(&mut agents, dir.path()); + assert!(agents[0].binary_path.is_some()); + assert!(agents[0] + .binary_path + .as_ref() + .unwrap() + .to_string_lossy() + .contains("codex-acp")); + } + + #[test] + fn test_resolve_nodejs_bin() { + // Simulates: external_agents/claude-agent-acp/0.20.0/node_modules/.bin/claude-agent-acp + let dir = tempfile::tempdir().unwrap(); + let bin_dir = dir + .path() + .join("external_agents") + .join("claude-agent-acp") + .join("0.20.0") + .join("node_modules") + .join(".bin"); + std::fs::create_dir_all(&bin_dir).unwrap(); + + let binary_path = bin_dir.join(platform_binary("claude-agent-acp")); + std::fs::write(&binary_path, b"#!/usr/bin/env node\nconsole.log('hi')").unwrap(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&binary_path, std::fs::Permissions::from_mode(0o755)) + .unwrap(); + } + + // Settings key is "claude" but directory is "claude-agent-acp" — fuzzy match. + let mut agents = vec![ZedAgent { + name: "claude".to_string(), + agent_type: AgentServerType::Registry, + binary_path: None, + env_overrides: HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }]; + + resolve_binary_paths(&mut agents, dir.path()); + assert!( + agents[0].binary_path.is_some(), + "Should resolve Node.js bin via fuzzy directory match" + ); + assert!(agents[0] + .binary_path + .as_ref() + .unwrap() + .to_string_lossy() + .contains("claude-agent-acp")); + } + + #[test] + fn test_resolve_semver_version_dirs() { + // Zed uses semver version dirs like "0.20.0", not just "v_" prefixed. + let dir = tempfile::tempdir().unwrap(); + let version_dir = dir + .path() + .join("external_agents") + .join("test-agent") + .join("1.2.3"); + std::fs::create_dir_all(&version_dir).unwrap(); + + let binary_path = version_dir.join(platform_binary("test-agent")); + std::fs::write(&binary_path, b"#!/bin/sh\necho hello").unwrap(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&binary_path, std::fs::Permissions::from_mode(0o755)) + .unwrap(); + } + + let mut agents = vec![ZedAgent { + name: "test-agent".to_string(), + agent_type: AgentServerType::Registry, + binary_path: None, + env_overrides: HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }]; + + resolve_binary_paths(&mut agents, dir.path()); + assert!( + agents[0].binary_path.is_some(), + "Should find binary in semver-named version directory" + ); + } + + #[test] + fn test_resolve_picks_latest_version() { + // Two version dirs, the newer one should win. + let dir = tempfile::tempdir().unwrap(); + let agents_base = dir.path().join("external_agents").join("my-agent"); + + let old_dir = agents_base.join("0.1.0"); + std::fs::create_dir_all(&old_dir).unwrap(); + let old_bin = old_dir.join(platform_binary("my-agent")); + std::fs::write(&old_bin, b"old").unwrap(); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&old_bin, std::fs::Permissions::from_mode(0o755)).unwrap(); + } + + // Sleep briefly to ensure different mtime + std::thread::sleep(std::time::Duration::from_millis(50)); + + let new_dir = agents_base.join("0.2.0"); + std::fs::create_dir_all(&new_dir).unwrap(); + let new_bin = new_dir.join(platform_binary("my-agent")); + std::fs::write(&new_bin, b"new").unwrap(); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&new_bin, std::fs::Permissions::from_mode(0o755)).unwrap(); + } + + let mut agents = vec![ZedAgent { + name: "my-agent".to_string(), + agent_type: AgentServerType::Registry, + binary_path: None, + env_overrides: HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }]; + + resolve_binary_paths(&mut agents, dir.path()); + assert!(agents[0].binary_path.is_some()); + let resolved = agents[0].binary_path.as_ref().unwrap(); + assert!( + resolved.to_string_lossy().contains("0.2.0"), + "Should pick the newer version directory, got: {}", + resolved.display() + ); + } + + #[test] + fn test_resolve_skips_custom_agents() { + let dir = tempfile::tempdir().unwrap(); + let mut agents = vec![ZedAgent { + name: "my-agent".to_string(), + agent_type: AgentServerType::Custom, + binary_path: None, + env_overrides: HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }]; + + resolve_binary_paths(&mut agents, dir.path()); + assert!(agents[0].binary_path.is_none()); + } + + #[test] + fn test_fuzzy_match_dir_contains_agent_name() { + // Directory "claude-code-acp" should match agent name "claude" + let dirs = vec![ + ("claude-code-acp".to_string(), PathBuf::from("/fake/claude-code-acp")), + ("codex".to_string(), PathBuf::from("/fake/codex")), + ]; + let matches = find_matching_agent_dirs("claude", &dirs); + assert_eq!(matches.len(), 1); + assert!(matches[0].to_string_lossy().contains("claude-code-acp")); + } + + #[test] + fn test_fuzzy_match_exact_preferred() { + // Exact match "codex" should be preferred over fuzzy "codex-extra" + let dirs = vec![ + ("codex".to_string(), PathBuf::from("/fake/codex")), + ("codex-extra".to_string(), PathBuf::from("/fake/codex-extra")), + ]; + let matches = find_matching_agent_dirs("codex", &dirs); + assert_eq!(matches.len(), 1); + assert!(matches[0].to_string_lossy().contains("/codex")); + } + + #[test] + fn test_fuzzy_match_multiple_fuzzy_results() { + // "claude" matches both "claude-agent-acp" and "claude-code-acp" + let dirs = vec![ + ("claude-agent-acp".to_string(), PathBuf::from("/fake/claude-agent-acp")), + ("claude-code-acp".to_string(), PathBuf::from("/fake/claude-code-acp")), + ("codex".to_string(), PathBuf::from("/fake/codex")), + ]; + let matches = find_matching_agent_dirs("claude", &dirs); + assert_eq!(matches.len(), 2); + } + + #[test] + fn test_agent_server_type_display() { + assert_eq!(AgentServerType::Registry.to_string(), "registry"); + assert_eq!(AgentServerType::Custom.to_string(), "custom"); + assert_eq!(AgentServerType::Extension.to_string(), "extension"); + } + + #[test] + fn test_agent_server_type_default_is_registry() { + assert_eq!(AgentServerType::default(), AgentServerType::Registry); + } + + #[test] + fn test_parse_missing_type_defaults_to_registry() { + let dir = tempfile::tempdir().unwrap(); + let settings = r#"{ + "agent_servers": { + "codex": { + "command": "codex-acp", + "args": [], + "env": {} + } + } + }"#; + std::fs::write(dir.path().join("settings.json"), settings).unwrap(); + + let agents = discover_agents_from_settings(dir.path()); + assert_eq!(agents.len(), 1); + assert_eq!(agents[0].name, "codex"); + assert_eq!(agents[0].agent_type, AgentServerType::Registry); + } + + #[test] + fn test_discover_external_agents_not_in_settings() { + let data_dir = tempfile::tempdir().unwrap(); + let ext_dir = data_dir.path().join("external_agents"); + + // Create directories simulating Zed-downloaded agents + std::fs::create_dir_all(ext_dir.join("claude-agent-acp")).unwrap(); + std::fs::create_dir_all(ext_dir.join("claude-code-acp")).unwrap(); + std::fs::create_dir_all(ext_dir.join("codex")).unwrap(); + std::fs::create_dir_all(ext_dir.join("gemini")).unwrap(); + + // Existing agents from settings: only "codex" + let existing = vec![ZedAgent { + name: "codex".to_string(), + agent_type: AgentServerType::Custom, + binary_path: None, + env_overrides: HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }]; + + let discovered = discover_agents_from_external_dir(data_dir.path(), &existing); + + // "codex" is already in settings (exact match), so it's excluded. + // "claude-agent-acp" and "claude-code-acp" are new, and "gemini" is new. + assert_eq!(discovered.len(), 3); + + let names: Vec<&str> = discovered.iter().map(|a| a.name.as_str()).collect(); + assert!(names.contains(&"claude-agent-acp")); + assert!(names.contains(&"claude-code-acp")); + assert!(names.contains(&"gemini")); + + // All should be Registry type + for agent in &discovered { + assert_eq!(agent.agent_type, AgentServerType::Registry); + } + } + + #[test] + fn test_discover_external_agents_fuzzy_excludes() { + let data_dir = tempfile::tempdir().unwrap(); + let ext_dir = data_dir.path().join("external_agents"); + + std::fs::create_dir_all(ext_dir.join("claude-agent-acp")).unwrap(); + + // "claude" in settings should fuzzy-match "claude-agent-acp" + let existing = vec![ZedAgent { + name: "claude".to_string(), + agent_type: AgentServerType::Registry, + binary_path: None, + env_overrides: HashMap::new(), + display_name: None, + description: None, + args: Vec::new(), + icon_local_path: None, + icon_url: None, + }]; + + let discovered = discover_agents_from_external_dir(data_dir.path(), &existing); + assert!( + discovered.is_empty(), + "claude-agent-acp should be excluded by fuzzy match with 'claude'" + ); + } + + #[test] + fn test_discover_external_agents_no_dir() { + let data_dir = tempfile::tempdir().unwrap(); + // No external_agents/ directory exists + let discovered = discover_agents_from_external_dir(data_dir.path(), &[]); + assert!(discovered.is_empty()); + } + + #[test] + fn test_discover_excludes_registry_dir() { + let tmp = tempfile::tempdir().unwrap(); + let agents_dir = tmp.path().join("external_agents"); + std::fs::create_dir_all(agents_dir.join("registry")).unwrap(); + std::fs::create_dir_all(agents_dir.join("some-agent")).unwrap(); + + let discovered = discover_agents_from_external_dir(tmp.path(), &[]); + let names: Vec<&str> = discovered.iter().map(|a| a.name.as_str()).collect(); + assert!(!names.contains(&"registry"), "registry/ should be excluded"); + assert!(names.contains(&"some-agent")); + } + + #[test] + fn test_find_nodejs_bin_prefers_agent_name() { + // Simulates the Gemini scenario: multiple binaries in .bin/, + // the agent binary should be preferred over dependency binaries. + let dir = tempfile::tempdir().unwrap(); + let bin_dir = dir + .path() + .join("external_agents") + .join("gemini") + .join("0.23.0") + .join("node_modules") + .join(".bin"); + std::fs::create_dir_all(&bin_dir).unwrap(); + + // Create dependency binaries that sort alphabetically before "gemini" + for name in &["acorn", "esparse", "extract-zip", "gemini", "glob"] { + let path = bin_dir.join(platform_binary(name)); + std::fs::write(&path, b"#!/usr/bin/env node\n").unwrap(); + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o755)).unwrap(); + } + } + + let version_dir = dir + .path() + .join("external_agents") + .join("gemini") + .join("0.23.0"); + + let result = find_nodejs_bin("gemini", &version_dir); + assert!(result.is_some(), "Should find a binary"); + let resolved = result.unwrap(); + let stem = resolved.file_stem().unwrap().to_string_lossy().to_string(); + assert_eq!( + stem, "gemini", + "Should prefer 'gemini' binary over dependency binaries, got: {}", + resolved.display() + ); + } +} diff --git a/crates/dirigent_zed/src/detection.rs b/crates/dirigent_zed/src/detection.rs new file mode 100644 index 0000000..4e69d9a --- /dev/null +++ b/crates/dirigent_zed/src/detection.rs @@ -0,0 +1,158 @@ +//! Detection of Zed editor installations on the current system. +//! +//! Checks for the existence of Zed configuration directories and discovers +//! configured agents within each installation. + +use crate::agents::{self, ZedAgent}; +use crate::paths::{self, ZedChannel}; +use std::path::PathBuf; + +/// A detected Zed editor installation with its configuration and agents. +#[derive(Debug, Clone)] +pub struct ZedInstallation { + /// Release channel (Stable, Preview, Nightly, Dev). + pub channel: ZedChannel, + /// Path to the configuration directory (contains `settings.json`). + pub config_dir: PathBuf, + /// Path to the data directory (contains `external_agents/`). + pub data_dir: PathBuf, + /// Agents discovered from settings and resolved binary paths. + pub agents: Vec<ZedAgent>, +} + +/// Detect Zed installations on this system. +/// +/// Checks for the existence of `settings.json` in the Zed config directory. +/// Currently, Zed uses a single config directory for all channels (unlike +/// some editors that have per-channel directories). We report one installation +/// per detected config directory. +/// +/// For each found installation: +/// 1. Discovers agents from `settings.json` (`agent_servers` key) +/// 2. Resolves downloaded binary paths from the data directory +pub fn detect_installations() -> Vec<ZedInstallation> { + let config_dir = match paths::zed_config_dir() { + Some(d) => d, + None => { + tracing::debug!("Could not determine Zed config directory for this platform"); + return Vec::new(); + } + }; + + let data_dir = match paths::zed_data_dir() { + Some(d) => d, + None => { + tracing::debug!("Could not determine Zed data directory for this platform"); + return Vec::new(); + } + }; + + let settings_path = config_dir.join("settings.json"); + if !settings_path.exists() { + tracing::debug!( + path = %settings_path.display(), + "Zed settings.json not found — no installation detected" + ); + return Vec::new(); + } + + tracing::info!( + config = %config_dir.display(), + data = %data_dir.display(), + "Detected Zed installation" + ); + + let mut found_agents = agents::discover_agents_from_settings(&config_dir); + + // Also discover agents from external_agents/ that aren't in settings. + let extra_agents = agents::discover_agents_from_external_dir(&data_dir, &found_agents); + found_agents.extend(extra_agents); + + agents::resolve_binary_paths(&mut found_agents, &data_dir); + + // Enrich agents with registry metadata (display names, descriptions, args, icons). + let registry = crate::registry::parse_registry(&data_dir); + agents::enrich_agents_from_registry(&mut found_agents, ®istry); + + // Zed currently uses a single config dir for all channels. We report it + // as Stable by default. If we later learn how to distinguish channels + // (e.g., via a marker file or binary path), we can refine this. + vec![ZedInstallation { + channel: ZedChannel::Stable, + config_dir, + data_dir, + agents: found_agents, + }] +} + +/// Detect installations using explicit paths (useful for testing or overrides). +pub fn detect_installation_at(config_dir: PathBuf, data_dir: PathBuf) -> Option<ZedInstallation> { + let settings_path = config_dir.join("settings.json"); + if !settings_path.exists() { + return None; + } + + let mut found_agents = agents::discover_agents_from_settings(&config_dir); + let extra_agents = agents::discover_agents_from_external_dir(&data_dir, &found_agents); + found_agents.extend(extra_agents); + agents::resolve_binary_paths(&mut found_agents, &data_dir); + + let registry = crate::registry::parse_registry(&data_dir); + agents::enrich_agents_from_registry(&mut found_agents, ®istry); + + Some(ZedInstallation { + channel: ZedChannel::Stable, + config_dir, + data_dir, + agents: found_agents, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_detect_installation_at_missing_settings() { + let dir = tempfile::tempdir().unwrap(); + let result = detect_installation_at(dir.path().to_path_buf(), dir.path().to_path_buf()); + assert!(result.is_none()); + } + + #[test] + fn test_detect_installation_at_with_settings() { + let config_dir = tempfile::tempdir().unwrap(); + let data_dir = tempfile::tempdir().unwrap(); + + let settings = r#"{ + "agent_servers": { + "claude-acp": { "type": "registry" } + } + }"#; + std::fs::write(config_dir.path().join("settings.json"), settings).unwrap(); + + let installation = + detect_installation_at(config_dir.path().to_path_buf(), data_dir.path().to_path_buf()); + assert!(installation.is_some()); + + let inst = installation.unwrap(); + assert_eq!(inst.channel, ZedChannel::Stable); + assert_eq!(inst.agents.len(), 1); + assert_eq!(inst.agents[0].name, "claude-acp"); + } + + #[test] + fn test_detect_installation_at_empty_settings() { + let config_dir = tempfile::tempdir().unwrap(); + let data_dir = tempfile::tempdir().unwrap(); + + std::fs::write(config_dir.path().join("settings.json"), "{}").unwrap(); + + let installation = + detect_installation_at(config_dir.path().to_path_buf(), data_dir.path().to_path_buf()); + assert!(installation.is_some()); + + let inst = installation.unwrap(); + assert!(inst.agents.is_empty()); + } +} diff --git a/crates/dirigent_zed/src/lib.rs b/crates/dirigent_zed/src/lib.rs new file mode 100644 index 0000000..40f3f61 --- /dev/null +++ b/crates/dirigent_zed/src/lib.rs @@ -0,0 +1,31 @@ +//! Zed editor integration for Dirigent. +//! +//! Provides detection of Zed installations, discovery of configured ACP agents, +//! and binary path resolution for agent servers managed by Zed. +//! +//! # Usage +//! +//! ```rust,no_run +//! use dirigent_zed::{detect_installations, ZedAgent, AgentServerType}; +//! +//! let installations = detect_installations(); +//! for inst in &installations { +//! println!("Zed {} at {}", inst.channel, inst.config_dir.display()); +//! for agent in &inst.agents { +//! println!(" Agent: {} ({})", agent.name, agent.agent_type); +//! if let Some(ref path) = agent.binary_path { +//! println!(" Binary: {}", path.display()); +//! } +//! } +//! } +//! ``` + +pub mod agents; +pub mod detection; +pub mod paths; +pub mod registry; + +pub use agents::{AgentServerType, ZedAgent}; +pub use detection::{detect_installation_at, detect_installations, ZedInstallation}; +pub use paths::ZedChannel; +pub use registry::{parse_registry, find_registry_match, RegistryAgentInfo}; diff --git a/crates/dirigent_zed/src/paths.rs b/crates/dirigent_zed/src/paths.rs new file mode 100644 index 0000000..46718ac --- /dev/null +++ b/crates/dirigent_zed/src/paths.rs @@ -0,0 +1,147 @@ +//! Platform path resolution for Zed editor directories. +//! +//! Resolves configuration and data directories for each Zed release channel +//! across Linux, macOS, and Windows. + +use std::path::PathBuf; + +/// Zed release channel. Each channel has independent config and data directories. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ZedChannel { + Stable, + Preview, + Nightly, + Dev, +} + +impl ZedChannel { + /// All known release channels. + pub fn all() -> &'static [ZedChannel] { + &[ + ZedChannel::Stable, + ZedChannel::Preview, + ZedChannel::Nightly, + ZedChannel::Dev, + ] + } + + /// Socket/identifier name used in IPC paths. + pub fn socket_name(&self) -> &'static str { + match self { + ZedChannel::Stable => "stable", + ZedChannel::Preview => "preview", + ZedChannel::Nightly => "nightly", + ZedChannel::Dev => "dev", + } + } + + /// Display name for the channel. + pub fn display_name(&self) -> &'static str { + match self { + ZedChannel::Stable => "Stable", + ZedChannel::Preview => "Preview", + ZedChannel::Nightly => "Nightly", + ZedChannel::Dev => "Dev", + } + } +} + +impl std::fmt::Display for ZedChannel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.display_name()) + } +} + +/// Resolve the Zed configuration directory for this platform. +/// +/// | Platform | Path | +/// |----------|------| +/// | Linux | `$XDG_CONFIG_HOME/zed` (default: `~/.config/zed`) | +/// | macOS | `~/.config/zed` | +/// | Windows | `%APPDATA%\Zed` | +pub fn zed_config_dir() -> Option<PathBuf> { + #[cfg(target_os = "macos")] + { + dirs::home_dir().map(|d| d.join(".config").join("zed")) + } + #[cfg(target_os = "windows")] + { + dirs::config_dir().map(|d| d.join("Zed")) + } + #[cfg(not(any(target_os = "macos", target_os = "windows")))] + { + dirs::config_dir().map(|d| d.join("zed")) + } +} + +/// Resolve the Zed data directory for this platform. +/// +/// | Platform | Path | +/// |----------|------| +/// | Linux | `$XDG_DATA_HOME/zed` (default: `~/.local/share/zed`) | +/// | macOS | `~/Library/Application Support/Zed` | +/// | Windows | `%LOCALAPPDATA%\Zed` | +pub fn zed_data_dir() -> Option<PathBuf> { + #[cfg(target_os = "windows")] + { + dirs::data_local_dir().map(|d| d.join("Zed")) + } + #[cfg(not(target_os = "windows"))] + { + dirs::data_dir().map(|d| d.join("zed")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_all_channels() { + let channels = ZedChannel::all(); + assert_eq!(channels.len(), 4); + assert!(channels.contains(&ZedChannel::Stable)); + assert!(channels.contains(&ZedChannel::Preview)); + assert!(channels.contains(&ZedChannel::Nightly)); + assert!(channels.contains(&ZedChannel::Dev)); + } + + #[test] + fn test_socket_names() { + assert_eq!(ZedChannel::Stable.socket_name(), "stable"); + assert_eq!(ZedChannel::Preview.socket_name(), "preview"); + assert_eq!(ZedChannel::Nightly.socket_name(), "nightly"); + assert_eq!(ZedChannel::Dev.socket_name(), "dev"); + } + + #[test] + fn test_display_names() { + assert_eq!(ZedChannel::Stable.to_string(), "Stable"); + assert_eq!(ZedChannel::Dev.to_string(), "Dev"); + } + + #[test] + fn test_config_dir_is_some() { + // On any supported platform, this should resolve. + let dir = zed_config_dir(); + assert!(dir.is_some(), "zed_config_dir should resolve on this platform"); + let path = dir.unwrap(); + let path_str = path.to_string_lossy(); + assert!( + path_str.contains("zed") || path_str.contains("Zed"), + "config dir should contain 'zed': {path_str}" + ); + } + + #[test] + fn test_data_dir_is_some() { + let dir = zed_data_dir(); + assert!(dir.is_some(), "zed_data_dir should resolve on this platform"); + let path = dir.unwrap(); + let path_str = path.to_string_lossy(); + assert!( + path_str.contains("zed") || path_str.contains("Zed"), + "data dir should contain 'zed': {path_str}" + ); + } +} diff --git a/crates/dirigent_zed/src/registry.rs b/crates/dirigent_zed/src/registry.rs new file mode 100644 index 0000000..5dcb333 --- /dev/null +++ b/crates/dirigent_zed/src/registry.rs @@ -0,0 +1,546 @@ +//! Registry metadata parsing for Zed ACP agents. +//! +//! Reads the local `registry.json` file that Zed caches from the ACP registry CDN. +//! Provides enrichment data (display names, descriptions, command args, icon paths) +//! for discovered agents. + +use serde::Deserialize; +use std::collections::HashMap; +use std::path::{Path, PathBuf}; + +/// Metadata about an agent from the ACP registry. +#[derive(Debug, Clone)] +pub struct RegistryAgentInfo { + /// Registry identifier (e.g. "claude-acp", "codex-acp"). + pub id: String, + /// Human-friendly display name (e.g. "Claude Agent", "Codex CLI"). + pub display_name: String, + /// Short description of the agent. + pub description: String, + /// Icon URL from the CDN. + pub icon_url: Option<String>, + /// Path to the locally cached icon file (SVG), if it exists. + pub icon_local_path: Option<PathBuf>, + /// Version string from the registry. + pub version: String, + /// Command arguments from the distribution config. + /// + /// For npx-distributed agents this may include flags like `["--acp"]`. + /// For binary-distributed agents this is the platform-specific `cmd` value. + pub args: Vec<String>, + /// The command to run, extracted from the platform-appropriate distribution. + /// + /// For binary distributions this is the `cmd` field (e.g. `"./codex-acp"`). + /// For npx distributions this is the npx package specifier. + pub command: Option<String>, + /// Environment variables from the distribution config. + pub env: HashMap<String, String>, +} + +// --------------------------------------------------------------------------- +// Raw serde models for registry.json +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +struct RegistryFile { + #[serde(default)] + agents: Vec<RawRegistryAgent>, +} + +#[derive(Debug, Deserialize)] +struct RawRegistryAgent { + id: String, + name: String, + #[serde(default)] + version: String, + #[serde(default)] + description: String, + #[serde(default)] + icon: Option<String>, + #[serde(default)] + distribution: Option<RawDistribution>, +} + +#[derive(Debug, Deserialize)] +struct RawDistribution { + #[serde(default)] + binary: Option<HashMap<String, RawBinaryTarget>>, + #[serde(default)] + npx: Option<RawNpxDistribution>, + #[serde(default)] + uvx: Option<RawUvxDistribution>, +} + +#[derive(Debug, Deserialize)] +struct RawBinaryTarget { + #[serde(default)] + cmd: Option<String>, +} + +#[derive(Debug, Deserialize)] +struct RawNpxDistribution { + #[serde(default)] + package: Option<String>, + #[serde(default)] + args: Vec<String>, + #[serde(default)] + env: HashMap<String, String>, +} + +#[derive(Debug, Deserialize)] +struct RawUvxDistribution { + #[serde(default)] + package: Option<String>, + #[serde(default)] + args: Vec<String>, + #[serde(default)] + env: HashMap<String, String>, +} + +/// Determine the current platform key used in the registry's binary distribution map. +/// +/// Returns keys like `"windows-x86_64"`, `"linux-aarch64"`, `"darwin-aarch64"`, etc. +fn current_platform_key() -> Option<&'static str> { + #[cfg(all(target_os = "windows", target_arch = "x86_64"))] + { + Some("windows-x86_64") + } + #[cfg(all(target_os = "windows", target_arch = "aarch64"))] + { + Some("windows-aarch64") + } + #[cfg(all(target_os = "macos", target_arch = "aarch64"))] + { + Some("darwin-aarch64") + } + #[cfg(all(target_os = "macos", target_arch = "x86_64"))] + { + Some("darwin-x86_64") + } + #[cfg(all(target_os = "linux", target_arch = "x86_64"))] + { + Some("linux-x86_64") + } + #[cfg(all(target_os = "linux", target_arch = "aarch64"))] + { + Some("linux-aarch64") + } + #[cfg(not(any( + all(target_os = "windows", target_arch = "x86_64"), + all(target_os = "windows", target_arch = "aarch64"), + all(target_os = "macos", target_arch = "aarch64"), + all(target_os = "macos", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "x86_64"), + all(target_os = "linux", target_arch = "aarch64"), + )))] + { + None + } +} + +/// Parse the local Zed registry.json and return a map of agent id -> metadata. +/// +/// The registry file lives at `{data_dir}/external_agents/registry/registry.json`. +/// Icons are cached at `{data_dir}/external_agents/registry/icons/{id}.svg`. +/// +/// Returns an empty map if the file doesn't exist or can't be parsed. +pub fn parse_registry(data_dir: &Path) -> HashMap<String, RegistryAgentInfo> { + let registry_dir = data_dir.join("external_agents").join("registry"); + let registry_path = registry_dir.join("registry.json"); + + let content = match std::fs::read_to_string(®istry_path) { + Ok(c) => c, + Err(e) => { + tracing::debug!( + path = %registry_path.display(), + error = %e, + "Could not read Zed registry.json" + ); + return HashMap::new(); + } + }; + + let registry: RegistryFile = match serde_json::from_str(&content) { + Ok(r) => r, + Err(e) => { + tracing::warn!( + path = %registry_path.display(), + error = %e, + "Failed to parse Zed registry.json" + ); + return HashMap::new(); + } + }; + + let icons_dir = registry_dir.join("icons"); + let platform_key = current_platform_key(); + + let mut map = HashMap::with_capacity(registry.agents.len()); + + for agent in registry.agents { + let icon_local_path = { + let candidate = icons_dir.join(format!("{}.svg", agent.id)); + if candidate.exists() { + Some(candidate) + } else { + None + } + }; + + let (command, args, env) = + extract_distribution_info(agent.distribution.as_ref(), platform_key); + + let info = RegistryAgentInfo { + id: agent.id.clone(), + display_name: agent.name, + description: agent.description, + icon_url: agent.icon, + icon_local_path, + version: agent.version, + args, + command, + env, + }; + + map.insert(agent.id, info); + } + + tracing::debug!( + count = map.len(), + "Parsed Zed registry with {} agents", + map.len() + ); + + map +} + +/// Extract command, args, and env from the distribution config. +/// +/// Priority: binary (platform-specific) > npx > uvx. +fn extract_distribution_info( + distribution: Option<&RawDistribution>, + platform_key: Option<&str>, +) -> (Option<String>, Vec<String>, HashMap<String, String>) { + let dist = match distribution { + Some(d) => d, + None => return (None, Vec::new(), HashMap::new()), + }; + + // Prefer binary distribution for the current platform. + if let Some(ref binary) = dist.binary { + if let Some(key) = platform_key { + if let Some(target) = binary.get(key) { + let cmd = target.cmd.clone(); + return (cmd, Vec::new(), HashMap::new()); + } + } + } + + // Fall back to npx distribution. + if let Some(ref npx) = dist.npx { + return ( + npx.package.clone(), + npx.args.clone(), + npx.env.clone(), + ); + } + + // Fall back to uvx distribution. + if let Some(ref uvx) = dist.uvx { + return ( + uvx.package.clone(), + uvx.args.clone(), + uvx.env.clone(), + ); + } + + (None, Vec::new(), HashMap::new()) +} + +/// Look up a registry entry by matching an agent name or directory name to a registry id. +/// +/// The matching strategy: +/// 1. Exact match on registry id +/// 2. Substring: registry id contained in agent name, or vice versa +/// 3. Core-name match: strip common ACP suffixes and compare base names +/// (e.g. "claude-agent-acp" and "claude-acp" both have core name "claude") +pub fn find_registry_match<'a>( + agent_name: &str, + registry: &'a HashMap<String, RegistryAgentInfo>, +) -> Option<&'a RegistryAgentInfo> { + let name_lower = agent_name.to_lowercase(); + + // 1. Exact match on registry id. + if let Some(info) = registry.get(&name_lower) { + return Some(info); + } + + // 2. Substring match: registry id contained in agent name, or vice versa. + let mut best: Option<&'a RegistryAgentInfo> = None; + let mut best_len = 0; + + for (id, info) in registry { + let id_lower = id.to_lowercase(); + if name_lower.contains(&id_lower) || id_lower.contains(&name_lower) { + if id_lower.len() > best_len { + best = Some(info); + best_len = id_lower.len(); + } + } + } + + if best.is_some() { + return best; + } + + // 3. Core-name match: strip ACP-related suffixes and compare. + let agent_core = strip_acp_suffixes(&name_lower); + for (id, info) in registry { + let id_lower = id.to_lowercase(); + let id_core = strip_acp_suffixes(&id_lower); + if !agent_core.is_empty() && agent_core == id_core { + return Some(info); + } + } + + None +} + +/// Strip common ACP-related suffixes to extract the core agent name. +/// +/// For example: +/// - "claude-agent-acp" -> "claude" +/// - "claude-acp" -> "claude" +/// - "claude-code-acp" -> "claude" +/// - "codex" -> "codex" +pub fn strip_acp_suffixes(name: &str) -> &str { + // Strip known suffixes in order of specificity (longest first). + for suffix in &["-agent-acp", "-code-acp", "-acp", "-cli"] { + if let Some(core) = name.strip_suffix(suffix) { + return core; + } + } + name +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Write; + + fn sample_registry_json() -> &'static str { + r#"{ + "version": "1.0.0", + "agents": [ + { + "id": "claude-acp", + "name": "Claude Agent", + "version": "0.24.2", + "description": "ACP wrapper for Anthropic's Claude", + "icon": "https://cdn.agentclientprotocol.com/registry/v1/latest/claude-acp.svg", + "distribution": { + "npx": { + "package": "@agentclientprotocol/claude-agent-acp@0.24.2" + } + } + }, + { + "id": "codex-acp", + "name": "Codex CLI", + "version": "0.10.0", + "description": "ACP adapter for OpenAI's coding assistant", + "icon": "https://cdn.agentclientprotocol.com/registry/v1/latest/codex-acp.svg", + "distribution": { + "binary": { + "linux-x86_64": { + "archive": "https://example.com/codex.tar.gz", + "cmd": "./codex-acp" + }, + "windows-x86_64": { + "archive": "https://example.com/codex.zip", + "cmd": "./codex-acp.exe" + }, + "darwin-aarch64": { + "archive": "https://example.com/codex-mac.tar.gz", + "cmd": "./codex-acp" + } + }, + "npx": { + "package": "@zed-industries/codex-acp@0.10.0" + } + } + }, + { + "id": "auggie", + "name": "Auggie CLI", + "version": "0.21.0", + "description": "Augment Code's powerful software agent", + "distribution": { + "npx": { + "package": "@augmentcode/auggie@0.21.0", + "args": ["--acp"], + "env": { "AUGMENT_DISABLE_AUTO_UPDATE": "1" } + } + } + } + ] +}"# + } + + #[test] + fn test_parse_registry_basic() { + let dir = tempfile::tempdir().unwrap(); + let registry_dir = dir + .path() + .join("external_agents") + .join("registry"); + std::fs::create_dir_all(®istry_dir).unwrap(); + + let mut f = + std::fs::File::create(registry_dir.join("registry.json")).unwrap(); + f.write_all(sample_registry_json().as_bytes()).unwrap(); + + let map = parse_registry(dir.path()); + assert_eq!(map.len(), 3); + + let claude = map.get("claude-acp").unwrap(); + assert_eq!(claude.display_name, "Claude Agent"); + assert_eq!(claude.description, "ACP wrapper for Anthropic's Claude"); + assert!(claude.icon_url.is_some()); + // No local icon file created in test. + assert!(claude.icon_local_path.is_none()); + + let auggie = map.get("auggie").unwrap(); + assert_eq!(auggie.args, vec!["--acp"]); + assert_eq!( + auggie.env.get("AUGMENT_DISABLE_AUTO_UPDATE"), + Some(&"1".to_string()) + ); + } + + #[test] + fn test_parse_registry_with_local_icon() { + let dir = tempfile::tempdir().unwrap(); + let registry_dir = dir + .path() + .join("external_agents") + .join("registry"); + let icons_dir = registry_dir.join("icons"); + std::fs::create_dir_all(&icons_dir).unwrap(); + + std::fs::File::create(registry_dir.join("registry.json")) + .unwrap() + .write_all(sample_registry_json().as_bytes()) + .unwrap(); + + // Create a fake icon file. + std::fs::write(icons_dir.join("claude-acp.svg"), "<svg/>").unwrap(); + + let map = parse_registry(dir.path()); + let claude = map.get("claude-acp").unwrap(); + assert!(claude.icon_local_path.is_some()); + assert!(claude + .icon_local_path + .as_ref() + .unwrap() + .to_string_lossy() + .contains("claude-acp.svg")); + } + + #[test] + fn test_parse_registry_missing_file() { + let dir = tempfile::tempdir().unwrap(); + let map = parse_registry(dir.path()); + assert!(map.is_empty()); + } + + #[test] + fn test_parse_registry_invalid_json() { + let dir = tempfile::tempdir().unwrap(); + let registry_dir = dir + .path() + .join("external_agents") + .join("registry"); + std::fs::create_dir_all(®istry_dir).unwrap(); + std::fs::write(registry_dir.join("registry.json"), "not json").unwrap(); + + let map = parse_registry(dir.path()); + assert!(map.is_empty()); + } + + #[test] + fn test_find_registry_match_exact() { + let dir = tempfile::tempdir().unwrap(); + let registry_dir = dir + .path() + .join("external_agents") + .join("registry"); + std::fs::create_dir_all(®istry_dir).unwrap(); + std::fs::File::create(registry_dir.join("registry.json")) + .unwrap() + .write_all(sample_registry_json().as_bytes()) + .unwrap(); + + let map = parse_registry(dir.path()); + + // Exact match. + let info = find_registry_match("claude-acp", &map).unwrap(); + assert_eq!(info.display_name, "Claude Agent"); + } + + #[test] + fn test_find_registry_match_fuzzy() { + let dir = tempfile::tempdir().unwrap(); + let registry_dir = dir + .path() + .join("external_agents") + .join("registry"); + std::fs::create_dir_all(®istry_dir).unwrap(); + std::fs::File::create(registry_dir.join("registry.json")) + .unwrap() + .write_all(sample_registry_json().as_bytes()) + .unwrap(); + + let map = parse_registry(dir.path()); + + // Agent name "claude" should fuzzy-match "claude-acp". + let info = find_registry_match("claude", &map).unwrap(); + assert_eq!(info.display_name, "Claude Agent"); + + // Directory name "claude-agent-acp" should fuzzy-match "claude-acp". + let info2 = find_registry_match("claude-agent-acp", &map).unwrap(); + assert_eq!(info2.display_name, "Claude Agent"); + } + + #[test] + fn test_find_registry_match_no_match() { + let map = HashMap::new(); + assert!(find_registry_match("nonexistent", &map).is_none()); + } + + #[test] + fn test_binary_distribution_platform_cmd() { + let dir = tempfile::tempdir().unwrap(); + let registry_dir = dir + .path() + .join("external_agents") + .join("registry"); + std::fs::create_dir_all(®istry_dir).unwrap(); + std::fs::File::create(registry_dir.join("registry.json")) + .unwrap() + .write_all(sample_registry_json().as_bytes()) + .unwrap(); + + let map = parse_registry(dir.path()); + let codex = map.get("codex-acp").unwrap(); + + // On any supported platform, the binary distribution should produce a command. + // The exact value depends on the compile target. + if current_platform_key().is_some() { + assert!( + codex.command.is_some(), + "codex-acp should have a command from binary distribution" + ); + } + } +} diff --git a/crates/opencode_client/CLAUDE.md b/crates/opencode_client/CLAUDE.md new file mode 100644 index 0000000..46e859b --- /dev/null +++ b/crates/opencode_client/CLAUDE.md @@ -0,0 +1,30 @@ +# Package: opencode_client + +Rust client library for interacting with the OpenCode.ai API. + +## Quick Facts +- **Type**: Library +- **Main Entry**: src/lib.rs +- **Dependencies**: reqwest, serde, serde_json, chrono + +## Key Files +- `src/lib.rs` - Public API exports +- `src/types.rs` - OpenCode API type definitions (Session, Message, Part, etc.) +- `src/client.rs` - HTTP client implementation with optional logging callbacks + +## Main Exports +- `OpenCodeClient` - Main API client with methods: list_sessions, list_messages, send_message +- `Session` - Session metadata and configuration +- `Message` - User or Assistant message (tagged enum) +- `MessageWithParts` - Message info + content parts +- `Part` - Text, Reasoning, or workflow parts (StepStart, StepFinish, Tool) +- `ClientError` - Error types: Http, Request, Serialization +- `LogCallback` - Type alias for logging callbacks + +## Related +- Used by: web, mobile (future), desktop (future) +- Independent: Can be used in any Rust project + +## Documentation +- README: ./README.md +- API spec: ../../docs/api/opencode.md diff --git a/crates/opencode_client/Cargo.toml b/crates/opencode_client/Cargo.toml new file mode 100644 index 0000000..a2ff51f --- /dev/null +++ b/crates/opencode_client/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "opencode_client" +version = "0.1.0" +edition = "2021" + +[dependencies] +chrono = "0.4" +futures = "0.3" +reqwest = { version = "0.12", features = ["json", "stream"] } +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +tokio = { version = "1", features = ["sync"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +js-sys = "0.3" +wasm-bindgen = "0.2" +wasm-bindgen-futures = "0.4" +web-sys = { version = "0.3", features = [ + "Event", + "EventSource", + "MessageEvent", + "console" +] } diff --git a/crates/opencode_client/README.md b/crates/opencode_client/README.md new file mode 100644 index 0000000..80fc826 --- /dev/null +++ b/crates/opencode_client/README.md @@ -0,0 +1,183 @@ +# opencode_client + +A Rust client library for interacting with the OpenCode.ai API. + +## Purpose + +This package provides a type-safe, ergonomic Rust interface to the OpenCode.ai API, enabling any Rust application to interact with OpenCode sessions and messages. It's designed to be reusable across different UI frameworks and platforms (web, mobile, desktop, CLI). + +**⚠️ Important: API Ownership** + +**ALL OpenCode.ai API functionality lives in this package.** This includes: +- ✅ REST API endpoints (sessions, messages, files, etc.) +- 🚧 Server-Sent Events (SSE) for real-time streaming (planned) +- 🔮 WebSocket connections (if needed in future) +- 🔮 Authentication flows +- 🔮 Rate limiting and retry logic + +UI packages (`web`, `mobile`, `desktop`) should remain thin presentation layers that consume this client library. Never implement API calls directly in UI code. + +## Features + +- **Type-safe API**: Full Rust type definitions for all OpenCode API structures +- **Async/await**: Built on `reqwest` for non-blocking HTTP operations +- **Optional logging**: Flexible logging callbacks for integration with any logging system +- **WASM compatible**: Works in browser environments via WebAssembly +- **Zero UI dependencies**: Pure business logic, usable anywhere + +## Usage + +### Basic Example + +```rust +use opencode_client::OpenCodeClient; + +#[tokio::main] +async fn main() -> Result<(), Box<dyn std::error::Error>> { + let client = OpenCodeClient::new("http://localhost:12225"); + + // List all sessions + let sessions = client.list_sessions().await?; + println!("Found {} sessions", sessions.len()); + + // Get messages from a session + if let Some(session) = sessions.first() { + let messages = client.list_messages(&session.id).await?; + println!("Session has {} messages", messages.len()); + } + + Ok(()) +} +``` + +### With Logging Callbacks + +```rust +use opencode_client::OpenCodeClient; +use std::sync::Arc; + +let client = OpenCodeClient::new("http://localhost:12225") + .with_logging( + Arc::new(|cat, msg| println!("[INFO] {}: {}", cat, msg)), + Arc::new(|cat, msg| println!("[SUCCESS] {}: {}", cat, msg)), + Arc::new(|cat, msg| eprintln!("[ERROR] {}: {}", cat, msg)), + ); + +// Now all API calls will log through your callbacks +let sessions = client.list_sessions().await?; +``` + +### Sending Messages + +```rust +let message_with_parts = client + .send_message("session_id", "Hello, world!".to_string()) + .await?; + +// Access the message info and parts +println!("Message ID: {}", get_message_id(&message_with_parts.info)); +for part in &message_with_parts.parts { + match part { + Part::Text(text) => println!("Text: {}", text.text), + Part::Reasoning(reasoning) => println!("Reasoning: {}", reasoning.text), + Part::Tool(tool) => println!("Tool: {} ({})", tool.tool, tool.state), + _ => {} + } +} +``` + +## Architecture + +The client is organized into three main modules: + +- **types.rs**: All OpenCode API type definitions (Session, Message, Part, etc.) +- **client.rs**: HTTP client implementation with CRUD operations +- **lib.rs**: Public API exports + +The client uses tagged enums for discriminated unions (e.g., `Message::User` vs `Message::Assistant`) and serde for JSON serialization/deserialization. + +## API Reference + +For detailed API documentation, see: +- [OpenCode API Documentation](../../docs/api/opencode.md) +- Or run: `cargo doc --open -p opencode_client` + +## Key Types + +- **OpenCodeClient**: Main client struct +- **Session**: Session metadata (id, title, project, timestamps) +- **Message**: Tagged enum for User or Assistant messages + - `UserMessage`: User-sent messages with optional summary + - `AssistantMessage`: AI responses with tokens, cost, system prompt, and metadata +- **MessageWithParts**: Complete message with content parts (info + parts array) +- **Part**: Message content enum: + - `Text(TextPart)`: Regular text content + - `Reasoning(ReasoningPart)`: AI reasoning/thinking process + - `Tool(ToolPart)`: Tool execution with state tracking + - `StepStart(GenericPart)`: Step boundary marker + - `StepFinish(GenericPart)`: Step completion marker + - `Unknown`: Future-proof fallback +- **ToolPart**: Tool execution details with: + - `tool`: Tool name (string) + - `call_id`: Unique call identifier + - `state`: ToolState enum (Pending, Running, Completed, Error) +- **ToolState**: Execution state with status-specific data: + - `Pending`: Waiting to start + - `Running { input, title?, metadata?, time }`: Currently executing + - `Completed { input, output, title, metadata, time, attachments? }`: Successful completion + - `Error { input, error, metadata?, time }`: Failed execution +- **ClientError**: Error types (Http, Request, Serialization) + +## Dependencies + +- **reqwest**: Async HTTP client (with JSON support) +- **serde/serde_json**: JSON serialization +- **chrono**: Timestamp handling +- **wasm-bindgen/web-sys**: WASM compatibility (target-specific) + +## Related Packages + +- **web**: Uses this client for browser-based UI +- **mobile** (future): Will use this client for mobile apps +- **desktop** (future): Will use this client for desktop apps + +## Development + +```bash +# Build +cargo build -p opencode_client + +# Test +cargo test -p opencode_client + +# Documentation +cargo doc --open -p opencode_client +``` + +## Roadmap + +See detailed implementation plans: +- [SSE Implementation](../../docs/building/00/sse_implementation.md) - Real-time event streaming +- [OpenCode API Reference](../../docs/building/general/opencode_api.md) - Complete API documentation + +### Upcoming Features + +- **SSE Event Streaming** 🚧 + - Real-time message part updates + - Live session state changes + - Connection resilience and reconnection + - See: `docs/building/sse_implementation.md` + +- **File Operations** 📋 + - Read files from workspace + - Search files and symbols + - Track file status + +- **Advanced Session Management** 📋 + - Create/delete sessions + - Fork sessions + - Session diffs and summaries + +## Known Issues & Differences + +See [OpenCode API Documentation](../../docs/building/general/opencode_api.md) for details on differences between the official API spec and actual implementation (e.g., optional fields, missing fields, additional part types). diff --git a/crates/opencode_client/src/client.rs b/crates/opencode_client/src/client.rs new file mode 100644 index 0000000..e818528 --- /dev/null +++ b/crates/opencode_client/src/client.rs @@ -0,0 +1,337 @@ +//! OpenCode API Client +//! +//! HTTP client for interacting with opencode.ai API + +use crate::sse::{SseClient, SseError, SseStream}; +use crate::types::{MessageWithParts, Session}; +use serde::Serialize; +use std::sync::Arc; + +/// Logging callback type for API client events +pub type LogCallback = Arc<dyn Fn(&str, &str) + Send + Sync>; + +#[derive(Clone)] +pub struct OpenCodeClient { + base_url: String, + client: reqwest::Client, + log_info: Option<LogCallback>, + log_success: Option<LogCallback>, + log_error: Option<LogCallback>, +} + +impl OpenCodeClient { + /// Create a new OpenCode API client + pub fn new(base_url: impl Into<String>) -> Self { + Self { + base_url: base_url.into(), + client: reqwest::Client::new(), + log_info: None, + log_success: None, + log_error: None, + } + } + + /// Set logging callbacks for the client + pub fn with_logging( + mut self, + log_info: LogCallback, + log_success: LogCallback, + log_error: LogCallback, + ) -> Self { + self.log_info = Some(log_info); + self.log_success = Some(log_success); + self.log_error = Some(log_error); + self + } + + fn log_info(&self, category: &str, message: &str) { + if let Some(logger) = &self.log_info { + logger(category, message); + } + } + + fn log_success(&self, category: &str, message: &str) { + if let Some(logger) = &self.log_success { + logger(category, message); + } + } + + fn log_error(&self, category: &str, message: &str) { + if let Some(logger) = &self.log_error { + logger(category, message); + } + } + + /// List all sessions + pub async fn list_sessions(&self) -> Result<Vec<Session>, ClientError> { + let url = format!("{}/session", self.base_url); + self.log_info("API", &format!("GET {}", url)); + + let response = self.client.get(&url).send().await?; + let status = response.status(); + + if !status.is_success() { + let error_body = response + .text() + .await + .unwrap_or_else(|_| String::from("(no body)")); + self.log_error("API", &format!("GET {} failed: {}", url, status)); + self.log_error("API", &format!("Response body: {}", error_body)); + return Err(ClientError::Http(status)); + } + + let text = response.text().await?; + self.log_info("API", &format!("Response: {} bytes", text.len())); + + match serde_json::from_str::<Vec<Session>>(&text) { + Ok(sessions) => { + self.log_success("API", &format!("Loaded {} sessions", sessions.len())); + Ok(sessions) + } + Err(e) => { + self.log_error("API", &format!("Failed to decode sessions: {}", e)); + self.log_error("API", &format!("Response body: {}", text)); + Err(ClientError::Serialization(e)) + } + } + } + + /// Get a specific session by ID + pub async fn get_session(&self, session_id: &str) -> Result<Session, ClientError> { + let url = format!("{}/session/{}", self.base_url, session_id); + let response = self.client.get(&url).send().await?; + + if !response.status().is_success() { + return Err(ClientError::Http(response.status())); + } + + let session = response.json::<Session>().await?; + Ok(session) + } + + /// List messages in a session + pub async fn list_messages( + &self, + session_id: &str, + ) -> Result<Vec<MessageWithParts>, ClientError> { + let url = format!("{}/session/{}/message", self.base_url, session_id); + self.log_info("API", &format!("GET {}", url)); + + let response = self.client.get(&url).send().await?; + let status = response.status(); + + if !status.is_success() { + let error_body = response + .text() + .await + .unwrap_or_else(|_| String::from("(no body)")); + self.log_error("API", &format!("GET {} failed: {}", url, status)); + self.log_error("API", &format!("Response body: {}", error_body)); + return Err(ClientError::Http(status)); + } + + let text = response.text().await?; + self.log_info("API", &format!("Response: {} bytes", text.len())); + + // Try to parse as array first to get individual message values + match serde_json::from_str::<serde_json::Value>(&text) { + Ok(serde_json::Value::Array(message_values)) => { + let mut messages = Vec::new(); + for (index, msg_value) in message_values.iter().enumerate() { + match serde_json::from_value::<MessageWithParts>(msg_value.clone()) { + Ok(msg) => messages.push(msg), + Err(e) => { + self.log_error( + "API", + &format!("Failed to decode message at index {}: {}", index, e), + ); + self.log_error( + "API", + &format!( + "Problematic message JSON: {}", + serde_json::to_string_pretty(msg_value) + .unwrap_or_else(|_| msg_value.to_string()) + ), + ); + return Err(ClientError::Serialization(e)); + } + } + } + self.log_success("API", &format!("Loaded {} messages", messages.len())); + Ok(messages) + } + Ok(_) => { + self.log_error("API", "Response is not an array"); + self.log_error("API", &format!("Response body: {}", text)); + // Create a dummy error by trying to parse invalid data + let dummy_err = serde_json::from_str::<Vec<MessageWithParts>>("null").unwrap_err(); + Err(ClientError::Serialization(dummy_err)) + } + Err(e) => { + self.log_error("API", &format!("Failed to parse JSON: {}", e)); + self.log_error( + "API", + &format!( + "Response body (first 1000 chars): {}", + if text.len() > 1000 { + &text[..1000] + } else { + &text + } + ), + ); + Err(ClientError::Serialization(e)) + } + } + } + + /// Send a chat message to a session + pub async fn send_message( + &self, + session_id: &str, + text: String, + ) -> Result<MessageWithParts, ClientError> { + #[derive(Debug, Serialize)] + struct ChatPart { + #[serde(rename = "type")] + part_type: String, + text: String, + } + + #[derive(Debug, Serialize)] + struct ChatInput { + parts: Vec<ChatPart>, + } + + let url = format!("{}/session/{}/message", self.base_url, session_id); + self.log_info("API", &format!("POST {} (text: {} chars)", url, text.len())); + + let input = ChatInput { + parts: vec![ChatPart { + part_type: "text".to_string(), + text, + }], + }; + + let response = self.client.post(&url).json(&input).send().await?; + let status = response.status(); + + if !status.is_success() { + let error_body = response + .text() + .await + .unwrap_or_else(|_| String::from("(no body)")); + self.log_error("API", &format!("POST {} failed: {}", url, status)); + self.log_error("API", &format!("Response body: {}", error_body)); + return Err(ClientError::Http(status)); + } + + let response_text = response.text().await?; + self.log_info("API", &format!("Response: {} bytes", response_text.len())); + + match serde_json::from_str::<MessageWithParts>(&response_text) { + Ok(message) => { + self.log_success("API", "Message sent successfully"); + Ok(message) + } + Err(e) => { + self.log_error("API", &format!("Failed to decode message: {}", e)); + self.log_error("API", &format!("Response body: {}", response_text)); + Err(ClientError::Serialization(e)) + } + } + } + + /// Subscribe to server-sent events (SSE) for real-time updates + /// + /// Returns a stream of events including message updates, session changes, etc. + /// The stream will automatically handle the SSE protocol and deserialize events. + /// + /// # Example + /// + /// ```no_run + /// # use opencode_client::OpenCodeClient; + /// # use futures::stream::StreamExt; + /// # async fn example() -> Result<(), Box<dyn std::error::Error>> { + /// let client = OpenCodeClient::new("http://localhost:12225"); + /// let mut stream = client.subscribe_events()?; + /// + /// while let Some(event_result) = stream.next().await { + /// match event_result { + /// Ok(event) => println!("Received event: {:?}", event), + /// Err(e) => eprintln!("Event error: {}", e), + /// } + /// } + /// # Ok(()) + /// # } + /// ``` + pub fn subscribe_events(&self) -> Result<SseStream, SseError> { + self.log_info("SSE", &format!("Connecting to {}/event", self.base_url)); + + let sse_client = SseClient::new(&self.base_url); + let stream = sse_client.connect()?; + + self.log_success("SSE", "Connected to event stream"); + Ok(stream) + } + + /// Abort an in-progress message generation in a session + /// + /// Sends a POST request to `/session/:id/abort` to cancel the current generation. + /// + /// # Returns + /// + /// Returns `Ok(true)` if the abort was successful, `Ok(false)` if there was nothing to abort, + /// or an error if the request failed. + pub async fn abort_session(&self, session_id: &str) -> Result<bool, ClientError> { + let url = format!("{}/session/{}/abort", self.base_url, session_id); + self.log_info("API", &format!("POST {} (abort)", url)); + + let response = self.client.post(&url).send().await?; + let status = response.status(); + + if !status.is_success() { + let error_body = response + .text() + .await + .unwrap_or_else(|_| String::from("(no body)")); + self.log_error("API", &format!("POST {} failed: {}", url, status)); + self.log_error("API", &format!("Response body: {}", error_body)); + return Err(ClientError::Http(status)); + } + + self.log_success("API", "Generation aborted successfully"); + Ok(true) + } +} + +#[derive(Debug)] +pub enum ClientError { + Http(reqwest::StatusCode), + Request(reqwest::Error), + Serialization(serde_json::Error), +} + +impl From<reqwest::Error> for ClientError { + fn from(err: reqwest::Error) -> Self { + ClientError::Request(err) + } +} + +impl From<serde_json::Error> for ClientError { + fn from(err: serde_json::Error) -> Self { + ClientError::Serialization(err) + } +} + +impl std::fmt::Display for ClientError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ClientError::Http(status) => write!(f, "HTTP error: {}", status), + ClientError::Request(err) => write!(f, "Request error: {}", err), + ClientError::Serialization(err) => write!(f, "Serialization error: {}", err), + } + } +} + +impl std::error::Error for ClientError {} diff --git a/crates/opencode_client/src/lib.rs b/crates/opencode_client/src/lib.rs new file mode 100644 index 0000000..0ce1ace --- /dev/null +++ b/crates/opencode_client/src/lib.rs @@ -0,0 +1,34 @@ +//! OpenCode API Client Library +//! +//! A Rust client for interacting with the OpenCode.ai API. +//! +//! # Example +//! +//! ```no_run +//! use opencode_client::OpenCodeClient; +//! +//! # async fn example() -> Result<(), Box<dyn std::error::Error>> { +//! let client = OpenCodeClient::new("http://localhost:12225"); +//! +//! // List all sessions +//! let sessions = client.list_sessions().await?; +//! +//! // Get messages from a session +//! let messages = client.list_messages("session_id").await?; +//! +//! // Send a message +//! let response = client.send_message("session_id", "Hello!".to_string()).await?; +//! # Ok(()) +//! # } +//! ``` + +pub mod client; +pub mod sse; +pub mod types; + +pub use client::{ClientError, LogCallback, OpenCodeClient}; +pub use sse::{ConnectionState, SseClient, SseError, SseStream}; +pub use types::{ + AssistantMessage, AssistantMessageTime, Event, Message, MessageTime, MessageWithParts, Part, + ReasoningPart, Session, TextPart, ToolPart, ToolState, UserMessage, +}; diff --git a/crates/opencode_client/src/sse.rs b/crates/opencode_client/src/sse.rs new file mode 100644 index 0000000..b3db6be --- /dev/null +++ b/crates/opencode_client/src/sse.rs @@ -0,0 +1,299 @@ +//! SSE (Server-Sent Events) client for OpenCode API +//! +//! Provides cross-platform SSE streaming for real-time event updates + +use crate::types::Event; +use futures::stream::Stream; +use std::pin::Pin; + +/// SSE-specific errors +#[derive(Debug)] +pub enum SseError { + Network(String), + Parse(String), + Closed, +} + +impl std::fmt::Display for SseError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SseError::Network(msg) => write!(f, "Network error: {}", msg), + SseError::Parse(msg) => write!(f, "Parse error: {}", msg), + SseError::Closed => write!(f, "Connection closed"), + } + } +} + +impl std::error::Error for SseError {} + +/// Connection state +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionState { + Connecting, + Connected, + Disconnected, + Error, +} + +/// Stream of SSE events +pub type SseStream = Pin<Box<dyn Stream<Item = Result<Event, SseError>> + Send>>; + +/// SSE Client +pub struct SseClient { + base_url: String, +} + +impl SseClient { + /// Create new SSE client + pub fn new(base_url: impl Into<String>) -> Self { + Self { + base_url: base_url.into(), + } + } + + /// Connect to SSE endpoint and return event stream + pub fn connect(&self) -> Result<SseStream, SseError> { + let url = format!("{}/event", self.base_url); + + #[cfg(target_arch = "wasm32")] + { + wasm::connect_sse(&url) + } + + #[cfg(not(target_arch = "wasm32"))] + { + native::connect_sse(&url) + } + } +} + +#[cfg(target_arch = "wasm32")] +mod wasm { + use super::*; + use futures::channel::mpsc; + use wasm_bindgen::prelude::*; + use wasm_bindgen::JsCast; + use web_sys::{EventSource, MessageEvent}; + + pub fn connect_sse(url: &str) -> Result<SseStream, SseError> { + web_sys::console::log_1(&format!("SSE: Creating EventSource for {}", url).into()); + + let event_source = EventSource::new(url) + .map_err(|e| SseError::Network(format!("EventSource creation failed: {:?}", e)))?; + + web_sys::console::log_1( + &format!( + "SSE: EventSource created, readyState: {}", + event_source.ready_state() + ) + .into(), + ); + + let (tx, rx) = mpsc::unbounded::<Result<Event, SseError>>(); + + setup_listeners(&event_source, tx)?; + + Ok(Box::pin(rx)) + } + + fn setup_listeners( + source: &EventSource, + mut tx: mpsc::UnboundedSender<Result<Event, SseError>>, + ) -> Result<(), SseError> { + // Add open event listener + { + let cb = Closure::wrap(Box::new(move |_e: web_sys::Event| { + web_sys::console::log_1(&"SSE: Connection opened".into()); + }) as Box<dyn FnMut(web_sys::Event)>); + + source + .add_event_listener_with_callback("open", cb.as_ref().unchecked_ref()) + .map_err(|e| SseError::Network(format!("Open listener setup failed: {:?}", e)))?; + + cb.forget(); + } + + // Add error event listener + { + let cb = Closure::wrap(Box::new(move |e: web_sys::Event| { + web_sys::console::error_1(&format!("SSE: Error event: {:?}", e).into()); + }) as Box<dyn FnMut(web_sys::Event)>); + + source + .add_event_listener_with_callback("error", cb.as_ref().unchecked_ref()) + .map_err(|e| SseError::Network(format!("Error listener setup failed: {:?}", e)))?; + + cb.forget(); + } + + let events = [ + "message.updated", + "message.removed", + "message.part.updated", + "message.part.removed", + "session.updated", + "session.created", + "session.deleted", + "session.compacted", + "server.connected", + "session.idle", + "session.error", + "permission.updated", + "file.edited", + "todo.updated", + ]; + + web_sys::console::log_1( + &format!("SSE: Setting up listeners for {} event types", events.len()).into(), + ); + + for event_type in events { + let tx_clone = tx.clone(); + let evt = event_type.to_string(); + let evt_for_log = evt.clone(); + + let cb = Closure::wrap(Box::new(move |e: MessageEvent| { + web_sys::console::log_1(&format!("SSE: Received {} event", evt_for_log).into()); + + if let Some(data) = e.data().as_string() { + web_sys::console::log_1(&format!("SSE: Event data: {}", data).into()); + + // OpenCode sends the event type INSIDE the data JSON, not in the event: field + // So we parse the data directly instead of wrapping it + match serde_json::from_str(&data) { + Ok(event) => { + web_sys::console::log_1( + &format!("SSE: Parsed event successfully").into(), + ); + let _ = tx_clone.unbounded_send(Ok(event)); + } + Err(e) => { + web_sys::console::error_1(&format!("SSE: Parse error: {}", e).into()); + let _ = tx_clone.unbounded_send(Err(SseError::Parse(e.to_string()))); + } + } + } else { + web_sys::console::warn_1( + &format!("SSE: {} event has no data", evt_for_log).into(), + ); + } + }) as Box<dyn FnMut(MessageEvent)>); + + source + .add_event_listener_with_callback(event_type, cb.as_ref().unchecked_ref()) + .map_err(|e| SseError::Network(format!("Listener setup failed: {:?}", e)))?; + + cb.forget(); + } + + web_sys::console::log_1(&"SSE: All listeners setup complete".into()); + Ok(()) + } +} + +#[cfg(not(target_arch = "wasm32"))] +mod native { + use super::*; + use futures::channel::mpsc; + use futures::stream::StreamExt; + use futures::SinkExt; + + pub fn connect_sse(url: &str) -> Result<SseStream, SseError> { + let url = url.to_string(); + let (tx, rx) = mpsc::unbounded::<Result<Event, SseError>>(); + + tokio::spawn(async move { + let _ = stream_events(url, tx).await; + }); + + Ok(Box::pin(rx)) + } + + async fn stream_events( + url: String, + mut tx: mpsc::UnboundedSender<Result<Event, SseError>>, + ) -> Result<(), SseError> { + let response = reqwest::get(&url) + .await + .map_err(|e| SseError::Network(e.to_string()))?; + + let mut stream = response.bytes_stream(); + let mut buffer = String::new(); + + while let Some(chunk_result) = stream.next().await { + let chunk = chunk_result.map_err(|e| SseError::Network(e.to_string()))?; + buffer.push_str(&String::from_utf8_lossy(&chunk)); + + while let Some((event, remaining)) = parse_sse(&buffer) { + buffer = remaining; + if tx.send(Ok(event)).await.is_err() { + return Ok(()); + } + } + } + + let _ = tx.send(Err(SseError::Closed)).await; + Ok(()) + } + + /// Filter system prompts from JSON data for cleaner logging + fn filter_system_prompt(data: &str) -> String { + // Try to parse as JSON and filter out system prompts + match serde_json::from_str::<serde_json::Value>(data) { + Ok(mut json) => { + if let Some(properties) = json.get_mut("properties") { + if let Some(info) = properties.get_mut("info") { + if let Some(system) = info.get_mut("system") { + if let Some(arr) = system.as_array() { + if !arr.is_empty() { + *system = serde_json::json!(format!("[FILTERED {} items]", arr.len())); + } + } + } + } + } + serde_json::to_string(&json).unwrap_or_else(|_| data.to_string()) + } + Err(_) => data.to_string() + } + } + + fn parse_sse(buffer: &str) -> Option<(Event, String)> { + let mut event_type = String::new(); + let mut data_lines = Vec::new(); + let mut pos = 0; + + for line in buffer.lines() { + pos += line.len() + 1; + + if line.is_empty() { + if !data_lines.is_empty() { + let data = data_lines.join("\n"); + + // Filter system prompts for cleaner logging + let filtered_data = filter_system_prompt(&data); + eprintln!("[OpenCode SSE] event_type='{}', data={}", event_type, filtered_data); + + // OpenCode sends the event type INSIDE the data JSON, not in the event: field + // So we parse the data directly instead of wrapping it + match serde_json::from_str(&data) { + Ok(event) => { + eprintln!("[OpenCode SSE] Parsed event successfully"); + return Some((event, buffer[pos..].to_string())); + } + Err(e) => { + eprintln!("[OpenCode SSE] Parse error: {} for data: {}", e, filtered_data); + return None; + } + } + } + } else if let Some(val) = line.strip_prefix("event:") { + event_type = val.trim().to_string(); + } else if let Some(val) = line.strip_prefix("data:") { + data_lines.push(val.trim().to_string()); + } + } + + None + } +} diff --git a/crates/opencode_client/src/types.rs b/crates/opencode_client/src/types.rs new file mode 100644 index 0000000..7fe5958 --- /dev/null +++ b/crates/opencode_client/src/types.rs @@ -0,0 +1,572 @@ +//! OpenCode API Types +//! +//! Rust representations of the OpenCode API types based on opencode_types.ts + +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Session { + pub id: String, + #[serde(rename = "projectID")] + pub project_id: String, + pub directory: String, + #[serde(rename = "parentID", skip_serializing_if = "Option::is_none")] + pub parent_id: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option<SessionSummary>, + #[serde(skip_serializing_if = "Option::is_none")] + pub share: Option<SessionShare>, + pub title: String, + pub version: String, + pub time: SessionTime, + #[serde(skip_serializing_if = "Option::is_none")] + pub revert: Option<SessionRevert>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SessionSummary { + pub diffs: Vec<FileDiff>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SessionShare { + pub url: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SessionTime { + pub created: u64, + pub updated: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub compacting: Option<u64>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SessionRevert { + #[serde(rename = "messageID")] + pub message_id: String, + #[serde(rename = "partID", skip_serializing_if = "Option::is_none")] + pub part_id: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub snapshot: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub diff: Option<String>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FileDiff { + #[serde(alias = "path")] + pub file: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub additions: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub deletions: Option<u32>, + #[serde(skip_serializing_if = "Option::is_none")] + pub before: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub after: Option<String>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "role")] +pub enum Message { + #[serde(rename = "user")] + User(UserMessage), + #[serde(rename = "assistant")] + Assistant(AssistantMessage), +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct UserMessage { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + pub time: MessageTime, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option<UserMessageSummary>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct UserMessageSummary { + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub body: Option<String>, + pub diffs: Vec<FileDiff>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct AssistantMessage { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + pub time: AssistantMessageTime, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option<MessageError>, + #[serde(default)] + pub system: Vec<String>, + #[serde(rename = "parentID", skip_serializing_if = "Option::is_none")] + pub parent_id: Option<String>, + #[serde(rename = "modelID", skip_serializing_if = "Option::is_none")] + pub model_id: Option<String>, + #[serde(rename = "providerID", skip_serializing_if = "Option::is_none")] + pub provider_id: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option<MessagePath>, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option<bool>, + #[serde(default)] + pub cost: f64, + #[serde(default)] + pub tokens: TokenUsage, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MessageTime { + pub created: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct AssistantMessageTime { + pub created: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed: Option<u64>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MessagePath { + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub root: Option<String>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +pub struct TokenUsage { + pub input: u64, + pub output: u64, + pub reasoning: u64, + #[serde(default)] + pub cache: CacheUsage, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)] +pub struct CacheUsage { + pub read: u64, + pub write: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "name")] +pub enum MessageError { + ProviderAuthError { data: ProviderAuthErrorData }, + UnknownError { data: UnknownErrorData }, + MessageOutputLengthError, + MessageAbortedError { data: MessageAbortedErrorData }, + ApiError { data: ApiErrorData }, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ProviderAuthErrorData { + #[serde(rename = "providerID")] + pub provider_id: String, + pub message: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct UnknownErrorData { + pub message: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MessageAbortedErrorData { + pub message: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ApiErrorData { + pub message: String, + #[serde(rename = "statusCode", skip_serializing_if = "Option::is_none")] + pub status_code: Option<u16>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type")] +pub enum Part { + #[serde(rename = "text")] + Text(TextPart), + #[serde(rename = "reasoning")] + Reasoning(ReasoningPart), + #[serde(rename = "step-start")] + StepStart(GenericPart), + #[serde(rename = "step-finish")] + StepFinish(GenericPart), + #[serde(rename = "tool")] + Tool(ToolPart), + #[serde(rename = "file")] + File(FilePart), + #[serde(rename = "snapshot")] + Snapshot(SnapshotPart), + #[serde(rename = "patch")] + Patch(PatchPart), + #[serde(rename = "agent")] + Agent(AgentPart), + #[serde(rename = "retry")] + Retry(RetryPart), + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct PartTime { + pub start: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub end: Option<u64>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct GenericPart { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct TextPart { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, + pub text: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub synthetic: Option<bool>, + #[serde(skip_serializing_if = "Option::is_none")] + pub time: Option<PartTime>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ReasoningPart { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, + pub text: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub time: Option<PartTime>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "status")] +pub enum ToolState { + #[serde(rename = "pending")] + Pending, + #[serde(rename = "running")] + Running { + input: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + title: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + metadata: Option<serde_json::Value>, + time: PartTime, + }, + #[serde(rename = "completed")] + Completed { + input: serde_json::Value, + output: String, + title: String, + metadata: serde_json::Value, + time: PartTime, + #[serde(skip_serializing_if = "Option::is_none")] + attachments: Option<Vec<serde_json::Value>>, + }, + #[serde(rename = "error")] + Error { + input: serde_json::Value, + error: String, + #[serde(skip_serializing_if = "Option::is_none")] + metadata: Option<serde_json::Value>, + time: PartTime, + }, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct ToolPart { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, + #[serde(rename = "callID")] + pub call_id: String, + pub tool: String, + pub state: ToolState, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option<serde_json::Value>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FilePart { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, + pub mime: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub filename: Option<String>, + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option<serde_json::Value>, // Simplified - can expand later +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SnapshotPart { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, + pub snapshot: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct PatchPart { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, + pub hash: String, + pub files: Vec<String>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct AgentPart { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option<AgentSource>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct AgentSource { + pub value: String, + pub start: u64, + pub end: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct RetryPart { + pub id: String, + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, + pub attempt: u32, + pub error: ApiErrorData, // Reusing existing type + pub time: RetryTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct RetryTime { + pub created: u64, +} + +/// Response containing message info and its parts +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MessageWithParts { + pub info: Message, + pub parts: Vec<Part>, +} + +// ============================================================================ +// SSE Event Types +// ============================================================================ + +/// SSE Event types from OpenCode API +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(tag = "type")] +pub enum Event { + #[serde(rename = "server.connected")] + ServerConnected { properties: serde_json::Value }, + #[serde(rename = "session.created")] + SessionCreated { properties: SessionEventInfo }, + #[serde(rename = "session.updated")] + SessionUpdated { properties: SessionEventInfo }, + #[serde(rename = "session.deleted")] + SessionDeleted { properties: SessionEventInfo }, + #[serde(rename = "session.compacted")] + SessionCompacted { properties: SessionIdOnly }, + #[serde(rename = "session.idle")] + SessionIdle { properties: SessionIdOnly }, + #[serde(rename = "session.error")] + SessionError { properties: SessionErrorInfo }, + #[serde(rename = "message.updated")] + MessageUpdated { properties: MessageEventInfo }, + #[serde(rename = "message.removed")] + MessageRemoved { properties: MessageRemovedInfo }, + #[serde(rename = "message.part.updated")] + MessagePartUpdated { properties: MessagePartEventInfo }, + #[serde(rename = "message.part.removed")] + MessagePartRemoved { properties: MessagePartRemovedInfo }, + #[serde(rename = "permission.updated")] + PermissionUpdated { properties: Permission }, + #[serde(rename = "permission.replied")] + PermissionReplied { properties: PermissionReplyInfo }, + #[serde(rename = "file.edited")] + FileEdited { properties: FileEditedInfo }, + #[serde(rename = "file.watcher.updated")] + FileWatcherUpdated { properties: FileWatcherInfo }, + #[serde(rename = "todo.updated")] + TodoUpdated { properties: TodoEventInfo }, + #[serde(rename = "lsp.client.diagnostics")] + LspClientDiagnostics { properties: LspDiagnosticsInfo }, + #[serde(rename = "installation.updated")] + InstallationUpdated { properties: InstallationInfo }, + #[serde(rename = "ide.installed")] + IdeInstalled { properties: IdeInfo }, + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SessionEventInfo { + pub info: Session, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SessionIdOnly { + #[serde(rename = "sessionID")] + pub session_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct SessionErrorInfo { + #[serde(rename = "sessionID", skip_serializing_if = "Option::is_none")] + pub session_id: Option<String>, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option<MessageError>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MessageEventInfo { + pub info: Message, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MessageRemovedInfo { + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MessagePartEventInfo { + pub part: Part, + #[serde(skip_serializing_if = "Option::is_none")] + pub delta: Option<String>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct MessagePartRemovedInfo { + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, + #[serde(rename = "partID")] + pub part_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Permission { + pub id: String, + #[serde(rename = "type")] + pub permission_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub pattern: Option<serde_json::Value>, // Can be string or array + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "messageID")] + pub message_id: String, + #[serde(rename = "callID", skip_serializing_if = "Option::is_none")] + pub call_id: Option<String>, + pub title: String, + pub metadata: serde_json::Value, + pub time: PermissionTime, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct PermissionTime { + pub created: u64, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct PermissionReplyInfo { + #[serde(rename = "sessionID")] + pub session_id: String, + #[serde(rename = "permissionID")] + pub permission_id: String, + pub response: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FileEditedInfo { + pub file: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct FileWatcherInfo { + pub file: String, + pub event: FileWatcherEvent, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum FileWatcherEvent { + Add, + Change, + Unlink, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct TodoEventInfo { + #[serde(rename = "sessionID")] + pub session_id: String, + pub todos: Vec<Todo>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct Todo { + pub content: String, + pub status: String, + pub priority: String, + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct LspDiagnosticsInfo { + #[serde(rename = "serverID")] + pub server_id: String, + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct InstallationInfo { + pub version: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub struct IdeInfo { + pub ide: String, +} diff --git a/dirigent.svg b/dirigent.svg new file mode 100644 index 0000000..71b961c --- /dev/null +++ b/dirigent.svg @@ -0,0 +1 @@ +<svg width="16" height="16" viewBox="0 0 32 32" xmlns="http://www.w3.org/2000/svg"> <!-- main crescent --> <path d="M10 5.8 C20 5.8 27 10.7 27 16 C27 23.2 20 27 10 27 C18.5 25.7 23 22 23 16 C23 12.2 18.5 8.5 10 5.8Z" fill="currentColor"/> <!-- subtle rounded thickening along lower curve --> <path d="M11 22 C18 25.5 20.5 22.5 22.8 19.5 C23.5 23.5 19 26.5 10 27 C8.5 25.5 9.6 23.5 11 22Z" fill="currentColor"/> <!-- larger, slightly lowered dot --> <circle cx="13.9" cy="16.9" r="4" fill="currentColor"/> </svg>