sync from monorepo @ 2452e92e

This commit is contained in:
2026-05-08 01:59:04 +02:00
commit b03dc15371
459 changed files with 129586 additions and 0 deletions
+124
View File
@@ -0,0 +1,124 @@
# Package: dirigent_acp_api
ACP Server implementation for accepting incoming ACP connections from external agents.
## Quick Facts
- **Type**: Library
- **Main Entry**: src/lib.rs
- **Dependencies**: axum, tokio, serde, tracing, uuid, async-trait, dirigent_protocol
- **Status**: Core structure complete, integration with CoreRuntime pending
## Overview
The `dirigent_acp_api` package implements an ACP (Agent-Client Protocol) server that allows Dirigent to accept incoming connections from external ACP clients like Claude Code or custom agents. This enables session sharing, remote orchestration, and multi-client collaboration.
## Architecture
### Core Components
- **config.rs** - Server configuration types (`AcpServerConfig`)
- **error.rs** - Error types (`AcpServerError`, `JsonRpcErrorObject`)
- **jsonrpc.rs** - JSON-RPC 2.0 types and parsing
- **rpc.rs** - RPC handler and method dispatch
- **session_manager.rs** - Session/client tracking (TODO)
- **sse.rs** - SSE notification system (TODO)
- **event_bridge.rs** - Event translation (TODO)
### Key Types
```rust
pub struct AcpServerConfig {
pub enabled: bool, // Enable/disable server
pub port: u16, // Listen port (default: 3001)
pub allowed_origins: Option<Vec<String>>, // CORS origins
pub max_connections: usize, // Connection limit (default: 100)
}
```
### ConnectorOperations Trait
The RPC handler uses a trait abstraction to avoid circular dependencies with dirigent_core:
```rust
#[async_trait]
pub trait ConnectorOperations: Send + Sync {
async fn create_session(&self, connector_id: &str) -> Result<String>;
async fn load_session(&self, connector_id: &str, session_id: &str) -> Result<Session>;
async fn send_prompt(&self, connector_id: &str, session_id: &str, prompt: &str) -> Result<String>;
// ... more methods
}
```
## API Endpoints
### POST `/rpc`
JSON-RPC 2.0 endpoint supporting:
- `initialize` - Client handshake
- `session/new` - Create session
- `session/load` - Load existing session
- `session/prompt` - Send prompt
- `session/cancel` - Cancel generation
- `session/close` - Close session
### GET `/events`
Server-Sent Events for streaming notifications:
- `acp/messageChunk` - Streaming content
- `acp/messageComplete` - Generation complete
- `acp/sessionIdle` - Ready for input
### GET `/health`
Health check endpoint.
## Configuration UI
The ACP Server is configured via the web UI at **Configuration > ACP Server**:
- Enable/disable toggle
- Port configuration
- Max connections limit
- Allowed origins (CORS)
- Default connector selection
- Connected clients management
Server functions in `crates/api/src/acp_server.rs` bridge the UI and this package.
## Implementation Status
**Completed:**
- Configuration types (`AcpServerConfig`)
- Error types (`AcpServerError`)
- JSON-RPC types and parsing
- RPC handler structure with ConnectorOperations trait
**Pending:**
- Session Manager implementation
- SSE Notifier implementation
- Event Bridge implementation
- Axum router integration
- Web server integration
## Key Files
| File | Description |
|------|-------------|
| `src/lib.rs` | Module exports and router creation |
| `src/config.rs` | AcpServerConfig with validation |
| `src/error.rs` | Error types and codes |
| `src/jsonrpc.rs` | JSON-RPC 2.0 implementation |
| `src/rpc.rs` | RPC handler and method dispatch |
## Related Packages
- **dirigent_core** - Provides CoreHandle implementation of ConnectorOperations
- **dirigent_protocol** - Shared event and message types
- **api** - Server functions for UI configuration
- **web** - Configuration UI components
## Documentation
- **Architecture**: `docs/architecture/acp_server.md`
- **Configuration**: `docs/configuration/acp-connectors.md`
- **Tasks**: `docs/building/07_acp_serve/02_acp_server_tasks.md`
+30
View File
@@ -0,0 +1,30 @@
[package]
name = "dirigent_acp_api"
version = "0.1.0"
edition = "2021"
[lib]
path = "src/lib.rs"
[dependencies]
anyhow = "1.0"
async-trait = "0.1"
# ACP (Agent-Client Protocol) dependencies
axum = "0.8"
# ACP Server dependencies (Phase 2)
chrono = { version = "0.4", features = ["serde"] }
# Workspace dependencies
# Note: dirigent_protocol is used for Event types
dirigent_protocol = { path = "../dirigent_protocol" }
# Note: dirigent_core is NOT a direct dependency to avoid circular dependency.
# CoreHandle is passed into the ACP server at runtime via generics or trait objects.
# The mode/model mapping logic is duplicated here for legacy mode support.
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0.1", features = ["sync"] }
tower = "0.5"
tower-http = { version = "0.6", features = ["cors"] }
# ACP Server dependencies (Phase 1)
tracing = "0.1"
uuid = { version = "1.0", features = ["serde", "v4", "v7"] }
+5
View File
@@ -0,0 +1,5 @@
This crate should expose an ACP API for dirigent to be used by another ACP Client.
Funny Integration Test would be to use Dirigent using Dirigent (but that would need probably a dummy acp agent to be used by that)
This will however also require dirigent to "pass through" functionality, it believes to be responsible for.
@@ -0,0 +1,438 @@
//! Agent Request Tracker
//!
//! This module provides infrastructure for tracking agent-initiated requests
//! that require responses from HTTP clients. It's used to implement bidirectional
//! JSON-RPC communication in the ACP Server.
//!
//! ## Use Case
//!
//! When an agent (like Claude) sends a request to the client (like a permission
//! prompt), the ACP Server needs to:
//! 1. Forward the request to the client via SSE
//! 2. Wait for the client's response via HTTP POST
//! 3. Deliver the response back to the agent
//!
//! The `AgentRequestTracker` manages the pending requests and provides a way
//! to correlate responses with their corresponding requests.
//!
//! ## Example Flow
//!
//! ```rust,ignore
//! use dirigent_acp_api::agent_requests::AgentRequestTracker;
//! use serde_json::json;
//!
//! // Create tracker
//! let tracker = AgentRequestTracker::new();
//!
//! // Agent sends request - register it and get receiver
//! let request_id = json!(0);
//! let client_id = "client-123";
//! let receiver = tracker.register(client_id, request_id.clone());
//!
//! // Forward request to client via SSE...
//!
//! // Client responds via HTTP POST - complete the request
//! let response = json!({"selectedOptionId": "allow"});
//! tracker.complete(client_id, request_id, response)?;
//!
//! // The receiver now gets the response
//! let response_value = receiver.await?;
//! ```
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use anyhow::{anyhow, Result};
use serde_json::Value;
use tokio::sync::oneshot;
use tracing::{debug, warn};
/// Tracks pending agent requests awaiting client responses
///
/// This struct provides thread-safe storage for correlating agent requests
/// with their eventual client responses. It uses oneshot channels to deliver
/// responses to waiting tasks.
///
/// ## Thread Safety
///
/// The tracker is designed to be cloned and shared across async tasks.
/// Internal state is protected by `Arc<Mutex<...>>` for thread-safe access.
#[derive(Debug, Clone)]
pub struct AgentRequestTracker {
/// Maps (client_id, request_id_string) to oneshot sender for delivering response
///
/// The key is a tuple of client ID and request ID (as string) to uniquely
/// identify each pending request. The value is a oneshot sender that will
/// be used to deliver the response when it arrives.
pending: Arc<Mutex<HashMap<(String, String), oneshot::Sender<Value>>>>,
}
impl Default for AgentRequestTracker {
fn default() -> Self {
Self::new()
}
}
impl AgentRequestTracker {
/// Create a new agent request tracker
///
/// Returns an empty tracker ready to register pending requests.
pub fn new() -> Self {
Self {
pending: Arc::new(Mutex::new(HashMap::new())),
}
}
/// Register a pending agent request and return a receiver for the response
///
/// This method creates a oneshot channel, stores the sender in the pending
/// requests map, and returns the receiver. The caller can await on the
/// receiver to get the client's response.
///
/// # Parameters
///
/// - `client_id`: The ID of the client that should respond to this request
/// - `request_id`: The request ID from the agent (JSON-RPC id field)
///
/// # Returns
///
/// A oneshot receiver that will receive the client's response when
/// `complete()` is called with matching client_id and request_id.
///
/// # Example
///
/// ```rust,ignore
/// let receiver = tracker.register("client-123", json!(0));
///
/// // Later, when client responds...
/// tracker.complete("client-123", json!(0), response)?;
///
/// // The receiver gets the response
/// let response = receiver.await?;
/// ```
pub fn register(&self, client_id: &str, request_id: Value) -> oneshot::Receiver<Value> {
let (tx, rx) = oneshot::channel();
let key = (client_id.to_string(), request_id.to_string());
let mut pending = self.pending.lock().expect("Lock poisoned");
pending.insert(key.clone(), tx);
debug!(
client_id = %client_id,
request_id = %request_id,
"Registered pending agent request"
);
rx
}
/// Complete a pending agent request with the client's response
///
/// This method looks up the pending request, sends the response through
/// the oneshot channel, and removes it from the pending map.
///
/// # Parameters
///
/// - `client_id`: The ID of the client sending the response
/// - `request_id`: The request ID from the original agent request
/// - `response`: The client's response (JSON-RPC response object)
///
/// # Returns
///
/// - `Ok(())` if the request was found and the response was delivered
/// - `Err` if the request_id was not found in pending requests
///
/// # Errors
///
/// Returns an error if:
/// - The request_id is not found in pending requests (may have timed out)
/// - The receiver has been dropped (unlikely but possible)
///
/// # Example
///
/// ```rust,ignore
/// // Client POSTs response to /acp/agent_response
/// let response = json!({
/// "jsonrpc": "2.0",
/// "id": 0,
/// "result": {"selectedOptionId": "allow"}
/// });
///
/// tracker.complete("client-123", json!(0), response)?;
/// ```
pub fn complete(&self, client_id: &str, request_id: Value, response: Value) -> Result<()> {
let key = (client_id.to_string(), request_id.to_string());
let mut pending = self.pending.lock().expect("Lock poisoned");
if let Some(sender) = pending.remove(&key) {
debug!(
client_id = %client_id,
request_id = %request_id,
"Completing pending agent request"
);
// Send the response through the oneshot channel
sender.send(response).map_err(|_| {
anyhow!(
"Failed to send response for request {}: receiver dropped",
request_id
)
})?;
Ok(())
} else {
warn!(
client_id = %client_id,
request_id = %request_id,
"Attempted to complete non-existent agent request (may have timed out)"
);
Err(anyhow!(
"Request ID {} not found for client {}",
request_id,
client_id
))
}
}
/// Timeout a pending agent request
///
/// This method removes a pending request from the map and logs a timeout
/// warning. It should be called when a request has been pending for too
/// long (e.g., 30 seconds) without a client response.
///
/// # Parameters
///
/// - `client_id`: The ID of the client that was supposed to respond
/// - `request_id`: The request ID that timed out
///
/// # Note
///
/// This method does not send any error response through the channel.
/// The receiver will get a `RecvError` when it tries to receive, which
/// the caller should interpret as a timeout.
///
/// # Example
///
/// ```rust,ignore
/// use tokio::time::{timeout, Duration};
///
/// let receiver = tracker.register("client-123", json!(0));
///
/// // Wait up to 30 seconds for response
/// match timeout(Duration::from_secs(30), receiver).await {
/// Ok(Ok(response)) => {
/// // Got response
/// }
/// Ok(Err(_)) => {
/// // Channel closed (timeout was called)
/// tracker.timeout("client-123", json!(0));
/// }
/// Err(_) => {
/// // Timeout elapsed
/// tracker.timeout("client-123", json!(0));
/// }
/// }
/// ```
pub fn timeout(&self, client_id: &str, request_id: Value) {
let key = (client_id.to_string(), request_id.to_string());
let mut pending = self.pending.lock().expect("Lock poisoned");
if pending.remove(&key).is_some() {
warn!(
client_id = %client_id,
request_id = %request_id,
"Agent request timed out (30s elapsed without client response)"
);
}
}
/// Get the number of pending requests
///
/// This method returns the total number of requests currently awaiting
/// client responses across all clients.
pub fn pending_count(&self) -> usize {
let pending = self.pending.lock().expect("Lock poisoned");
pending.len()
}
/// Get the number of pending requests for a specific client
///
/// # Parameters
///
/// - `client_id`: The ID of the client to query
pub fn client_pending_count(&self, client_id: &str) -> usize {
let pending = self.pending.lock().expect("Lock poisoned");
pending
.keys()
.filter(|(cid, _)| cid == client_id)
.count()
}
/// Clear all pending requests (used when shutting down or on client disconnect)
///
/// This method removes all pending requests from the tracker. The oneshot
/// senders are dropped, which will cause their receivers to get `RecvError`.
///
/// # Parameters
///
/// - `client_id`: Optional client ID to clear only that client's pending requests.
/// If None, clears all pending requests.
pub fn clear(&self, client_id: Option<&str>) {
let mut pending = self.pending.lock().expect("Lock poisoned");
match client_id {
Some(id) => {
let keys_to_remove: Vec<_> = pending
.keys()
.filter(|(cid, _)| cid == id)
.cloned()
.collect();
for key in keys_to_remove {
pending.remove(&key);
}
debug!(
client_id = %id,
"Cleared all pending agent requests for client"
);
}
None => {
let count = pending.len();
pending.clear();
debug!(
count = count,
"Cleared all pending agent requests"
);
}
}
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[tokio::test]
async fn test_register_and_complete() {
let tracker = AgentRequestTracker::new();
let client_id = "client-123";
let request_id = json!(0);
let response = json!({"result": "success"});
// Register request
let receiver = tracker.register(client_id, request_id.clone());
assert_eq!(tracker.pending_count(), 1);
// Complete request
let result = tracker.complete(client_id, request_id, response.clone());
assert!(result.is_ok());
assert_eq!(tracker.pending_count(), 0);
// Receiver should get the response
let received = receiver.await.unwrap();
assert_eq!(received, response);
}
#[tokio::test]
async fn test_complete_non_existent_request() {
let tracker = AgentRequestTracker::new();
let client_id = "client-123";
let request_id = json!(999);
let response = json!({"result": "success"});
// Try to complete non-existent request
let result = tracker.complete(client_id, request_id, response);
assert!(result.is_err());
}
#[tokio::test]
async fn test_timeout() {
let tracker = AgentRequestTracker::new();
let client_id = "client-123";
let request_id = json!(0);
// Register request
let receiver = tracker.register(client_id, request_id.clone());
assert_eq!(tracker.pending_count(), 1);
// Timeout the request
tracker.timeout(client_id, request_id);
assert_eq!(tracker.pending_count(), 0);
// Receiver should get error (channel closed)
let result = receiver.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_multiple_pending_requests() {
let tracker = AgentRequestTracker::new();
let client1 = "client-1";
let client2 = "client-2";
// Register multiple requests
let _rx1 = tracker.register(client1, json!(0));
let _rx2 = tracker.register(client1, json!(1));
let _rx3 = tracker.register(client2, json!(0));
assert_eq!(tracker.pending_count(), 3);
assert_eq!(tracker.client_pending_count(client1), 2);
assert_eq!(tracker.client_pending_count(client2), 1);
}
#[tokio::test]
async fn test_clear_all() {
let tracker = AgentRequestTracker::new();
let _rx1 = tracker.register("client-1", json!(0));
let _rx2 = tracker.register("client-2", json!(0));
assert_eq!(tracker.pending_count(), 2);
tracker.clear(None);
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_clear_client() {
let tracker = AgentRequestTracker::new();
let client1 = "client-1";
let client2 = "client-2";
let _rx1 = tracker.register(client1, json!(0));
let _rx2 = tracker.register(client1, json!(1));
let _rx3 = tracker.register(client2, json!(0));
assert_eq!(tracker.pending_count(), 3);
// Clear only client1's requests
tracker.clear(Some(client1));
assert_eq!(tracker.pending_count(), 1);
assert_eq!(tracker.client_pending_count(client1), 0);
assert_eq!(tracker.client_pending_count(client2), 1);
}
#[test]
fn test_tracker_clone() {
let tracker = AgentRequestTracker::new();
let tracker_clone = tracker.clone();
// Both should point to same underlying state
let _rx = tracker.register("client-1", json!(0));
assert_eq!(tracker_clone.pending_count(), 1);
}
}
+275
View File
@@ -0,0 +1,275 @@
//! Configuration types for the ACP Server
//!
//! This module defines configuration types for the ACP Server, including
//! server settings, connection limits, and CORS configuration.
use serde::{Deserialize, Serialize};
/// Default port for the ACP Server
pub const DEFAULT_PORT: u16 = 3001;
/// Default maximum number of concurrent connections
pub const DEFAULT_MAX_CONNECTIONS: usize = 100;
/// Configuration for the ACP Server
///
/// This struct contains all configurable options for the ACP Server,
/// including network settings, security options, and resource limits.
///
/// **Note**: This config is used when starting an actual TCP server
/// (separate port mode only). For integrated mode (mounting at /acp),
/// the port field is still required but represents which port was chosen
/// for the separate server path. Higher-level configs (dirigent_core, api)
/// use `Option<u16>` to represent the integrated vs separate distinction.
///
/// TODO: Consider moving the AcpPortConfig enum from web package to core
/// and using it here to make the integrated/separate distinction explicit.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AcpServerConfig {
/// Whether the ACP Server is enabled
///
/// When disabled, the server will not accept incoming connections.
/// Default: false
#[serde(default)]
pub enabled: bool,
/// The port to listen on for incoming connections
///
/// This is always a concrete port number because this config is used
/// to start an actual TCP server. Use Option<u16> in higher-level configs
/// to represent integrated mode (None) vs separate mode (Some(port)).
///
/// Default: 3001
#[serde(default = "default_port")]
pub port: u16,
/// List of allowed origins for CORS
///
/// When Some, only requests from these origins are allowed.
/// When None, all origins are allowed (use with caution).
///
/// Example: ["http://localhost:3000", "https://app.example.com"]
#[serde(default)]
pub allowed_origins: Option<Vec<String>>,
/// Maximum number of concurrent client connections
///
/// New connections will be rejected when this limit is reached.
/// Default: 100
#[serde(default = "default_max_connections")]
pub max_connections: usize,
}
/// Returns the default port value
fn default_port() -> u16 {
DEFAULT_PORT
}
/// Returns the default max connections value
fn default_max_connections() -> usize {
DEFAULT_MAX_CONNECTIONS
}
impl Default for AcpServerConfig {
fn default() -> Self {
Self {
enabled: false,
port: DEFAULT_PORT,
allowed_origins: None,
max_connections: DEFAULT_MAX_CONNECTIONS,
}
}
}
impl AcpServerConfig {
/// Create a new configuration with default values
pub fn new() -> Self {
Self::default()
}
/// Create a configuration with the server enabled
pub fn enabled() -> Self {
Self {
enabled: true,
..Default::default()
}
}
/// Create a configuration with a specific port
pub fn with_port(port: u16) -> Self {
Self {
port,
..Default::default()
}
}
/// Set whether the server is enabled
pub fn set_enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
/// Set the port
pub fn set_port(mut self, port: u16) -> Self {
self.port = port;
self
}
/// Set the allowed origins
pub fn set_allowed_origins(mut self, origins: Option<Vec<String>>) -> Self {
self.allowed_origins = origins;
self
}
/// Set the maximum number of connections
pub fn set_max_connections(mut self, max: usize) -> Self {
self.max_connections = max;
self
}
/// Check if the configuration is valid
pub fn validate(&self) -> Result<(), String> {
if self.port == 0 {
return Err("Port cannot be 0".to_string());
}
if self.max_connections == 0 {
return Err("max_connections must be at least 1".to_string());
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = AcpServerConfig::default();
assert!(!config.enabled);
assert_eq!(config.port, DEFAULT_PORT);
assert!(config.allowed_origins.is_none());
assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
}
#[test]
fn test_enabled_config() {
let config = AcpServerConfig::enabled();
assert!(config.enabled);
assert_eq!(config.port, DEFAULT_PORT);
}
#[test]
fn test_with_port() {
let config = AcpServerConfig::with_port(8080);
assert!(!config.enabled);
assert_eq!(config.port, 8080);
}
#[test]
fn test_builder_pattern() {
let config = AcpServerConfig::new()
.set_enabled(true)
.set_port(4000)
.set_allowed_origins(Some(vec!["http://localhost:3000".to_string()]))
.set_max_connections(50);
assert!(config.enabled);
assert_eq!(config.port, 4000);
assert_eq!(
config.allowed_origins,
Some(vec!["http://localhost:3000".to_string()])
);
assert_eq!(config.max_connections, 50);
}
#[test]
fn test_validation_valid() {
let config = AcpServerConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_validation_invalid_port() {
let config = AcpServerConfig {
port: 0,
..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_validation_invalid_max_connections() {
let config = AcpServerConfig {
max_connections: 0,
..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_serialization() {
let config = AcpServerConfig {
enabled: true,
port: 3001,
allowed_origins: Some(vec!["http://localhost:3000".to_string()]),
max_connections: 100,
};
let json = serde_json::to_string(&config).unwrap();
assert!(json.contains("\"enabled\":true"));
assert!(json.contains("\"port\":3001"));
assert!(json.contains("\"allowed_origins\""));
assert!(json.contains("\"max_connections\":100"));
// Deserialize back
let parsed: AcpServerConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, config);
}
#[test]
fn test_deserialization_with_defaults() {
// Minimal JSON with defaults
let json = r#"{"enabled":true}"#;
let config: AcpServerConfig = serde_json::from_str(json).unwrap();
assert!(config.enabled);
assert_eq!(config.port, DEFAULT_PORT);
assert!(config.allowed_origins.is_none());
assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
}
#[test]
fn test_deserialization_empty() {
// Empty JSON should use all defaults
let json = "{}";
let config: AcpServerConfig = serde_json::from_str(json).unwrap();
assert!(!config.enabled);
assert_eq!(config.port, DEFAULT_PORT);
assert!(config.allowed_origins.is_none());
assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
}
#[test]
fn test_equality() {
let config1 = AcpServerConfig::default();
let config2 = AcpServerConfig::default();
assert_eq!(config1, config2);
let config3 = AcpServerConfig::enabled();
assert_ne!(config1, config3);
}
#[test]
fn test_clone() {
let config = AcpServerConfig::enabled()
.set_port(8080)
.set_allowed_origins(Some(vec!["origin".to_string()]));
let cloned = config.clone();
assert_eq!(config, cloned);
}
}
+362
View File
@@ -0,0 +1,362 @@
//! Error types for the ACP Server
//!
//! This module defines error types used throughout the ACP Server implementation,
//! including conversions to JSON-RPC error format for client responses.
use std::fmt;
use serde::{Deserialize, Serialize};
/// Standard JSON-RPC error codes as defined in the specification
pub mod error_codes {
/// Parse error - Invalid JSON was received by the server
pub const PARSE_ERROR: i32 = -32700;
/// Invalid Request - The JSON sent is not a valid Request object
pub const INVALID_REQUEST: i32 = -32600;
/// Method not found - The method does not exist / is not available
pub const METHOD_NOT_FOUND: i32 = -32601;
/// Invalid params - Invalid method parameter(s)
pub const INVALID_PARAMS: i32 = -32602;
/// Internal error - Internal JSON-RPC error
pub const INTERNAL_ERROR: i32 = -32603;
// Server errors reserved for implementation-defined errors (-32000 to -32099)
/// Session not found error
pub const SESSION_NOT_FOUND: i32 = -32001;
/// Connector not found error
pub const CONNECTOR_NOT_FOUND: i32 = -32002;
/// Invalid session error
pub const INVALID_SESSION: i32 = -32003;
/// Transport error
pub const TRANSPORT_ERROR: i32 = -32004;
/// Client not found error
pub const CLIENT_NOT_FOUND: i32 = -32005;
}
/// ACP Server error enum representing all possible error conditions
#[derive(Debug, Clone, PartialEq)]
pub enum AcpServerError {
/// The session ID provided is invalid or malformed
InvalidSession,
/// An RPC-related error occurred
///
/// Contains a description of the RPC error, such as method not found,
/// invalid params, or parse errors.
RpcError(String),
/// A transport-level error occurred
///
/// Contains a description of the transport error, such as connection
/// failures, timeout, or network issues.
TransportError(String),
/// The requested session was not found
///
/// The session ID does not correspond to any known session in the
/// session manager.
SessionNotFound,
/// The requested connector was not found
///
/// The connector ID does not correspond to any registered connector
/// in the runtime. Contains the connector ID that was not found.
ConnectorNotFound(String),
/// The connector is not in a ready state
///
/// The connector exists but is not available to handle requests
/// (e.g., still connecting, error state, or stopped).
/// Contains the connector ID that is not ready.
ConnectorNotReady(String),
/// Operation timed out
///
/// Contains a description of what timed out.
Timeout(String),
/// An internal server error occurred
///
/// Contains a description of the internal error. Used for unexpected
/// conditions that don't fit other categories.
Internal(String),
/// The requested client was not found
///
/// The client ID does not correspond to any connected client in the
/// session manager. Contains the client ID that was not found.
ClientNotFound(String),
}
impl fmt::Display for AcpServerError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AcpServerError::InvalidSession => {
write!(f, "Invalid session ID provided")
}
AcpServerError::RpcError(msg) => {
write!(f, "RPC error: {}", msg)
}
AcpServerError::TransportError(msg) => {
write!(f, "Transport error: {}", msg)
}
AcpServerError::SessionNotFound => {
write!(f, "Session not found")
}
AcpServerError::ConnectorNotFound(id) => {
write!(f, "Connector not found: {}", id)
}
AcpServerError::ConnectorNotReady(id) => {
write!(f, "Connector not ready: {}", id)
}
AcpServerError::Timeout(msg) => {
write!(f, "Timeout: {}", msg)
}
AcpServerError::Internal(msg) => {
write!(f, "Internal error: {}", msg)
}
AcpServerError::ClientNotFound(id) => {
write!(f, "Client not found: {}", id)
}
}
}
}
impl std::error::Error for AcpServerError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
// None of our error variants wrap other errors currently
None
}
}
/// JSON-RPC error representation for wire format
///
/// This struct is used in JSON-RPC responses when an error occurs.
/// It follows the JSON-RPC 2.0 specification for error objects.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct JsonRpcErrorObject {
/// A Number that indicates the error type that occurred
pub code: i32,
/// A String providing a short description of the error
pub message: String,
/// Optional additional information about the error
#[serde(skip_serializing_if = "Option::is_none")]
pub data: Option<serde_json::Value>,
}
impl JsonRpcErrorObject {
/// Create a new JSON-RPC error object
pub fn new(code: i32, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
data: None,
}
}
/// Create a new JSON-RPC error object with additional data
pub fn with_data(code: i32, message: impl Into<String>, data: serde_json::Value) -> Self {
Self {
code,
message: message.into(),
data: Some(data),
}
}
/// Create a parse error
pub fn parse_error(message: impl Into<String>) -> Self {
Self::new(error_codes::PARSE_ERROR, message)
}
/// Create an invalid request error
pub fn invalid_request(message: impl Into<String>) -> Self {
Self::new(error_codes::INVALID_REQUEST, message)
}
/// Create a method not found error
pub fn method_not_found(method: impl Into<String>) -> Self {
Self::new(
error_codes::METHOD_NOT_FOUND,
format!("Method not found: {}", method.into()),
)
}
/// Create an invalid params error
pub fn invalid_params(message: impl Into<String>) -> Self {
Self::new(error_codes::INVALID_PARAMS, message)
}
/// Create an internal error
pub fn internal_error(message: impl Into<String>) -> Self {
Self::new(error_codes::INTERNAL_ERROR, message)
}
}
impl From<AcpServerError> for JsonRpcErrorObject {
fn from(error: AcpServerError) -> Self {
match error {
AcpServerError::InvalidSession => {
JsonRpcErrorObject::new(error_codes::INVALID_SESSION, "Invalid session ID provided")
}
AcpServerError::RpcError(msg) => {
JsonRpcErrorObject::new(error_codes::INTERNAL_ERROR, msg)
}
AcpServerError::TransportError(msg) => {
JsonRpcErrorObject::new(error_codes::TRANSPORT_ERROR, msg)
}
AcpServerError::SessionNotFound => {
JsonRpcErrorObject::new(error_codes::SESSION_NOT_FOUND, "Session not found")
}
AcpServerError::ConnectorNotFound(id) => {
JsonRpcErrorObject::new(error_codes::CONNECTOR_NOT_FOUND, format!("Connector not found: {}", id))
}
AcpServerError::ConnectorNotReady(id) => {
JsonRpcErrorObject::new(error_codes::CONNECTOR_NOT_FOUND, format!("Connector not ready: {}", id))
}
AcpServerError::Timeout(msg) => {
JsonRpcErrorObject::new(error_codes::TRANSPORT_ERROR, format!("Timeout: {}", msg))
}
AcpServerError::Internal(msg) => JsonRpcErrorObject::internal_error(msg),
AcpServerError::ClientNotFound(id) => {
JsonRpcErrorObject::new(error_codes::CLIENT_NOT_FOUND, format!("Client not found: {}", id))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
assert_eq!(
AcpServerError::InvalidSession.to_string(),
"Invalid session ID provided"
);
assert_eq!(
AcpServerError::RpcError("test".to_string()).to_string(),
"RPC error: test"
);
assert_eq!(
AcpServerError::TransportError("timeout".to_string()).to_string(),
"Transport error: timeout"
);
assert_eq!(
AcpServerError::SessionNotFound.to_string(),
"Session not found"
);
assert_eq!(
AcpServerError::ConnectorNotFound("test-conn".to_string()).to_string(),
"Connector not found: test-conn"
);
assert_eq!(
AcpServerError::ConnectorNotReady("test-conn".to_string()).to_string(),
"Connector not ready: test-conn"
);
assert_eq!(
AcpServerError::Timeout("request timed out".to_string()).to_string(),
"Timeout: request timed out"
);
assert_eq!(
AcpServerError::Internal("something went wrong".to_string()).to_string(),
"Internal error: something went wrong"
);
assert_eq!(
AcpServerError::ClientNotFound("client-123".to_string()).to_string(),
"Client not found: client-123"
);
}
#[test]
fn test_error_to_jsonrpc() {
let error: JsonRpcErrorObject = AcpServerError::InvalidSession.into();
assert_eq!(error.code, error_codes::INVALID_SESSION);
assert_eq!(error.message, "Invalid session ID provided");
let error: JsonRpcErrorObject = AcpServerError::SessionNotFound.into();
assert_eq!(error.code, error_codes::SESSION_NOT_FOUND);
let error: JsonRpcErrorObject = AcpServerError::ConnectorNotFound("conn1".to_string()).into();
assert_eq!(error.code, error_codes::CONNECTOR_NOT_FOUND);
assert!(error.message.contains("conn1"));
let error: JsonRpcErrorObject = AcpServerError::ConnectorNotReady("conn2".to_string()).into();
assert_eq!(error.code, error_codes::CONNECTOR_NOT_FOUND);
assert!(error.message.contains("conn2"));
let error: JsonRpcErrorObject = AcpServerError::Timeout("session creation".to_string()).into();
assert_eq!(error.code, error_codes::TRANSPORT_ERROR);
let error: JsonRpcErrorObject = AcpServerError::TransportError("net error".to_string()).into();
assert_eq!(error.code, error_codes::TRANSPORT_ERROR);
assert_eq!(error.message, "net error");
let error: JsonRpcErrorObject = AcpServerError::Internal("internal".to_string()).into();
assert_eq!(error.code, error_codes::INTERNAL_ERROR);
let error: JsonRpcErrorObject = AcpServerError::ClientNotFound("client-456".to_string()).into();
assert_eq!(error.code, error_codes::CLIENT_NOT_FOUND);
assert!(error.message.contains("client-456"));
}
#[test]
fn test_jsonrpc_error_serialization() {
let error = JsonRpcErrorObject::new(error_codes::PARSE_ERROR, "Invalid JSON");
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains("-32700"));
assert!(json.contains("Invalid JSON"));
// With data
let error = JsonRpcErrorObject::with_data(
error_codes::INVALID_PARAMS,
"Missing field",
serde_json::json!({"field": "session_id"}),
);
let json = serde_json::to_string(&error).unwrap();
assert!(json.contains("session_id"));
}
#[test]
fn test_jsonrpc_error_factories() {
let error = JsonRpcErrorObject::parse_error("bad json");
assert_eq!(error.code, error_codes::PARSE_ERROR);
let error = JsonRpcErrorObject::invalid_request("missing jsonrpc field");
assert_eq!(error.code, error_codes::INVALID_REQUEST);
let error = JsonRpcErrorObject::method_not_found("session.unknown");
assert_eq!(error.code, error_codes::METHOD_NOT_FOUND);
assert!(error.message.contains("session.unknown"));
let error = JsonRpcErrorObject::invalid_params("session_id required");
assert_eq!(error.code, error_codes::INVALID_PARAMS);
let error = JsonRpcErrorObject::internal_error("panic");
assert_eq!(error.code, error_codes::INTERNAL_ERROR);
}
#[test]
fn test_error_is_error_trait() {
fn assert_is_error<T: std::error::Error>() {}
assert_is_error::<AcpServerError>();
}
#[test]
fn test_error_clone() {
let error = AcpServerError::Internal("test".to_string());
let cloned = error.clone();
assert_eq!(error, cloned);
}
}
File diff suppressed because it is too large Load Diff
+460
View File
@@ -0,0 +1,460 @@
//! JSON-RPC 2.0 types for the ACP Server
//!
//! This module implements JSON-RPC 2.0 request/response types according to
//! the specification at https://www.jsonrpc.org/specification.
//!
//! Key features:
//! - Support for both numeric and string IDs
//! - Batch request/response handling
//! - Proper serialization of null vs missing fields
use serde::{Deserialize, Serialize};
use crate::error::JsonRpcErrorObject;
/// JSON-RPC protocol version constant
pub const JSONRPC_VERSION: &str = "2.0";
/// JSON-RPC request/response identifier
///
/// According to the JSON-RPC 2.0 spec, an id can be a String, Number,
/// or Null. This type uses an untagged enum to handle both string and
/// number identifiers.
///
/// Note: The spec recommends not using Null as an id for requests.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum JsonRpcId {
/// Numeric identifier (integer)
Number(i64),
/// String identifier
String(String),
/// Null identifier (typically used in error responses for invalid requests)
Null,
}
impl From<i64> for JsonRpcId {
fn from(n: i64) -> Self {
JsonRpcId::Number(n)
}
}
impl From<String> for JsonRpcId {
fn from(s: String) -> Self {
JsonRpcId::String(s)
}
}
impl From<&str> for JsonRpcId {
fn from(s: &str) -> Self {
JsonRpcId::String(s.to_string())
}
}
/// A JSON-RPC 2.0 request object
///
/// Represents a remote procedure call with optional parameters.
/// The `id` field determines whether this is a request (with id) or
/// notification (without id).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcRequest {
/// JSON-RPC protocol version (must be "2.0")
pub jsonrpc: String,
/// A String containing the name of the method to be invoked
pub method: String,
/// Optional structured value that holds the parameter values
#[serde(skip_serializing_if = "Option::is_none")]
pub params: Option<serde_json::Value>,
/// Optional identifier established by the client
///
/// If absent, the request is a notification (no response expected)
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<JsonRpcId>,
}
impl JsonRpcRequest {
/// Create a new JSON-RPC request
pub fn new(method: impl Into<String>, params: Option<serde_json::Value>, id: JsonRpcId) -> Self {
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
method: method.into(),
params,
id: Some(id),
}
}
/// Create a new JSON-RPC notification (request without id)
pub fn notification(method: impl Into<String>, params: Option<serde_json::Value>) -> Self {
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
method: method.into(),
params,
id: None,
}
}
/// Check if this request is a notification (no id)
pub fn is_notification(&self) -> bool {
self.id.is_none()
}
/// Validate the request format
pub fn validate(&self) -> Result<(), String> {
if self.jsonrpc != JSONRPC_VERSION {
return Err(format!(
"Invalid JSON-RPC version: expected '{}', got '{}'",
JSONRPC_VERSION, self.jsonrpc
));
}
if self.method.is_empty() {
return Err("Method name cannot be empty".to_string());
}
// Methods starting with "rpc." are reserved for internal use
if self.method.starts_with("rpc.") {
return Err(format!(
"Method name '{}' is reserved (starts with 'rpc.')",
self.method
));
}
Ok(())
}
}
/// A JSON-RPC 2.0 response object
///
/// Contains either a result (success) or an error (failure), never both.
/// The id must match the corresponding request id.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonRpcResponse {
/// JSON-RPC protocol version (must be "2.0")
pub jsonrpc: String,
/// The result of the call (on success)
///
/// This member is REQUIRED on success and MUST NOT exist on error.
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
/// The error object (on failure)
///
/// This member is REQUIRED on error and MUST NOT exist on success.
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<JsonRpcErrorObject>,
/// The identifier matching the request
///
/// If there was an error detecting the id in the Request object
/// (e.g. Parse error/Invalid Request), it MUST be Null.
pub id: JsonRpcId,
}
impl JsonRpcResponse {
/// Create a successful response
pub fn success(result: serde_json::Value, id: JsonRpcId) -> Self {
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
result: Some(result),
error: None,
id,
}
}
/// Create an error response
pub fn error(error: JsonRpcErrorObject, id: JsonRpcId) -> Self {
Self {
jsonrpc: JSONRPC_VERSION.to_string(),
result: None,
error: Some(error),
id,
}
}
/// Create an error response with a null id (for parse errors)
pub fn error_with_null_id(error: JsonRpcErrorObject) -> Self {
Self::error(error, JsonRpcId::Null)
}
/// Check if this response represents success
pub fn is_success(&self) -> bool {
self.result.is_some() && self.error.is_none()
}
/// Check if this response represents an error
pub fn is_error(&self) -> bool {
self.error.is_some()
}
}
/// Represents either a single request or a batch of requests
///
/// The JSON-RPC 2.0 spec allows sending multiple requests in a single
/// JSON array for batch processing.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum JsonRpcRequestBatch {
/// A single request
Single(JsonRpcRequest),
/// A batch of requests
Batch(Vec<JsonRpcRequest>),
}
impl JsonRpcRequestBatch {
/// Check if this is an empty batch
pub fn is_empty(&self) -> bool {
match self {
JsonRpcRequestBatch::Single(_) => false,
JsonRpcRequestBatch::Batch(batch) => batch.is_empty(),
}
}
/// Get the number of requests
pub fn len(&self) -> usize {
match self {
JsonRpcRequestBatch::Single(_) => 1,
JsonRpcRequestBatch::Batch(batch) => batch.len(),
}
}
/// Convert to a vector of requests
pub fn into_vec(self) -> Vec<JsonRpcRequest> {
match self {
JsonRpcRequestBatch::Single(req) => vec![req],
JsonRpcRequestBatch::Batch(batch) => batch,
}
}
}
/// Represents either a single response or a batch of responses
///
/// The response format must match the request format: single request
/// gets single response, batch request gets batch response.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum JsonRpcResponseBatch {
/// A single response
Single(JsonRpcResponse),
/// A batch of responses
Batch(Vec<JsonRpcResponse>),
}
impl JsonRpcResponseBatch {
/// Create a batch response from a vector
///
/// Returns Single if there's exactly one response, otherwise Batch.
pub fn from_vec(responses: Vec<JsonRpcResponse>) -> Self {
if responses.len() == 1 {
JsonRpcResponseBatch::Single(responses.into_iter().next().unwrap())
} else {
JsonRpcResponseBatch::Batch(responses)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_jsonrpc_id_number() {
let id = JsonRpcId::Number(42);
let json = serde_json::to_string(&id).unwrap();
assert_eq!(json, "42");
let parsed: JsonRpcId = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, id);
}
#[test]
fn test_jsonrpc_id_string() {
let id = JsonRpcId::String("abc-123".to_string());
let json = serde_json::to_string(&id).unwrap();
assert_eq!(json, "\"abc-123\"");
let parsed: JsonRpcId = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, id);
}
#[test]
fn test_jsonrpc_id_null() {
let id = JsonRpcId::Null;
let json = serde_json::to_string(&id).unwrap();
assert_eq!(json, "null");
let parsed: JsonRpcId = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, id);
}
#[test]
fn test_jsonrpc_id_from_conversions() {
let id: JsonRpcId = 42i64.into();
assert_eq!(id, JsonRpcId::Number(42));
let id: JsonRpcId = "test".into();
assert_eq!(id, JsonRpcId::String("test".to_string()));
let id: JsonRpcId = String::from("owned").into();
assert_eq!(id, JsonRpcId::String("owned".to_string()));
}
#[test]
fn test_request_creation() {
let req = JsonRpcRequest::new("session.new", Some(json!({"title": "Test"})), 1.into());
assert_eq!(req.jsonrpc, "2.0");
assert_eq!(req.method, "session.new");
assert!(req.params.is_some());
assert_eq!(req.id, Some(JsonRpcId::Number(1)));
assert!(!req.is_notification());
}
#[test]
fn test_notification_creation() {
let notif = JsonRpcRequest::notification("event.ping", None);
assert!(notif.is_notification());
assert_eq!(notif.id, None);
}
#[test]
fn test_request_validation() {
let valid = JsonRpcRequest::new("test.method", None, 1.into());
assert!(valid.validate().is_ok());
// Invalid version
let mut invalid_version = valid.clone();
invalid_version.jsonrpc = "1.0".to_string();
assert!(invalid_version.validate().is_err());
// Empty method
let mut empty_method = valid.clone();
empty_method.method = String::new();
assert!(empty_method.validate().is_err());
// Reserved method
let mut reserved = valid.clone();
reserved.method = "rpc.internal".to_string();
assert!(reserved.validate().is_err());
}
#[test]
fn test_request_serialization() {
let req = JsonRpcRequest::new(
"session.prompt",
Some(json!({"session_id": "abc", "content": "Hello"})),
"req-123".into(),
);
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"jsonrpc\":\"2.0\""));
assert!(json.contains("\"method\":\"session.prompt\""));
assert!(json.contains("\"id\":\"req-123\""));
// Deserialize back
let parsed: JsonRpcRequest = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.method, "session.prompt");
}
#[test]
fn test_response_success() {
let resp = JsonRpcResponse::success(json!({"session_id": "new-123"}), 1.into());
assert!(resp.is_success());
assert!(!resp.is_error());
assert_eq!(resp.result, Some(json!({"session_id": "new-123"})));
assert!(resp.error.is_none());
}
#[test]
fn test_response_error() {
let error = JsonRpcErrorObject::method_not_found("unknown.method");
let resp = JsonRpcResponse::error(error, 1.into());
assert!(!resp.is_success());
assert!(resp.is_error());
assert!(resp.result.is_none());
assert!(resp.error.is_some());
}
#[test]
fn test_response_serialization() {
// Success response
let success = JsonRpcResponse::success(json!({"ok": true}), 42.into());
let json = serde_json::to_string(&success).unwrap();
assert!(json.contains("\"result\""));
assert!(!json.contains("\"error\""));
// Error response
let error = JsonRpcResponse::error(
JsonRpcErrorObject::internal_error("Something broke"),
42.into(),
);
let json = serde_json::to_string(&error).unwrap();
assert!(!json.contains("\"result\""));
assert!(json.contains("\"error\""));
}
#[test]
fn test_batch_request_single() {
let req = JsonRpcRequest::new("test", None, 1.into());
let batch = JsonRpcRequestBatch::Single(req);
assert_eq!(batch.len(), 1);
assert!(!batch.is_empty());
let vec = batch.into_vec();
assert_eq!(vec.len(), 1);
}
#[test]
fn test_batch_request_multiple() {
let req1 = JsonRpcRequest::new("test1", None, 1.into());
let req2 = JsonRpcRequest::new("test2", None, 2.into());
let batch = JsonRpcRequestBatch::Batch(vec![req1, req2]);
assert_eq!(batch.len(), 2);
let vec = batch.into_vec();
assert_eq!(vec.len(), 2);
}
#[test]
fn test_batch_request_empty() {
let batch = JsonRpcRequestBatch::Batch(vec![]);
assert!(batch.is_empty());
assert_eq!(batch.len(), 0);
}
#[test]
fn test_batch_request_deserialization() {
// Single request
let single_json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
let single: JsonRpcRequestBatch = serde_json::from_str(single_json).unwrap();
assert_eq!(single.len(), 1);
// Batch request
let batch_json = r#"[{"jsonrpc":"2.0","method":"test1","id":1},{"jsonrpc":"2.0","method":"test2","id":2}]"#;
let batch: JsonRpcRequestBatch = serde_json::from_str(batch_json).unwrap();
assert_eq!(batch.len(), 2);
}
#[test]
fn test_batch_response_from_vec() {
// Single response becomes Single variant
let responses = vec![JsonRpcResponse::success(json!(null), 1.into())];
let batch = JsonRpcResponseBatch::from_vec(responses);
matches!(batch, JsonRpcResponseBatch::Single(_));
// Multiple responses become Batch variant
let responses = vec![
JsonRpcResponse::success(json!(null), 1.into()),
JsonRpcResponse::success(json!(null), 2.into()),
];
let batch = JsonRpcResponseBatch::from_vec(responses);
matches!(batch, JsonRpcResponseBatch::Batch(_));
}
}
+116
View File
@@ -0,0 +1,116 @@
//! Dirigent ACP API
//!
//! This crate exposes an ACP (Agent-Client Protocol) API for Dirigent,
//! allowing other ACP clients to interact with Dirigent.
//!
//! ## Modules
//!
//! - [`config`] - Server configuration types
//! - [`error`] - Error types and JSON-RPC error conversion
//! - [`event_bridge`] - Event forwarding from source streams to SSE clients
//! - [`jsonrpc`] - JSON-RPC 2.0 request/response types
//! - [`router`] - Axum router and HTTP handlers
//! - [`rpc`] - JSON-RPC request handler and method dispatch
//! - [`session_manager`] - Session mapping and client connection tracking
//! - [`sse`] - SSE notifications for streaming events to clients
//!
//! ## Example
//!
//! ```rust,ignore
//! use dirigent_acp_api::{
//! AcpServerConfig, NoOpConnectorOperations,
//! router::{AcpServerState, create_acp_server_router},
//! };
//!
//! // Create server configuration
//! let config = AcpServerConfig::enabled()
//! .set_port(3001)
//! .set_max_connections(100);
//!
//! // Create server state
//! let state = AcpServerState::new(config);
//!
//! // Create the router with connector operations
//! let router = create_acp_server_router(state, NoOpConnectorOperations);
//!
//! // Run with axum
//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await?;
//! axum::serve(listener, router).await?;
//! ```
// Modules
pub mod agent_requests;
pub mod config;
pub mod error;
pub mod event_bridge;
pub mod jsonrpc;
pub mod router;
pub mod rpc;
pub mod session_manager;
pub mod sse;
// Re-exports for convenience
pub use agent_requests::AgentRequestTracker;
pub use config::AcpServerConfig;
pub use error::{AcpServerError, JsonRpcErrorObject};
pub use event_bridge::{EventBridge, EventBridgeConfig};
pub use jsonrpc::{
JsonRpcId, JsonRpcRequest, JsonRpcRequestBatch, JsonRpcResponse, JsonRpcResponseBatch,
JSONRPC_VERSION,
};
pub use router::{create_acp_server_router, AcpServerState, RouterState};
pub use rpc::{
ConnectorInfo, ConnectorOperations, NoOpConnectorOperations, RpcHandler, SessionInfo,
ACP_PROTOCOL_VERSION, SERVER_NAME,
};
pub use session_manager::{ClientConnection, ClientInfo, SessionManager, SessionMapping};
pub use sse::{AcpNotification, SseNotifier, translate_event};
use axum::{response::Json, routing::get, Router};
use serde::{Deserialize, Serialize};
#[derive(Debug, Serialize, Deserialize)]
pub struct ApiInfo {
pub name: String,
pub version: String,
pub description: String,
}
/// Create the ACP API router
pub fn create_api_router() -> Router {
Router::new()
.route("/", get(api_info))
.route("/health", get(health_check))
}
/// Get API information
async fn api_info() -> Json<ApiInfo> {
Json(ApiInfo {
name: "dirigent_acp_api".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
description: "ACP API for Dirigent - orchestrates agentic clients".to_string(),
})
}
/// Health check endpoint
async fn health_check() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "healthy",
"message": "Dirigent ACP API is running"
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_api_info() {
let info = ApiInfo {
name: "test".to_string(),
version: "0.1.0".to_string(),
description: "test description".to_string(),
};
assert_eq!(info.name, "test");
}
}
+950
View File
@@ -0,0 +1,950 @@
//! Axum Router Integration for ACP Server
//!
//! This module provides the Axum router and HTTP handlers for the ACP Server.
//! It exposes a JSON-RPC endpoint at `/rpc` and an SSE events endpoint at `/events`.
//!
//! ## Architecture
//!
//! The router is generic over `ConnectorOperations`, allowing different implementations
//! to be used (e.g., one backed by `CoreHandle` for production, or `NoOpConnectorOperations`
//! for testing).
//!
//! ## Endpoints
//!
//! - `POST /rpc` - JSON-RPC 2.0 endpoint for method calls
//! - `GET /events?client_id=...` - SSE stream for real-time notifications
//! - `GET /health` - Health check endpoint
//!
//! ## Example
//!
//! ```rust,ignore
//! use dirigent_acp_api::router::{AcpServerState, create_acp_server_router};
//! use dirigent_acp_api::{AcpServerConfig, NoOpConnectorOperations};
//!
//! // Create state
//! let state = AcpServerState::new(AcpServerConfig::enabled());
//!
//! // Create router with no-op operations for testing
//! let router = create_acp_server_router(state, NoOpConnectorOperations);
//! ```
use std::convert::Infallible;
use std::sync::Arc;
use std::time::Instant;
use axum::{
extract::{Query, State},
http::{HeaderMap, StatusCode},
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse, Json,
},
routing::{get, post},
Router,
};
use serde::{Deserialize, Serialize};
use tokio_stream::StreamExt;
use tower_http::cors::{Any, CorsLayer};
use tracing::{debug, info, trace, warn};
use crate::config::AcpServerConfig;
use crate::rpc::{ConnectorOperations, RpcHandler};
use crate::session_manager::SessionManager;
use crate::sse::SseNotifier;
// ============================================================================
// AcpServerState (T031)
// ============================================================================
/// Internal state shared across handlers
struct AcpServerStateInner {
/// Session manager for tracking client sessions
session_manager: SessionManager,
/// SSE notifier for broadcasting events to clients
sse_notifier: SseNotifier,
/// Agent request tracker for bidirectional request/response
agent_request_tracker: Arc<crate::agent_requests::AgentRequestTracker>,
/// Server configuration
config: AcpServerConfig,
}
/// Shared state for the ACP Server Axum handlers
///
/// This struct contains all the state needed by the HTTP handlers:
/// - Session manager for tracking client sessions and mappings
/// - SSE notifier for broadcasting events to connected clients
/// - Server configuration
///
/// The state is wrapped in `Arc` internally, making it cheap to clone and
/// share across async tasks and handlers.
///
/// Note: `RpcHandler` is created per-request with `ConnectorOperations` passed in,
/// as the connector operations implementation cannot be stored in the shared state
/// (it may contain non-Clone types like `CoreHandle`).
#[derive(Clone)]
pub struct AcpServerState {
inner: Arc<AcpServerStateInner>,
}
impl AcpServerState {
/// Create a new ACP server state with the given configuration
///
/// # Parameters
///
/// - `config`: The server configuration
///
/// # Example
///
/// ```rust
/// use dirigent_acp_api::router::AcpServerState;
/// use dirigent_acp_api::AcpServerConfig;
///
/// let state = AcpServerState::new(AcpServerConfig::enabled());
/// ```
pub fn new(config: AcpServerConfig) -> Self {
Self {
inner: Arc::new(AcpServerStateInner {
session_manager: SessionManager::new(),
sse_notifier: SseNotifier::new(),
agent_request_tracker: Arc::new(crate::agent_requests::AgentRequestTracker::new()),
config,
}),
}
}
/// Create a new ACP server state with custom components
///
/// This is useful for testing or when you need to share a session manager
/// or SSE notifier with other parts of the application.
///
/// # Parameters
///
/// - `session_manager`: The session manager instance
/// - `sse_notifier`: The SSE notifier instance
/// - `agent_request_tracker`: The agent request tracker instance
/// - `config`: The server configuration
pub fn with_components(
session_manager: SessionManager,
sse_notifier: SseNotifier,
agent_request_tracker: Arc<crate::agent_requests::AgentRequestTracker>,
config: AcpServerConfig,
) -> Self {
Self {
inner: Arc::new(AcpServerStateInner {
session_manager,
sse_notifier,
agent_request_tracker,
config,
}),
}
}
/// Get a reference to the session manager
pub fn session_manager(&self) -> &SessionManager {
&self.inner.session_manager
}
/// Get a reference to the SSE notifier
pub fn sse_notifier(&self) -> &SseNotifier {
&self.inner.sse_notifier
}
/// Get a reference to the agent request tracker
pub fn agent_request_tracker(&self) -> &Arc<crate::agent_requests::AgentRequestTracker> {
&self.inner.agent_request_tracker
}
/// Get a reference to the configuration
pub fn config(&self) -> &AcpServerConfig {
&self.inner.config
}
}
impl std::fmt::Debug for AcpServerState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AcpServerState")
.field("session_count", &self.inner.session_manager.mapping_count())
.field("client_count", &self.inner.sse_notifier.client_count())
.field("config", &self.inner.config)
.finish()
}
}
// ============================================================================
// Router State with ConnectorOperations (T028)
// ============================================================================
/// Combined state for Axum handlers that includes both server state and connector operations
///
/// This struct combines the shared `AcpServerState` with a `ConnectorOperations`
/// implementation. It's used as the Axum state to provide handlers with access
/// to both session management and connector functionality.
#[derive(Clone)]
pub struct RouterState<C: ConnectorOperations + Clone + Send + Sync + 'static> {
/// The ACP server state (session manager, SSE notifier, config)
pub state: AcpServerState,
/// The connector operations implementation
pub connector_ops: C,
}
impl<C: ConnectorOperations + Clone + Send + Sync + 'static> RouterState<C> {
/// Create a new router state
pub fn new(state: AcpServerState, connector_ops: C) -> Self {
Self {
state,
connector_ops,
}
}
}
// ============================================================================
// Router Factory (T028)
// ============================================================================
/// Create the ACP server router with all endpoints configured
///
/// This function creates an Axum router with the following endpoints:
/// - `POST /rpc` - JSON-RPC 2.0 endpoint
/// - `GET /events` - SSE events stream
/// - `GET /health` - Health check
///
/// CORS middleware is applied based on the configuration.
///
/// # Type Parameters
///
/// - `C`: The connector operations implementation
///
/// # Parameters
///
/// - `state`: The ACP server state
/// - `connector_ops`: The connector operations implementation
///
/// # Example
///
/// ```rust,ignore
/// use dirigent_acp_api::router::{AcpServerState, create_acp_server_router};
/// use dirigent_acp_api::{AcpServerConfig, NoOpConnectorOperations};
///
/// let state = AcpServerState::new(AcpServerConfig::enabled());
/// let router = create_acp_server_router(state, NoOpConnectorOperations);
///
/// // Use with axum server
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await?;
/// axum::serve(listener, router).await?;
/// ```
pub fn create_acp_server_router<C>(state: AcpServerState, connector_ops: C) -> Router
where
C: ConnectorOperations + Clone + Send + Sync + 'static,
{
// Build CORS layer based on configuration
let cors = build_cors_layer(&state.config());
// Create combined router state
let router_state = RouterState::new(state, connector_ops);
// Build the router (T023: added /agent_response route)
Router::new()
.route("/rpc", post(handle_rpc::<C>))
.route("/events", get(handle_sse::<C>))
.route("/health", get(handle_health::<C>))
.route("/agent_response", post(handle_agent_response::<C>))
.layer(cors)
.with_state(router_state)
}
/// Build CORS layer from configuration
fn build_cors_layer(config: &AcpServerConfig) -> CorsLayer {
let cors = CorsLayer::new()
.allow_methods([
axum::http::Method::GET,
axum::http::Method::POST,
axum::http::Method::OPTIONS,
])
.allow_headers(Any);
match &config.allowed_origins {
Some(origins) if !origins.is_empty() => {
// Parse origins and set specific allowed origins
let parsed_origins: Vec<_> = origins
.iter()
.filter_map(|o| o.parse().ok())
.collect();
if parsed_origins.is_empty() {
warn!("No valid origins in allowed_origins, allowing any origin");
cors.allow_origin(Any)
} else {
debug!("CORS configured with {} allowed origins", parsed_origins.len());
cors.allow_origin(parsed_origins)
}
}
_ => {
debug!("CORS configured to allow any origin");
cors.allow_origin(Any)
}
}
}
// ============================================================================
// Request/Response Types
// ============================================================================
/// Query parameters for the SSE events endpoint
#[derive(Debug, Deserialize)]
pub struct SseQuery {
/// The client ID (required for SSE subscription)
pub client_id: Option<String>,
}
/// Health check response
#[derive(Debug, Serialize, Deserialize)]
pub struct HealthResponse {
/// Health status
pub status: String,
/// Server message
pub message: String,
/// Number of connected clients
pub clients: usize,
/// Number of active sessions
pub sessions: usize,
}
/// Error response for SSE endpoint
#[derive(Debug, Serialize, Deserialize)]
pub struct SseErrorResponse {
/// Error message
pub error: String,
/// Error code
pub code: String,
}
// ============================================================================
// Endpoint Handlers (T029, T030)
// ============================================================================
/// Handle POST /rpc requests (T029)
///
/// Extracts the JSON body, processes it through the RPC handler, and returns
/// the JSON-RPC response.
async fn handle_rpc<C>(
State(router_state): State<RouterState<C>>,
headers: HeaderMap,
body: String,
) -> impl IntoResponse
where
C: ConnectorOperations + Clone + Send + Sync + 'static,
{
debug!("Received RPC request: {} bytes", body.len());
// Extract client_id from X-Client-ID header
let client_id = headers
.get("X-Client-ID")
.and_then(|v| v.to_str().ok());
// Extract select_connector from X-Select-Connector header (sent during initialize)
let select_connector = headers
.get("X-Select-Connector")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if let Some(id) = client_id {
info!(client_id = %id, "Received X-Client-ID header from RPC request");
debug!("RPC headers: {:?}", headers);
} else {
warn!("No X-Client-ID header provided in RPC request");
debug!("Available headers: {:?}", headers.keys().collect::<Vec<_>>());
}
// Log the incoming request
debug!("Received RPC request body: {}", serde_json::to_string(&body).unwrap_or_else(|_| "failed to serialize".to_string()));
// Create an RPC handler for this request
let handler = RpcHandler::new(
router_state.state.session_manager().clone(),
router_state.connector_ops.clone(),
router_state.state.sse_notifier().clone(),
router_state.state.agent_request_tracker().clone(),
);
// Process the request with the client_id from header
let response = handler.handle_request(&body, client_id).await;
// If this was an initialize request with X-Select-Connector header,
// store the preference after the client is registered
if let Some(ref connector) = select_connector {
// Check if this is an initialize response by looking for clientId in result
if let crate::jsonrpc::JsonRpcResponseBatch::Single(ref resp) = response {
if let Some(ref result) = resp.result {
if let Some(new_client_id) = result.get("clientId").and_then(|v| v.as_str()) {
router_state.state.session_manager().update_client_preferred_connector(
new_client_id,
Some(connector.clone()),
);
info!(
client_id = %new_client_id,
select_connector = %connector,
"Stored preferred connector from X-Select-Connector header"
);
}
}
}
}
// Log the response being sent
debug!("Sending RPC response: {}", serde_json::to_string(&response).unwrap_or_else(|_| "failed to serialize".to_string()));
// Return JSON response
Json(response)
}
/// Handle GET /events SSE requests (T030)
///
/// Creates an SSE stream subscription for the client. The client_id must be
/// provided as a query parameter.
async fn handle_sse<C>(
State(router_state): State<RouterState<C>>,
Query(query): Query<SseQuery>,
) -> Result<Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<SseErrorResponse>)>
where
C: ConnectorOperations + Clone + Send + Sync + 'static,
{
// Validate client_id
let client_id = match query.client_id {
Some(id) if !id.is_empty() => id,
_ => {
warn!("SSE request missing client_id");
return Err((
StatusCode::BAD_REQUEST,
Json(SseErrorResponse {
error: "Missing required parameter: client_id".to_string(),
code: "MISSING_CLIENT_ID".to_string(),
}),
));
}
};
info!("SSE subscription requested for client: {}", client_id);
debug!(
"SSE notifier state before subscribe: client_count={}, subscribed_clients={:?}",
router_state.state.sse_notifier().client_count(),
router_state.state.sse_notifier().subscribed_clients()
);
// Subscribe to notifications
let notifier = router_state.state.sse_notifier();
let notification_stream = notifier.subscribe(&client_id);
info!(
"SSE client subscribed: client_id={}, total_clients={}, is_subscribed={}",
client_id,
notifier.client_count(),
notifier.is_subscribed(&client_id)
);
// Check if this client has any sessions that need tool updates
// This handles the race condition where session/new happens before SSE subscription
let session_manager = router_state.state.session_manager();
let client_sessions = session_manager.list_client_sessions(&client_id);
if !client_sessions.is_empty() {
debug!(
"Client {} has {} existing sessions, sending tool updates",
client_id,
client_sessions.len()
);
// For each session, get the connector and send tool updates
for session_id in client_sessions {
if let Some(mapping) = session_manager.get_mapping(&session_id) {
debug!(
"Fetching tool updates for session {} on connector {}",
session_id,
mapping.connector_id
);
// Get available commands from the connector
match router_state.connector_ops.get_connector_commands(&mapping.connector_id).await {
Ok(commands) => {
let update_params = crate::sse::SessionUpdateParams {
session_id: session_id.clone(),
update: crate::sse::SessionUpdateVariant::AvailableCommandsUpdate {
available_commands: commands.clone(),
},
event_type_override: None,
};
// Broadcast the tool update now that client is subscribed
match notifier.broadcast(&client_id, update_params) {
Ok(n) => {
info!(
"Sent deferred available_commands_update for session {}: {} commands to {} receivers",
session_id,
commands.len(),
n
);
}
Err(_) => {
warn!(
"Failed to send deferred available_commands_update for session {}",
session_id
);
}
}
}
Err(e) => {
warn!(
"Failed to get connector commands for session {}: {}",
session_id,
e
);
}
}
}
}
}
// Map the notification stream to SSE events
let client_id_for_log = client_id.clone();
let sse_stream = notification_stream.map(move |result| {
match result {
Ok(notification) => {
// T013: Start timing for SSE event creation
let event_start = Instant::now();
// Convert notification to SSE event
// Use to_sse_json() which handles raw events correctly
let event_type = notification.event_type();
let data = notification.to_sse_json();
// T011: Log before SSE event is written (T023: includes session_id for correlation)
trace!(
client_id = %client_id_for_log,
event_type = %event_type,
session_id = %notification.session_id,
data_len = data.len(),
"Writing SSE event to client stream"
);
debug!(
"SSE: Sending event to client {}: type={}, session_id={}, data_len={}",
client_id_for_log,
event_type,
notification.session_id,
data.len()
);
trace!("SSE event data: {}", data);
let event = Event::default()
.event(event_type)
.data(data);
// T012: Log after Event::default() construction (T023: includes session_id for correlation)
trace!(
client_id = %client_id_for_log,
event_type = %event_type,
session_id = %notification.session_id,
"SSE event constructed, sending to client"
);
// T013: Log SSE event write completion with timing (T023: includes session_id for correlation)
let elapsed_ms = event_start.elapsed().as_millis();
trace!(
client_id = %client_id_for_log,
event_type = %event_type,
session_id = %notification.session_id,
elapsed_ms = elapsed_ms,
"SSE event write completed"
);
// T014: Warn for slow SSE event writes (T023: includes session_id for correlation)
if elapsed_ms > 100 {
warn!(
elapsed_ms = elapsed_ms,
client_id = %client_id_for_log,
event_type = %event_type,
session_id = %notification.session_id,
"Slow SSE event write detected"
);
}
Ok(event)
}
Err(e) => {
// Broadcast stream error (e.g., lagged receiver)
// We send an error event but keep the stream open
warn!("SSE stream error for client {}: {:?}", client_id_for_log, e);
Ok(Event::default()
.event("error")
.data(format!(r#"{{"error":"Stream error: {:?}"}}"#, e)))
}
}
});
// Return SSE response with keep-alive
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}
/// Handle GET /health requests
///
/// Returns basic health information about the server.
async fn handle_health<C>(
State(router_state): State<RouterState<C>>,
) -> Json<HealthResponse>
where
C: ConnectorOperations + Clone + Send + Sync + 'static,
{
Json(HealthResponse {
status: "healthy".to_string(),
message: "Dirigent ACP Server is running".to_string(),
clients: router_state.state.sse_notifier().client_count(),
sessions: router_state.state.session_manager().mapping_count(),
})
}
/// Response for agent response endpoint
#[derive(Debug, Serialize, Deserialize)]
pub struct AgentResponseResult {
/// Status of the response
pub status: String,
}
/// Handle POST /agent_response requests (T019)
///
/// Accepts JSON-RPC responses from clients for pending agent requests.
/// When a client approves/denies a permission request, they POST the
/// response to this endpoint, which completes the pending request.
/// (Note: When nested at /acp, the full path becomes /acp/agent_response)
///
/// Expected body format:
/// ```json
/// {
/// "jsonrpc": "2.0",
/// "id": 0,
/// "result": { "selectedOptionId": "allow" }
/// }
/// ```
///
/// The response delivery to the connector happens through the oneshot channel
/// registered in the event bridge (Phase 2.4). The tracker's `complete()` method
/// sends the response value through the oneshot channel, and the event bridge task
/// (which registered the request) will receive it and send the `ConnectorCommand::AgentResponse`.
async fn handle_agent_response<C>(
State(router_state): State<RouterState<C>>,
headers: HeaderMap,
Json(response): Json<serde_json::Value>,
) -> Result<Json<AgentResponseResult>, (StatusCode, String)>
where
C: ConnectorOperations + Clone + Send + Sync + 'static,
{
// Extract client_id from X-Client-ID header (T020)
let client_id = match headers.get("X-Client-ID").and_then(|v| v.to_str().ok()) {
Some(id) if !id.is_empty() => id,
_ => {
warn!("Agent response request missing X-Client-ID header");
return Err((
StatusCode::BAD_REQUEST,
"Missing required header: X-Client-ID".to_string(),
));
}
};
debug!(
"Received agent response from client: {}",
client_id
);
trace!("Agent response body: {}", response);
// Extract request_id from JSON body (T020)
let request_id = match response.get("id") {
Some(id) => id.clone(),
None => {
warn!(
client_id = %client_id,
"Agent response missing 'id' field in body"
);
return Err((
StatusCode::BAD_REQUEST,
"Missing required field: id".to_string(),
));
}
};
info!(
client_id = %client_id,
request_id = %request_id,
"Processing agent response"
);
// Call AgentRequestTracker::complete() (T021)
let tracker = router_state.state.agent_request_tracker();
match tracker.complete(client_id, request_id.clone(), response.clone()) {
Ok(()) => {
info!(
client_id = %client_id,
request_id = %request_id,
"Agent response delivered successfully"
);
// Note (T022): The response delivery to the connector happens through
// the oneshot channel in the event bridge. When the event bridge handles
// an `Event::AgentRequest`, it registers the request with the tracker and
// gets a receiver. After we complete the request here, the receiver in the
// event bridge gets the response value and sends the `ConnectorCommand::AgentResponse`
// to the connector. This flow is implemented in Phase 2.4 (T024-T030).
Ok(Json(AgentResponseResult {
status: "ok".to_string(),
}))
}
Err(e) => {
warn!(
client_id = %client_id,
request_id = %request_id,
error = %e,
"Agent response failed: request not found (may have timed out)"
);
Err((
StatusCode::NOT_FOUND,
format!("Request ID {} not found or already completed", request_id),
))
}
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use crate::jsonrpc::JsonRpcResponseBatch;
use crate::rpc::NoOpConnectorOperations;
use axum::{
body::Body,
http::{Request, StatusCode},
};
use tower::ServiceExt;
fn create_test_router() -> Router {
let state = AcpServerState::new(AcpServerConfig::enabled());
create_acp_server_router(state, NoOpConnectorOperations)
}
#[tokio::test]
async fn test_health_endpoint() {
let router = create_test_router();
let request = Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let health: HealthResponse = serde_json::from_slice(&body).unwrap();
assert_eq!(health.status, "healthy");
assert_eq!(health.clients, 0);
assert_eq!(health.sessions, 0);
}
#[tokio::test]
async fn test_rpc_initialize() {
let router = create_test_router();
let request_body = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
let request = Request::builder()
.method("POST")
.uri("/rpc")
.header("content-type", "application/json")
.body(Body::from(request_body))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let resp: JsonRpcResponseBatch = serde_json::from_slice(&body).unwrap();
match resp {
JsonRpcResponseBatch::Single(r) => {
assert!(r.is_success());
let result = r.result.unwrap();
assert_eq!(result["agentInfo"]["name"], "dirigent-acp-server");
}
_ => panic!("Expected single response"),
}
}
#[tokio::test]
async fn test_rpc_invalid_json() {
let router = create_test_router();
let request = Request::builder()
.method("POST")
.uri("/rpc")
.header("content-type", "application/json")
.body(Body::from("not valid json"))
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let resp: JsonRpcResponseBatch = serde_json::from_slice(&body).unwrap();
match resp {
JsonRpcResponseBatch::Single(r) => {
assert!(r.is_error());
let error = r.error.unwrap();
assert_eq!(error.code, crate::error::error_codes::PARSE_ERROR);
}
_ => panic!("Expected single response"),
}
}
#[tokio::test]
async fn test_sse_missing_client_id() {
let router = create_test_router();
let request = Request::builder()
.uri("/events")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let error: SseErrorResponse = serde_json::from_slice(&body).unwrap();
assert_eq!(error.code, "MISSING_CLIENT_ID");
}
#[tokio::test]
async fn test_sse_with_client_id() {
let router = create_test_router();
let request = Request::builder()
.uri("/events?client_id=test-client")
.body(Body::empty())
.unwrap();
let response = router.oneshot(request).await.unwrap();
// Should return 200 OK with SSE content type
assert_eq!(response.status(), StatusCode::OK);
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok());
assert!(content_type.is_some());
assert!(content_type.unwrap().contains("text/event-stream"));
}
#[test]
fn test_acp_server_state_new() {
let state = AcpServerState::new(AcpServerConfig::enabled());
assert!(state.config().enabled);
assert_eq!(state.session_manager().mapping_count(), 0);
assert_eq!(state.sse_notifier().client_count(), 0);
}
#[test]
fn test_acp_server_state_with_components() {
let session_manager = SessionManager::new();
let sse_notifier = SseNotifier::new();
let config = AcpServerConfig::enabled().set_port(4000);
// Pre-populate some state
session_manager.register_client(None);
let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new());
let state = AcpServerState::with_components(
session_manager,
sse_notifier,
agent_request_tracker,
config,
);
assert!(state.config().enabled);
assert_eq!(state.config().port, 4000);
assert_eq!(state.session_manager().client_count(), 1);
}
#[test]
fn test_acp_server_state_clone() {
let state1 = AcpServerState::new(AcpServerConfig::enabled());
let state2 = state1.clone();
// Both should share the same underlying state
let client_id = state1.session_manager().register_client(None);
assert_eq!(state2.session_manager().client_count(), 1);
assert!(state2.session_manager().get_client(&client_id).is_some());
}
#[test]
fn test_acp_server_state_debug() {
let state = AcpServerState::new(AcpServerConfig::enabled());
let debug_str = format!("{:?}", state);
assert!(debug_str.contains("AcpServerState"));
assert!(debug_str.contains("session_count"));
assert!(debug_str.contains("client_count"));
}
#[test]
fn test_build_cors_layer_any_origin() {
let config = AcpServerConfig::default();
let _cors = build_cors_layer(&config);
// No assertion needed - just verify it doesn't panic
}
#[test]
fn test_build_cors_layer_specific_origins() {
let config = AcpServerConfig::default()
.set_allowed_origins(Some(vec![
"http://localhost:3000".to_string(),
"https://app.example.com".to_string(),
]));
let _cors = build_cors_layer(&config);
// No assertion needed - just verify it doesn't panic
}
#[test]
fn test_build_cors_layer_empty_origins() {
let config = AcpServerConfig::default()
.set_allowed_origins(Some(vec![]));
let _cors = build_cors_layer(&config);
// Should fall back to any origin
}
}
File diff suppressed because it is too large Load Diff
+185
View File
@@ -0,0 +1,185 @@
//! NoOp Connector Operations for Testing
//!
//! This module provides a stub implementation of `ConnectorOperations` that
//! can be used for testing the RPC handler without actual connector access.
use async_trait::async_trait;
use tracing::debug;
use crate::error::AcpServerError;
use super::types::{ConnectorInfo, ConnectorOperations, SessionInfo};
/// A no-op implementation of ConnectorOperations for testing
///
/// This stub provides minimal implementations that return success without
/// actually doing anything. Useful for unit testing the RPC dispatch logic.
#[derive(Debug, Clone, Copy, Default)]
pub struct NoOpConnectorOperations;
#[async_trait]
impl ConnectorOperations for NoOpConnectorOperations {
async fn create_session(
&self,
connector_id: &str,
_cwd: Option<String>,
_ownership: dirigent_protocol::SessionOwnership,
) -> Result<SessionInfo, AcpServerError> {
Ok(SessionInfo {
session_id: uuid::Uuid::new_v4().to_string(),
title: Some("New Session".to_string()),
connector_id: connector_id.to_string(),
cwd: None,
created_at: chrono::Utc::now().to_rfc3339(),
models: None,
modes: None,
})
}
async fn load_session(
&self,
connector_id: &str,
session_id: &str,
_cwd: Option<String>,
_mcp_servers: Option<serde_json::Value>,
) -> Result<SessionInfo, AcpServerError> {
Ok(SessionInfo {
session_id: session_id.to_string(),
title: Some("Loaded Session".to_string()),
connector_id: connector_id.to_string(),
cwd: None,
created_at: chrono::Utc::now().to_rfc3339(),
models: None,
modes: None,
})
}
async fn send_message(
&self,
_connector_id: &str,
_session_id: &str,
text: String,
) -> Result<String, AcpServerError> {
debug!("NoOp: send_message - {} chars", text.len());
Ok("end_turn".to_string())
}
async fn cancel_generation(
&self,
_connector_id: &str,
_session_id: &str,
) -> Result<(), AcpServerError> {
debug!("NoOp: cancel_generation");
Ok(())
}
async fn list_connectors(&self) -> Result<Vec<ConnectorInfo>, AcpServerError> {
Ok(vec![ConnectorInfo {
id: "stub-connector".to_string(),
name: "Stub Connector".to_string(),
connector_type: "stub".to_string(),
available: true,
}])
}
async fn default_connector_id(&self) -> Option<String> {
Some("stub-connector".to_string())
}
async fn get_connector_commands(
&self,
_connector_id: &str,
) -> Result<Vec<crate::sse::SlashCommand>, AcpServerError> {
// Return a stub list of commands
Ok(vec![crate::sse::SlashCommand {
name: "echo".to_string(),
description: "Echo command (stub)".to_string(),
input: None,
}])
}
async fn send_agent_response(
&self,
connector_id: &str,
request_id: serde_json::Value,
_response: serde_json::Value,
) -> Result<(), AcpServerError> {
debug!(
"NoOp: send_agent_response - connector: {}, request: {}",
connector_id, request_id
);
Ok(())
}
async fn get_session_metadata(
&self,
_connector_id: &str,
_session_id: &str,
) -> Result<
(
Option<dirigent_protocol::SessionModelState>,
Option<dirigent_protocol::SessionModeState>,
),
AcpServerError,
> {
debug!("NoOp: get_session_metadata");
// Return None for both - no metadata available in stub
Ok((None, None))
}
async fn set_session_mode(
&self,
_connector_id: &str,
_session_id: &str,
mode_id: &str,
) -> Result<(), AcpServerError> {
debug!("NoOp: set_session_mode - mode_id: {}", mode_id);
Ok(())
}
async fn set_session_model(
&self,
_connector_id: &str,
_session_id: &str,
model_id: &str,
) -> Result<(), AcpServerError> {
debug!("NoOp: set_session_model - model_id: {}", model_id);
Ok(())
}
async fn get_connector_agent_type(
&self,
_connector_id: &str,
) -> Result<Option<String>, AcpServerError> {
debug!("NoOp: get_connector_agent_type");
Ok(None)
}
async fn list_sessions(
&self,
connector_id: &str,
) -> Result<Vec<SessionInfo>, AcpServerError> {
debug!("NoOp: list_sessions for connector: {}", connector_id);
Ok(vec![
SessionInfo {
session_id: "stub-session-1".to_string(),
title: Some("Stub Session".to_string()),
connector_id: connector_id.to_string(),
cwd: None,
created_at: chrono::Utc::now().to_rfc3339(),
models: None,
modes: None,
},
])
}
async fn resolve_session_connector(&self, _session_id: &str) -> Option<String> {
debug!("NoOp: resolve_session_connector");
None
}
async fn list_all_sessions(&self) -> Result<Vec<SessionInfo>, AcpServerError> {
debug!("NoOp: list_all_sessions");
Ok(vec![])
}
}
+705
View File
@@ -0,0 +1,705 @@
//! JSON-RPC Request/Response Types for ACP Server
//!
//! This module contains all the parameter and result types used in
//! the ACP JSON-RPC protocol.
use serde::{Deserialize, Serialize};
use crate::error::AcpServerError;
use crate::sse::ConfigOption;
// ============================================================================
// Session and Connector Info Types
// ============================================================================
/// Information about a session
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionInfo {
/// The session ID
pub session_id: String,
/// Optional title for the session
pub title: Option<String>,
/// The connector handling this session
pub connector_id: String,
/// Working directory for this session
pub cwd: Option<String>,
/// When the session was created (ISO 8601 format)
pub created_at: String,
/// Available models and current model (optional, connector-dependent)
pub models: Option<dirigent_protocol::SessionModelState>,
/// Available modes and current mode (optional, connector-dependent)
pub modes: Option<dirigent_protocol::SessionModeState>,
}
/// Information about a connector
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ConnectorInfo {
/// Unique identifier for this connector
pub id: String,
/// Human-readable name
pub name: String,
/// Type of connector (e.g., "opencode", "gateway")
pub connector_type: String,
/// Whether the connector is currently available
pub available: bool,
}
// ============================================================================
// Initialize Types
// ============================================================================
/// Parameters for the initialize request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeParams {
/// Client capabilities
#[serde(default)]
pub capabilities: Option<serde_json::Value>,
/// Client name
#[serde(default)]
pub client_name: Option<String>,
/// Client version
#[serde(default)]
pub client_version: Option<String>,
}
/// Result of the initialize request (ACP spec compliant)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResult {
/// Protocol version (integer)
pub protocol_version: u32,
/// Agent capabilities
pub agent_capabilities: AgentCapabilities,
/// Agent info (name, title, version)
pub agent_info: AgentInfo,
/// Authentication methods (empty for now)
pub auth_methods: Vec<AuthMethod>,
/// Client ID for HTTP-based transports (required for SSE routing)
/// Note: This is not in the official ACP spec but is required for stateless
/// HTTP connections where the client needs an ID to subscribe to SSE events.
pub client_id: String,
}
/// Agent information
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AgentInfo {
/// Agent name (programmatic)
pub name: String,
/// Agent title (human-readable)
pub title: String,
/// Agent version
pub version: String,
}
/// Agent capabilities advertised to clients
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AgentCapabilities {
/// Whether session/load is supported
#[serde(skip_serializing_if = "Option::is_none")]
pub load_session: Option<bool>,
/// Whether session/list is supported
#[serde(skip_serializing_if = "Option::is_none")]
pub list_sessions: Option<bool>,
/// Session capabilities (resume, fork, etc.)
#[serde(skip_serializing_if = "Option::is_none")]
pub session_capabilities: Option<SessionCapabilities>,
/// Prompt capabilities (content types supported)
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_capabilities: Option<PromptCapabilities>,
/// MCP server support
#[serde(skip_serializing_if = "Option::is_none")]
pub mcp: Option<McpCapabilities>,
}
/// Extended session capabilities
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionCapabilities {
/// Whether session/list is supported (empty object = supported)
#[serde(skip_serializing_if = "Option::is_none")]
pub list: Option<serde_json::Value>,
/// Whether session/resume is supported (empty object = supported)
#[serde(skip_serializing_if = "Option::is_none")]
pub resume: Option<serde_json::Value>,
}
/// Prompt capabilities (content types)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PromptCapabilities {
/// Image content support
#[serde(skip_serializing_if = "Option::is_none")]
pub image: Option<bool>,
/// Audio content support
#[serde(skip_serializing_if = "Option::is_none")]
pub audio: Option<bool>,
/// Embedded context support
#[serde(skip_serializing_if = "Option::is_none")]
pub embedded_context: Option<bool>,
}
/// MCP capabilities
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct McpCapabilities {
/// HTTP transport support
#[serde(skip_serializing_if = "Option::is_none")]
pub http: Option<bool>,
/// SSE transport support (deprecated)
#[serde(skip_serializing_if = "Option::is_none")]
pub sse: Option<bool>,
}
/// Authentication method
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AuthMethod {
/// Method ID
pub id: String,
/// Method name
pub name: String,
/// Method description
pub description: String,
}
impl Default for AgentCapabilities {
fn default() -> Self {
Self {
load_session: Some(true),
list_sessions: Some(true),
session_capabilities: Some(SessionCapabilities {
list: Some(serde_json::json!({})),
resume: Some(serde_json::json!({})),
}),
prompt_capabilities: Some(PromptCapabilities {
image: Some(true),
audio: None,
embedded_context: Some(true),
}),
mcp: Some(McpCapabilities {
http: Some(true),
sse: Some(true),
}),
}
}
}
// ============================================================================
// Session/New Types
// ============================================================================
/// Parameters for session/new request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionNewParams {
/// Optional connector ID to create the session on
#[serde(default)]
pub connector_id: Option<String>,
/// Optional working directory for the session
#[serde(default)]
pub cwd: Option<String>,
/// Optional client-provided session ID
#[serde(default)]
pub session_id: Option<String>,
}
/// Result of session/new request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionNewResult {
/// The client-facing session ID
pub session_id: String,
/// Optional title for the session
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
/// The connector handling this session
pub connector_id: String,
/// When the session was created (ISO 8601 format)
pub created_at: String,
/// Available models and current model (UNSTABLE in ACP spec)
#[serde(skip_serializing_if = "Option::is_none")]
pub models: Option<dirigent_protocol::SessionModelState>,
/// Available modes and current mode
#[serde(skip_serializing_if = "Option::is_none")]
pub modes: Option<dirigent_protocol::SessionModeState>,
/// Configuration options (modes, models as unified config)
/// This is the preferred way to expose configuration in ACP
#[serde(skip_serializing_if = "Option::is_none")]
pub config_options: Option<Vec<ConfigOption>>,
}
// ============================================================================
// Session/Load Types
// ============================================================================
/// Parameters for session/load request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionLoadParams {
/// The session ID to load
pub session_id: String,
/// The connector ID where the session exists (optional; resolved automatically if omitted)
#[serde(default)]
pub connector_id: Option<String>,
/// Optional working directory (standard ACP field sent by clients like Zed)
#[serde(default)]
pub cwd: Option<String>,
/// Optional MCP server configurations (standard ACP field)
#[serde(default)]
pub mcp_servers: Option<Vec<serde_json::Value>>,
}
/// Result of session/load request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionLoadResult {
/// The client-facing session ID
pub session_id: String,
/// Optional title for the session
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
/// The connector handling this session
pub connector_id: String,
/// When the session was created (ISO 8601 format)
pub created_at: String,
/// Available models and current model (UNSTABLE in ACP spec)
#[serde(skip_serializing_if = "Option::is_none")]
pub models: Option<dirigent_protocol::SessionModelState>,
/// Available modes and current mode
#[serde(skip_serializing_if = "Option::is_none")]
pub modes: Option<dirigent_protocol::SessionModeState>,
/// Configuration options (modes, models as unified config)
/// This is the preferred way to expose configuration in ACP
#[serde(skip_serializing_if = "Option::is_none")]
pub config_options: Option<Vec<ConfigOption>>,
}
// ============================================================================
// Session/List Types
// ============================================================================
/// Parameters for session/list request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionListParams {
/// Optional connector ID to list sessions from
#[serde(default)]
pub connector_id: Option<String>,
/// Optional working directory filter
#[serde(default)]
pub cwd: Option<String>,
/// Optional pagination cursor from previous response
#[serde(default)]
pub cursor: Option<String>,
}
/// A session entry in a session/list response (ACP spec)
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionListEntry {
/// Unique session identifier
pub session_id: String,
/// Working directory for this session
pub cwd: String,
/// Human-readable title
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
/// Last activity timestamp (ISO 8601)
#[serde(skip_serializing_if = "Option::is_none")]
pub updated_at: Option<String>,
/// Agent-specific metadata
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "_meta")]
pub meta: Option<serde_json::Value>,
}
/// Result of session/list request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionListResult {
/// List of available sessions
pub sessions: Vec<SessionListEntry>,
/// Pagination cursor for next page (absent when no more results)
#[serde(skip_serializing_if = "Option::is_none")]
pub next_cursor: Option<String>,
}
// ============================================================================
// Session/Resume Types
// ============================================================================
/// Parameters for session/resume request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionResumeParams {
/// The session ID to resume
pub session_id: String,
/// The connector ID where the session exists (optional; resolved automatically if omitted)
#[serde(default)]
pub connector_id: Option<String>,
/// Optional working directory (standard ACP field sent by clients like Zed)
#[serde(default)]
pub cwd: Option<String>,
/// Optional MCP server configurations (standard ACP field)
#[serde(default)]
pub mcp_servers: Option<Vec<serde_json::Value>>,
}
/// Result of session/resume request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionResumeResult {
/// The client-facing session ID
pub session_id: String,
/// Optional title for the session
#[serde(skip_serializing_if = "Option::is_none")]
pub title: Option<String>,
/// The connector handling this session
pub connector_id: String,
/// When the session was created (ISO 8601 format)
pub created_at: String,
/// Available models and current model
#[serde(skip_serializing_if = "Option::is_none")]
pub models: Option<dirigent_protocol::SessionModelState>,
/// Available modes and current mode
#[serde(skip_serializing_if = "Option::is_none")]
pub modes: Option<dirigent_protocol::SessionModeState>,
/// Configuration options
#[serde(skip_serializing_if = "Option::is_none")]
pub config_options: Option<Vec<crate::sse::ConfigOption>>,
}
// ============================================================================
// Session/Prompt Types
// ============================================================================
/// Parameters for session/prompt request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionPromptParams {
/// The session ID to send the prompt to
pub session_id: String,
/// The prompt content (array of content blocks)
pub prompt: PromptContent,
}
/// Content for a prompt - either simple text or structured blocks
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PromptContent {
/// Simple text content
Text(String),
/// Structured content blocks
Blocks(Vec<ContentBlock>),
}
impl PromptContent {
/// Convert to plain text representation
pub fn to_text(&self) -> String {
match self {
PromptContent::Text(text) => text.clone(),
PromptContent::Blocks(blocks) => blocks
.iter()
.filter_map(|b| {
if let ContentBlock::Text { text } = b {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join("\n"),
}
}
}
/// A content block in a prompt
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlock {
/// Text content
#[serde(rename = "text")]
Text { text: String },
/// Image content (base64 encoded)
#[serde(rename = "image")]
Image { data: String, media_type: String },
}
/// Result of session/prompt request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionPromptResult {
/// The reason the turn stopped
pub stop_reason: String,
}
// ============================================================================
// Session/Cancel Types
// ============================================================================
/// Parameters for session/cancel request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionCancelParams {
/// The session ID to cancel
pub session_id: String,
}
/// Result of session/cancel request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionCancelResult {
/// Whether the cancellation was successful
pub cancelled: bool,
}
// ============================================================================
// Session/Close Types
// ============================================================================
/// Parameters for session/close request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionCloseParams {
/// The session ID to close
pub session_id: String,
}
/// Result of session/close request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionCloseResult {
/// Whether the session was successfully closed
pub closed: bool,
}
// ============================================================================
// Session/SetMode and Session/SetModel Types
// ============================================================================
/// Parameters for session/set_mode request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionSetModeParams {
/// The session ID to update
pub session_id: String,
/// The mode ID to switch to
pub mode_id: String,
}
/// Parameters for session/set_model request
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionSetModelParams {
/// The session ID to update
pub session_id: String,
/// The model ID to switch to
pub model_id: String,
}
/// Parameters for session/set_config_option request
///
/// This is the unified way to set configuration options (modes, models, etc.)
/// for clients using the new config_options system.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SessionSetConfigOptionParams {
/// The session ID to update
pub session_id: String,
/// The config option ID (e.g., "mode", "model")
pub config_id: String,
/// The value to set
pub value: String,
}
// ============================================================================
// ConnectorOperations Trait
// ============================================================================
/// Trait for connector operations that will be implemented by the web server
///
/// This trait abstracts the operations that require access to `CoreHandle`,
/// allowing the RPC handler to remain decoupled from `dirigent_core`.
/// The web server implements this trait to bridge between the ACP server
/// and the actual connector implementations.
#[async_trait::async_trait]
pub trait ConnectorOperations: Send + Sync {
/// Create a new session on the specified connector
async fn create_session(
&self,
connector_id: &str,
cwd: Option<String>,
ownership: dirigent_protocol::SessionOwnership,
) -> Result<SessionInfo, AcpServerError>;
/// Load an existing session from a connector
async fn load_session(
&self,
connector_id: &str,
session_id: &str,
cwd: Option<String>,
mcp_servers: Option<serde_json::Value>,
) -> Result<SessionInfo, AcpServerError>;
/// Send a message to a session and wait for completion
async fn send_message(
&self,
connector_id: &str,
session_id: &str,
text: String,
) -> Result<String, AcpServerError>;
/// Cancel active generation on a session
async fn cancel_generation(
&self,
connector_id: &str,
session_id: &str,
) -> Result<(), AcpServerError>;
/// List all available connectors
async fn list_connectors(&self) -> Result<Vec<ConnectorInfo>, AcpServerError>;
/// Get the default connector ID (if configured)
async fn default_connector_id(&self) -> Option<String>;
/// Get available commands/tools from a connector
async fn get_connector_commands(
&self,
connector_id: &str,
) -> Result<Vec<crate::sse::SlashCommand>, AcpServerError>;
/// Send an agent response back to a connector
async fn send_agent_response(
&self,
connector_id: &str,
request_id: serde_json::Value,
response: serde_json::Value,
) -> Result<(), AcpServerError>;
/// Get session metadata (models/modes) from a session
async fn get_session_metadata(
&self,
connector_id: &str,
session_id: &str,
) -> Result<
(
Option<dirigent_protocol::SessionModelState>,
Option<dirigent_protocol::SessionModeState>,
),
AcpServerError,
>;
/// Set the session mode
async fn set_session_mode(
&self,
connector_id: &str,
session_id: &str,
mode_id: &str,
) -> Result<(), AcpServerError>;
/// Set the session model
async fn set_session_model(
&self,
connector_id: &str,
session_id: &str,
model_id: &str,
) -> Result<(), AcpServerError>;
/// Get the agent type for a connector (for mode/model mapping)
async fn get_connector_agent_type(
&self,
connector_id: &str,
) -> Result<Option<String>, AcpServerError>;
/// List all sessions on a connector
async fn list_sessions(
&self,
connector_id: &str,
) -> Result<Vec<SessionInfo>, AcpServerError>;
/// Look up which connector owns a session by its session ID.
/// Uses archivist to find the connector. Returns connector_id (not UID).
/// Default: returns None (no archivist available).
async fn resolve_session_connector(&self, _session_id: &str) -> Option<String> {
None
}
/// List sessions across all connectors (archivist-backed).
/// Used when the resolved connector is a gateway to provide cross-connector view.
/// Default: returns empty vec (no archivist available).
async fn list_all_sessions(&self) -> Result<Vec<SessionInfo>, AcpServerError> {
Ok(vec![])
}
}
File diff suppressed because it is too large Load Diff
@@ -0,0 +1,435 @@
//! Content and Metadata Transformation Utilities
//!
//! This module provides helper functions for transforming content between
//! the internal Dirigent protocol and the ACP wire format.
use dirigent_protocol::ContentBlock;
use serde_json::json;
/// Extract 'kind' from tool call metadata (stored as acp_kind)
pub fn extract_kind(tool_call: &dirigent_protocol::ToolCall) -> Option<String> {
tool_call
.metadata
.as_ref()
.and_then(|m| m.get("acp_kind"))
.and_then(|k| k.as_str())
.map(String::from)
}
/// Convert ToolCallStatus to ACP string format
pub fn tool_call_status_to_string(status: &dirigent_protocol::ToolCallStatus) -> String {
match status {
dirigent_protocol::ToolCallStatus::Pending => "pending".to_string(),
dirigent_protocol::ToolCallStatus::Running => "in_progress".to_string(),
dirigent_protocol::ToolCallStatus::Completed => "completed".to_string(),
dirigent_protocol::ToolCallStatus::Error => "failed".to_string(),
}
}
/// Unwrap ToolCallContent wrappers to extract ContentBlocks for SSE output
///
/// The internal protocol uses ToolCallContent wrappers, but the SSE/ACP wire format
/// expects flat ContentBlock arrays. This function extracts Content variants only.
pub fn unwrap_tool_call_content(
content: &[dirigent_protocol::ToolCallContent],
) -> Vec<ContentBlock> {
content
.iter()
.filter_map(|wrapper| match wrapper {
dirigent_protocol::ToolCallContent::Content { content } => Some(content.clone()),
// Diff and Terminal variants are not yet supported in SSE output
_ => None,
})
.collect()
}
/// Rebuild _meta for ACP output (ensures claudeCode.toolName is present)
pub fn rebuild_meta(
meta: &Option<dirigent_protocol::Meta>,
tool_call: &dirigent_protocol::ToolCall,
) -> Option<serde_json::Value> {
let mut meta_obj = serde_json::Map::new();
// Add claudeCode.toolName
let mut claude_code = serde_json::Map::new();
claude_code.insert("toolName".to_string(), json!(tool_call.tool_name));
// Merge any existing _meta (preserves toolResponse, etc.)
if let Some(existing_meta) = meta {
// Convert Meta to JSON Value for merging
if let Ok(meta_value) = serde_json::to_value(existing_meta) {
if let Some(obj) = meta_value.as_object() {
for (k, v) in obj {
if k == "claudeCode" {
// Merge into our claudeCode object
if let Some(cc) = v.as_object() {
for (ck, cv) in cc {
// Don't override toolName we just set
if ck != "toolName" {
claude_code.insert(ck.clone(), cv.clone());
}
}
}
} else {
meta_obj.insert(k.clone(), v.clone());
}
}
}
}
}
meta_obj.insert("claudeCode".to_string(), json!(claude_code));
Some(json!(meta_obj))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_extract_kind_present() {
let mut metadata = serde_json::Map::new();
metadata.insert("acp_kind".to_string(), json!("search"));
let tool_call = dirigent_protocol::ToolCall {
id: "call_123".to_string(),
tool_name: "bash".to_string(),
status: dirigent_protocol::ToolCallStatus::Pending,
content: vec![],
raw_input: None,
raw_output: None,
title: None,
error: None,
metadata: Some(json!(metadata)),
origin: None,
};
assert_eq!(extract_kind(&tool_call), Some("search".to_string()));
}
#[test]
fn test_extract_kind_missing_metadata() {
let tool_call = dirigent_protocol::ToolCall {
id: "call_123".to_string(),
tool_name: "bash".to_string(),
status: dirigent_protocol::ToolCallStatus::Pending,
content: vec![],
raw_input: None,
raw_output: None,
title: None,
error: None,
metadata: None,
origin: None,
};
assert_eq!(extract_kind(&tool_call), None);
}
#[test]
fn test_extract_kind_missing_field() {
let mut metadata = serde_json::Map::new();
metadata.insert("other_field".to_string(), json!("value"));
let tool_call = dirigent_protocol::ToolCall {
id: "call_123".to_string(),
tool_name: "bash".to_string(),
status: dirigent_protocol::ToolCallStatus::Pending,
content: vec![],
raw_input: None,
raw_output: None,
title: None,
error: None,
metadata: Some(json!(metadata)),
origin: None,
};
assert_eq!(extract_kind(&tool_call), None);
}
#[test]
fn test_extract_kind_malformed_metadata() {
// acp_kind is not a string
let mut metadata = serde_json::Map::new();
metadata.insert("acp_kind".to_string(), json!(123));
let tool_call = dirigent_protocol::ToolCall {
id: "call_123".to_string(),
tool_name: "bash".to_string(),
status: dirigent_protocol::ToolCallStatus::Pending,
content: vec![],
raw_input: None,
raw_output: None,
title: None,
error: None,
metadata: Some(json!(metadata)),
origin: None,
};
assert_eq!(extract_kind(&tool_call), None);
}
#[test]
fn test_tool_call_status_to_string_all_variants() {
assert_eq!(
tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Pending),
"pending"
);
assert_eq!(
tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Running),
"in_progress"
);
assert_eq!(
tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Completed),
"completed"
);
assert_eq!(
tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Error),
"failed"
);
}
#[test]
fn test_rebuild_meta_basic() {
let tool_call = dirigent_protocol::ToolCall {
id: "call_123".to_string(),
tool_name: "some_tool".to_string(),
status: dirigent_protocol::ToolCallStatus::Pending,
content: vec![],
raw_input: None,
raw_output: None,
title: None,
error: None,
metadata: None,
origin: None,
};
let meta = rebuild_meta(&None, &tool_call);
assert!(meta.is_some());
let meta_obj = meta.unwrap();
assert_eq!(meta_obj["claudeCode"]["toolName"], "some_tool");
}
#[test]
fn test_rebuild_meta_preserves_existing_fields() {
let existing_meta = dirigent_protocol::Meta {
provider: Some(dirigent_protocol::ProviderMeta {
name: "test_provider".to_string(),
original_ids: None,
raw_excerpt: None,
}),
extra: std::collections::HashMap::from([(
"customField".to_string(),
json!("customValue"),
)]),
};
let tool_call = dirigent_protocol::ToolCall {
id: "call_123".to_string(),
tool_name: "bash".to_string(),
status: dirigent_protocol::ToolCallStatus::Running,
content: vec![],
raw_input: None,
raw_output: None,
title: None,
error: None,
metadata: None,
origin: None,
};
let meta = rebuild_meta(&Some(existing_meta), &tool_call);
assert!(meta.is_some());
let meta_obj = meta.unwrap();
// Verify claudeCode.toolName is present
assert_eq!(meta_obj["claudeCode"]["toolName"], "bash");
// Verify existing fields are preserved
assert_eq!(meta_obj["provider"]["name"], "test_provider");
assert_eq!(meta_obj["customField"], "customValue");
}
#[test]
fn test_rebuild_meta_merges_claude_code_fields() {
let mut existing_meta = dirigent_protocol::Meta::default();
existing_meta.extra.insert(
"claudeCode".to_string(),
json!({
"toolResponse": "some response",
"otherField": "other value"
}),
);
let tool_call = dirigent_protocol::ToolCall {
id: "call_123".to_string(),
tool_name: "test_tool".to_string(),
status: dirigent_protocol::ToolCallStatus::Completed,
content: vec![],
raw_input: None,
raw_output: None,
title: None,
error: None,
metadata: None,
origin: None,
};
let meta = rebuild_meta(&Some(existing_meta), &tool_call);
assert!(meta.is_some());
let meta_obj = meta.unwrap();
// Verify toolName is present (we set it)
assert_eq!(meta_obj["claudeCode"]["toolName"], "test_tool");
// Verify existing claudeCode fields are preserved
assert_eq!(meta_obj["claudeCode"]["toolResponse"], "some response");
assert_eq!(meta_obj["claudeCode"]["otherField"], "other value");
}
#[test]
fn test_rebuild_meta_does_not_override_tool_name() {
// Existing meta has a different toolName
let mut existing_meta = dirigent_protocol::Meta::default();
existing_meta.extra.insert(
"claudeCode".to_string(),
json!({
"toolName": "wrong_tool",
"toolResponse": "response"
}),
);
let tool_call = dirigent_protocol::ToolCall {
id: "call_123".to_string(),
tool_name: "correct_tool".to_string(),
status: dirigent_protocol::ToolCallStatus::Completed,
content: vec![],
raw_input: None,
raw_output: None,
title: None,
error: None,
metadata: None,
origin: None,
};
let meta = rebuild_meta(&Some(existing_meta), &tool_call);
assert!(meta.is_some());
let meta_obj = meta.unwrap();
// Verify our toolName is used (not the one from existing meta)
assert_eq!(meta_obj["claudeCode"]["toolName"], "correct_tool");
// Verify other fields are preserved
assert_eq!(meta_obj["claudeCode"]["toolResponse"], "response");
}
#[test]
fn test_rebuild_meta_preserves_tool_response() {
// Test that rebuild_meta preserves toolResponse from incoming Claude updates
let mut existing_meta = dirigent_protocol::Meta::default();
existing_meta.extra.insert(
"claudeCode".to_string(),
json!({
"toolResponse": {
"mode": "content",
"numFiles": 0,
"filenames": [],
"content": "some output",
"numLines": 58,
"appliedLimit": 100
},
"toolName": "OldToolName"
}),
);
let tool_call = dirigent_protocol::ToolCall {
id: "call_456".to_string(),
tool_name: "Grep".to_string(),
status: dirigent_protocol::ToolCallStatus::Running,
content: vec![],
raw_input: None,
raw_output: None,
title: None,
error: None,
metadata: None,
origin: None,
};
let meta = rebuild_meta(&Some(existing_meta), &tool_call);
assert!(meta.is_some());
let meta_obj = meta.unwrap();
// toolName should be updated to current tool_call.tool_name
assert_eq!(meta_obj["claudeCode"]["toolName"], "Grep");
// toolResponse should be preserved completely
assert!(meta_obj["claudeCode"]["toolResponse"].is_object());
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["mode"], "content");
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numFiles"], 0);
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numLines"], 58);
}
#[test]
fn test_rebuild_meta_round_trip_no_data_loss() {
// Test incoming → internal → outgoing preserves all fields
let mut existing_meta = dirigent_protocol::Meta::default();
existing_meta.extra.insert(
"claudeCode".to_string(),
json!({
"toolResponse": {
"mode": "content",
"numFiles": 3,
"filenames": ["file1.rs", "file2.rs", "file3.rs"],
"content": "grep results here",
"numLines": 42,
"appliedLimit": 100,
"customField": "should be preserved",
"nestedObject": {
"deep": "value"
}
},
"additionalField": "also preserved"
}),
);
existing_meta
.extra
.insert("customTopLevel".to_string(), json!("preserved too"));
let tool_call = dirigent_protocol::ToolCall {
id: "call_789".to_string(),
tool_name: "Grep".to_string(),
status: dirigent_protocol::ToolCallStatus::Completed,
content: vec![],
raw_input: None,
raw_output: None,
title: None,
error: None,
metadata: None,
origin: None,
};
let meta = rebuild_meta(&Some(existing_meta), &tool_call);
assert!(meta.is_some());
let meta_obj = meta.unwrap();
// Verify NO data loss
assert_eq!(meta_obj["claudeCode"]["toolName"], "Grep");
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["mode"], "content");
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numFiles"], 3);
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numLines"], 42);
assert_eq!(
meta_obj["claudeCode"]["toolResponse"]["customField"],
"should be preserved"
);
assert_eq!(
meta_obj["claudeCode"]["toolResponse"]["nestedObject"]["deep"],
"value"
);
assert_eq!(meta_obj["claudeCode"]["additionalField"], "also preserved");
assert_eq!(meta_obj["customTopLevel"], "preserved too");
}
}
@@ -0,0 +1,453 @@
//! Event Translation Layer
//!
//! This module provides translation from Dirigent protocol Events to ACP notifications.
//! It handles the mapping between internal event representations and the wire format
//! expected by ACP clients.
use dirigent_protocol::{Event, SessionUpdate};
use super::content_transform::{
extract_kind, rebuild_meta, tool_call_status_to_string, unwrap_tool_call_content,
};
use super::types::{
models_to_config_option, modes_to_config_option, ConfigOption, ConfigOptionChoice,
ConfigOptionType, SessionUpdateParams, SessionUpdateVariant,
};
/// Translate a Dirigent protocol Event to ACP notifications
///
/// Not all Dirigent events need to be forwarded to ACP clients. This function
/// returns a vector of notifications for events that should be sent, and an empty
/// vector for events that should be filtered out.
///
/// ## Mapped Events
///
/// - `Event::SessionUpdate` with `SessionUpdate::AgentMessageChunk` -> `MessageChunk`
/// - `Event::SessionUpdate` with `SessionUpdate::AgentThoughtChunk` -> `MessageChunk` (with content_type)
/// - `Event::SessionUpdate` with `SessionUpdate::ToolCall` -> `ToolCallUpdate`
/// - `Event::SessionUpdate` with `SessionUpdate::ToolCallUpdate` -> `ToolCallUpdate`
/// - `Event::MessageCompleted` -> `MessageComplete`
/// - `Event::SessionIdle` -> `SessionIdle`
/// - `Event::SessionMetadataReceived` -> `ConfigOptionUpdate` (with modes and models)
/// - `Event::MessageFailed` -> `SessionError`
/// - `Event::Error` -> `SessionError` (system-wide error)
///
/// ## Filtered Events
///
/// - `Event::SessionsListed` - List operations don't need streaming
/// - `Event::SessionCreated` - Handled via RPC response
/// - `Event::SessionUpdated` - Metadata updates, not content
/// - `Event::SessionMetadataUpdated` - Metadata updates
/// - `Event::SessionDeleted` - Handled via RPC
/// - `Event::Connected` / `Event::Disconnected` - System events
/// - `Event::ConnectorCreated` / `Event::ConnectorRemoved` - System events
/// - `Event::MessagesListed` - List operations
/// - `Event::MessageStarted` - Initial message, content comes via chunks
///
/// # Parameters
///
/// - `event`: The Dirigent protocol event to translate
///
/// # Returns
///
/// Vec of `SessionUpdateParams` to be sent. Most events return 0 or 1 update.
pub fn translate_event(event: &Event) -> Vec<SessionUpdateParams> {
match event {
// SessionUpdate events - the main streaming content
Event::SessionUpdate {
session_id, update, ..
} => translate_session_update(session_id, update)
.map(|u| vec![u])
.unwrap_or_default(),
// Message completion
Event::MessageCompleted { message, .. } => vec![SessionUpdateParams {
session_id: message.session_id.clone(),
update: SessionUpdateVariant::MessageComplete {
message_id: Some(message.id.clone()),
},
..Default::default()
}],
// Session idle
Event::SessionIdle { session_id, .. } => vec![SessionUpdateParams {
session_id: session_id.clone(),
update: SessionUpdateVariant::SessionIdle {},
..Default::default()
}],
// Session metadata received - forward as BOTH config_option_update AND current_mode_update
//
// We emit both notification types to support:
// - New clients (acp-beta flag): Use config_option_update
// - Legacy clients: Use current_mode_update (mode only - no legacy model updates exist)
//
// Clients safely ignore notification types they don't understand.
// See: https://agentclientprotocol.com/rfds/session-config-options
Event::SessionMetadataReceived {
session_id,
models,
modes,
config_options: event_config_options,
..
} => {
let mut updates = vec![];
// If the event already has config_options from the agent, prefer those
// over converting from modes/models (agent-provided options are authoritative).
if let Some(agent_config_options) = event_config_options {
if !agent_config_options.is_empty() {
// Emit ConfigOptionUpdate with agent-provided options
updates.push(SessionUpdateParams {
session_id: session_id.clone(),
update: SessionUpdateVariant::ConfigOptionUpdate {
config_options: agent_config_options.iter().map(|co| ConfigOption {
id: co.id.clone(),
name: co.name.clone(),
description: co.description.clone(),
category: co.category.clone(),
option_type: match co.option_type {
dirigent_protocol::ConfigOptionType::Select => ConfigOptionType::Select,
},
current_value: co.current_value.clone(),
options: Some(co.options.iter().map(|v| ConfigOptionChoice {
value: v.value.clone(),
name: v.name.clone(),
description: v.description.clone(),
}).collect()),
}).collect(),
},
..Default::default()
});
// Also emit CurrentModeUpdate for legacy clients if a mode option exists
if let Some(mode_opt) = agent_config_options.iter().find(|co| co.id == "mode" || co.category.as_deref() == Some("mode")) {
updates.push(SessionUpdateParams {
session_id: session_id.clone(),
update: SessionUpdateVariant::CurrentModeUpdate {
mode_id: mode_opt.current_value.clone(),
},
..Default::default()
});
}
return updates;
}
}
// Fall back to building config_options from legacy modes/models fields
let mut config_options: Vec<ConfigOption> = vec![];
if let Some(modes_state) = modes {
config_options.push(modes_to_config_option(modes_state));
// Also emit CurrentModeUpdate for legacy clients
// Legacy protocol only supports mode updates (no model updates)
updates.push(SessionUpdateParams {
session_id: session_id.clone(),
update: SessionUpdateVariant::CurrentModeUpdate {
mode_id: modes_state.current_mode_id.clone(),
},
..Default::default()
});
}
if let Some(models_state) = models {
config_options.push(models_to_config_option(models_state));
// Note: Legacy protocol has no CurrentModelUpdate - only new ConfigOptionUpdate
}
if !config_options.is_empty() {
updates.push(SessionUpdateParams {
session_id: session_id.clone(),
update: SessionUpdateVariant::ConfigOptionUpdate { config_options },
..Default::default()
});
}
updates
}
// Message failure - not forwarded in new structure (no equivalent variant)
Event::MessageFailed { .. } => vec![],
// System-wide error - not forwarded in new structure (no equivalent variant)
Event::Error { .. } => vec![],
// Agent request - will be handled in Phase 2 (bidirectional request/response)
// For now, not forwarded (requires new SSE variant and response mechanism)
Event::AgentRequest { .. } => vec![],
// Events that should not be forwarded via this translation path
// (they may be handled elsewhere or are not relevant to ACP clients)
Event::SessionsListed { .. }
| Event::SessionCreated { .. }
| Event::SessionUpdated { .. }
| Event::SessionMetadataUpdated { .. }
| Event::SessionDeleted { .. }
| Event::SessionClosed { .. }
| Event::SessionSystemMessageSet { .. }
| Event::SessionError { .. }
| Event::SessionTransferred { .. }
| Event::ForwardingPanic { .. }
| Event::MessagesListed { .. }
| Event::MessageStarted { .. }
| Event::TurnComplete { .. }
| Event::ConnectorCreated { .. }
| Event::ConnectorRemoved { .. }
| Event::ConnectorStateChanged { .. }
| Event::Connected
| Event::Disconnected => vec![],
// ACP client connection events - these are handled by the UI directly
// via the main SSE endpoint, not translated to session updates
Event::AcpClientConnected { .. }
| Event::AcpClientDisconnected { .. }
| Event::AcpClientSessionOpened { .. }
| Event::AcpClientSessionRouted { .. } => vec![],
// Inspector events - not relevant for ACP session updates
Event::InspectorSnapshot { .. }
| Event::InspectorNodeRegistered { .. }
| Event::InspectorNodeRemoved { .. }
| Event::InspectorStateChanged { .. }
| Event::InspectorPropertiesUpdated { .. } => vec![],
// Archivist-to-frontend signal — not relevant for ACP clients
Event::SessionRegistered { .. } => vec![],
// System task events — internal UI concern, not relevant for ACP clients
Event::SystemTaskStatusChanged { .. } => vec![],
}
}
/// Translate a SessionUpdate to a SessionUpdateParams
fn translate_session_update(
session_id: &str,
update: &SessionUpdate,
) -> Option<SessionUpdateParams> {
match update {
// Agent message content chunk
SessionUpdate::AgentMessageChunk { content, .. } => Some(SessionUpdateParams {
session_id: session_id.to_string(),
update: SessionUpdateVariant::AgentMessageChunk {
content: content.clone(),
},
..Default::default()
}),
// Agent thought chunk - render as agent message (thoughts shown as agent text in UI)
SessionUpdate::AgentThoughtChunk { content, .. } => Some(SessionUpdateParams {
session_id: session_id.to_string(),
update: SessionUpdateVariant::AgentMessageChunk {
content: content.clone(),
},
..Default::default()
}),
// User message chunks are typically not streamed to other clients
// Filter them out for now
SessionUpdate::UserMessageChunk { .. } => None,
// Tool call initiated - forward to client
SessionUpdate::ToolCall {
tool_call, _meta, ..
} => Some(SessionUpdateParams {
session_id: session_id.to_string(),
update: SessionUpdateVariant::ToolCall {
tool_call_id: tool_call.id.clone(),
title: tool_call.title.clone(),
kind: extract_kind(tool_call),
raw_input: tool_call.raw_input.clone(),
status: Some(tool_call_status_to_string(&tool_call.status)),
content: unwrap_tool_call_content(&tool_call.content),
_meta: rebuild_meta(_meta, tool_call),
},
..Default::default()
}),
// Tool call update - forward to client
SessionUpdate::ToolCallUpdate {
tool_call, _meta, ..
} => Some(SessionUpdateParams {
session_id: session_id.to_string(),
update: SessionUpdateVariant::ToolCallUpdate {
tool_call_id: tool_call.id.clone(),
status: Some(tool_call_status_to_string(&tool_call.status)),
content: unwrap_tool_call_content(&tool_call.content),
raw_output: tool_call.raw_output.clone(),
error: tool_call.error.clone(),
_meta: rebuild_meta(_meta, tool_call),
},
..Default::default()
}),
// Unknown update type - forward as-is for transparent proxying
SessionUpdate::Unknown { data } => Some(SessionUpdateParams {
session_id: session_id.to_string(),
update: SessionUpdateVariant::Unknown { data: data.clone() },
..Default::default()
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
use dirigent_protocol::{ContentBlock, Message, MessageRole, MessageStatus};
#[test]
fn test_translate_session_update_agent_chunk() {
let event = Event::SessionUpdate {
connector_id: "conn-1".to_string(),
session_id: "sess-1".to_string(),
update: SessionUpdate::AgentMessageChunk {
message_id: "msg-1".to_string(),
content: ContentBlock::Text {
text: "Hello".to_string(),
},
_meta: None,
},
};
let updates = translate_event(&event);
assert_eq!(updates.len(), 1);
let update_params = &updates[0];
assert_eq!(update_params.session_id, "sess-1");
match &update_params.update {
SessionUpdateVariant::AgentMessageChunk { content } => match content {
ContentBlock::Text { text } => {
assert_eq!(text, "Hello");
}
_ => panic!("Expected Text content"),
},
_ => panic!("Expected AgentMessageChunk variant"),
}
}
#[test]
fn test_translate_session_update_thought_chunk() {
let event = Event::SessionUpdate {
connector_id: "conn-1".to_string(),
session_id: "sess-1".to_string(),
update: SessionUpdate::AgentThoughtChunk {
message_id: "msg-1".to_string(),
content: ContentBlock::Text {
text: "Thinking...".to_string(),
},
_meta: None,
},
};
let updates = translate_event(&event);
assert_eq!(updates.len(), 1);
let update_params = &updates[0];
assert_eq!(update_params.session_id, "sess-1");
// Thought chunks are rendered as agent message chunks
match &update_params.update {
SessionUpdateVariant::AgentMessageChunk { content } => match content {
ContentBlock::Text { text } => {
assert_eq!(text, "Thinking...");
}
_ => panic!("Expected Text content"),
},
_ => panic!("Expected AgentMessageChunk variant"),
}
}
#[test]
fn test_translate_message_completed() {
let now = chrono::Utc::now();
let event = Event::MessageCompleted {
connector_id: "conn-1".to_string(),
message: Message {
id: "msg-1".to_string(),
session_id: "sess-1".to_string(),
role: MessageRole::Assistant,
created_at: now,
content: vec![],
status: MessageStatus::Completed,
metadata: None,
},
};
let updates = translate_event(&event);
assert_eq!(updates.len(), 1);
let update_params = &updates[0];
assert_eq!(update_params.session_id, "sess-1");
match &update_params.update {
SessionUpdateVariant::MessageComplete { message_id } => {
assert_eq!(message_id, &Some("msg-1".to_string()));
}
_ => panic!("Expected MessageComplete variant"),
}
}
#[test]
fn test_translate_session_idle() {
let event = Event::SessionIdle {
connector_id: "test-connector".to_string(),
session_id: "sess-1".to_string(),
};
let updates = translate_event(&event);
assert_eq!(updates.len(), 1);
let update_params = &updates[0];
assert_eq!(update_params.session_id, "sess-1");
match &update_params.update {
SessionUpdateVariant::SessionIdle {} => {
// Expected
}
_ => panic!("Expected SessionIdle variant"),
}
}
#[test]
fn test_translate_message_failed() {
let event = Event::MessageFailed {
message_id: "msg-1".to_string(),
error: "Connection lost".to_string(),
};
// MessageFailed events are now filtered (return empty vec)
let updates = translate_event(&event);
assert!(
updates.is_empty(),
"MessageFailed events should be filtered"
);
}
#[test]
fn test_translate_filtered_events() {
// These events should not produce notifications
let filtered_events = vec![
Event::Connected,
Event::Disconnected,
Event::SessionDeleted {
session_id: "s".to_string(),
},
Event::Error {
message: "System error".to_string(),
},
Event::MessageFailed {
message_id: "msg-1".to_string(),
error: "Failed".to_string(),
},
];
for event in filtered_events {
assert!(
translate_event(&event).is_empty(),
"Expected {:?} to be filtered",
event
);
}
}
}
+478
View File
@@ -0,0 +1,478 @@
//! SSE (Server-Sent Events) Notifications for ACP Server
//!
//! This module provides the SSE notification infrastructure for streaming events
//! to connected ACP clients. It handles per-client subscriptions using broadcast
//! channels and provides translation from Dirigent protocol events to ACP notifications.
//!
//! ## Architecture
//!
//! The module is organized into several sub-modules:
//!
//! - `types` - Data structures for SSE notifications (SessionUpdateParams, SessionUpdateVariant, etc.)
//! - `notifier` - SseNotifier for managing client subscriptions and broadcasting
//! - `event_translator` - Translation from Dirigent Events to ACP notifications
//! - `content_transform` - Helper functions for content and metadata transformation
//!
//! ## SSE Wire Format
//!
//! The SSE stream uses the following format:
//!
//! ```text
//! event: session/update
//! data: {"sessionId": "...", "update": {"sessionUpdate": "agent_message_chunk", ...}}
//!
//! event: session/update
//! data: {"sessionId": "...", "update": {"sessionUpdate": "message_complete", ...}}
//! ```
//!
//! ## Example
//!
//! ```rust,ignore
//! use dirigent_acp_api::sse::{SseNotifier, SessionUpdateParams, translate_event};
//!
//! // Create the notifier
//! let notifier = SseNotifier::new();
//!
//! // Subscribe a client
//! let stream = notifier.subscribe("client-123");
//!
//! // Translate and broadcast an event
//! let updates = translate_event(&event);
//! for update in updates {
//! notifier.broadcast("client-123", update);
//! }
//!
//! // Unsubscribe when done
//! notifier.unsubscribe("client-123");
//! ```
mod content_transform;
mod event_translator;
mod notifier;
mod types;
// Re-export public types
pub use content_transform::{
extract_kind, rebuild_meta, tool_call_status_to_string, unwrap_tool_call_content,
};
pub use event_translator::translate_event;
pub use notifier::SseNotifier;
pub use types::{
models_to_config_option, modes_to_config_option, AcpNotification, ConfigOption,
ConfigOptionChoice, ConfigOptionType, SessionUpdateParams, SessionUpdateVariant, SlashCommand,
};
// ============================================================================
// Tests (moved from original sse.rs)
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
// ========================================================================
// AcpNotification Tests
// ========================================================================
#[test]
fn test_message_chunk_serialization() {
let notification = AcpNotification::MessageChunk {
session_id: "sess-123".to_string(),
message_id: "msg-456".to_string(),
content: "Hello, world!".to_string(),
content_type: Some("text".to_string()),
};
let json = serde_json::to_string(&notification).unwrap();
assert!(json.contains(r#""type":"message_chunk"#));
assert!(json.contains(r#""sessionId":"sess-123"#));
assert!(json.contains(r#""messageId":"msg-456"#));
assert!(json.contains(r#""content":"Hello, world!"#));
assert!(json.contains(r#""contentType":"text"#));
}
#[test]
fn test_message_complete_serialization() {
let notification = AcpNotification::MessageComplete {
session_id: "sess-123".to_string(),
message_id: "msg-456".to_string(),
};
let json = serde_json::to_string(&notification).unwrap();
assert!(json.contains(r#""type":"message_complete"#));
assert!(json.contains(r#""sessionId":"sess-123"#));
assert!(json.contains(r#""messageId":"msg-456"#));
}
#[test]
fn test_session_idle_serialization() {
let notification = AcpNotification::SessionIdle {
session_id: "sess-123".to_string(),
};
let json = serde_json::to_string(&notification).unwrap();
assert!(json.contains(r#""type":"session_idle"#));
assert!(json.contains(r#""sessionId":"sess-123"#));
}
#[test]
fn test_session_error_serialization() {
let notification = AcpNotification::SessionError {
session_id: "sess-123".to_string(),
error: "Something went wrong".to_string(),
};
let json = serde_json::to_string(&notification).unwrap();
assert!(json.contains(r#""type":"session_error"#));
assert!(json.contains(r#""error":"Something went wrong"#));
}
#[test]
fn test_tool_call_update_serialization() {
let notification = AcpNotification::ToolCallUpdate {
session_id: "sess-123".to_string(),
message_id: "msg-456".to_string(),
tool_call_id: "call-789".to_string(),
tool_name: "bash".to_string(),
status: "running".to_string(),
title: Some("Running command".to_string()),
error: None,
};
let json = serde_json::to_string(&notification).unwrap();
assert!(json.contains(r#""type":"tool_call_update"#));
assert!(json.contains(r#""toolName":"bash"#));
assert!(json.contains(r#""status":"running"#));
assert!(json.contains(r#""title":"Running command"#));
// error should not be present when None
assert!(!json.contains(r#""error""#));
}
#[test]
fn test_event_type() {
assert_eq!(
AcpNotification::MessageChunk {
session_id: "s".to_string(),
message_id: "m".to_string(),
content: "c".to_string(),
content_type: None,
}
.event_type(),
"message/chunk"
);
assert_eq!(
AcpNotification::MessageComplete {
session_id: "s".to_string(),
message_id: "m".to_string(),
}
.event_type(),
"message/complete"
);
assert_eq!(
AcpNotification::SessionIdle {
session_id: "s".to_string(),
}
.event_type(),
"session/idle"
);
assert_eq!(
AcpNotification::SessionError {
session_id: "s".to_string(),
error: "e".to_string(),
}
.event_type(),
"session/error"
);
assert_eq!(
AcpNotification::ToolCallUpdate {
session_id: "s".to_string(),
message_id: "m".to_string(),
tool_call_id: "t".to_string(),
tool_name: "bash".to_string(),
status: "running".to_string(),
title: None,
error: None,
}
.event_type(),
"tool/update"
);
}
#[test]
fn test_to_sse_string() {
let notification = AcpNotification::SessionIdle {
session_id: "sess-123".to_string(),
};
let sse = notification.to_sse_string();
assert!(sse.starts_with("event: session/idle\n"));
assert!(sse.contains("data: "));
assert!(sse.contains("sess-123"));
assert!(sse.ends_with("\n\n"));
}
#[test]
fn test_notification_roundtrip() {
let notifications = vec![
AcpNotification::MessageChunk {
session_id: "s".to_string(),
message_id: "m".to_string(),
content: "c".to_string(),
content_type: Some("text".to_string()),
},
AcpNotification::MessageComplete {
session_id: "s".to_string(),
message_id: "m".to_string(),
},
AcpNotification::SessionIdle {
session_id: "s".to_string(),
},
AcpNotification::SessionError {
session_id: "s".to_string(),
error: "e".to_string(),
},
AcpNotification::ToolCallUpdate {
session_id: "s".to_string(),
message_id: "m".to_string(),
tool_call_id: "t".to_string(),
tool_name: "bash".to_string(),
status: "running".to_string(),
title: Some("title".to_string()),
error: None,
},
];
for notification in notifications {
let json = serde_json::to_string(&notification).unwrap();
let deserialized: AcpNotification = serde_json::from_str(&json).unwrap();
assert_eq!(notification, deserialized);
}
}
// ========================================================================
// SessionUpdate Tag Verification Tests
// ========================================================================
#[test]
fn test_session_update_variant_uses_session_update_tag_tool_call() {
// Test ToolCall variant with flattened structure
let variant = SessionUpdateVariant::ToolCall {
tool_call_id: "call_456".to_string(),
title: Some("Running command".to_string()),
kind: Some("command".to_string()),
raw_input: Some(json!({"command": "ls"})),
status: Some("in_progress".to_string()),
content: vec![],
_meta: Some(json!({
"claudeCode": {
"toolName": "bash"
}
})),
};
let json = serde_json::to_value(&variant).unwrap();
// Verify correct tag name "sessionUpdate" is present
assert!(
json.get("sessionUpdate").is_some(),
"Missing 'sessionUpdate' field in JSON: {:?}",
json
);
assert_eq!(
json["sessionUpdate"], "tool_call",
"Expected sessionUpdate value 'tool_call', got: {:?}",
json["sessionUpdate"]
);
// Verify incorrect tag "type" is NOT present
assert!(
json.get("type").is_none(),
"Should not have 'type' field in ACP output, found: {:?}",
json.get("type")
);
// Verify flattened fields are present
assert!(json.get("toolCallId").is_some(), "Missing toolCallId field");
assert!(json.get("title").is_some(), "Missing title field");
assert!(json.get("status").is_some(), "Missing status field");
}
#[test]
fn test_session_update_variant_roundtrip_with_session_update_tag() {
// Test that serialization and deserialization work correctly with sessionUpdate tag
let variants = vec![
SessionUpdateVariant::SessionIdle {},
SessionUpdateVariant::MessageComplete {
message_id: Some("msg_1".to_string()),
},
SessionUpdateVariant::AgentMessageChunk {
content: dirigent_protocol::ContentBlock::Text {
text: "test".to_string(),
},
},
SessionUpdateVariant::ToolCall {
tool_call_id: "call_1".to_string(),
title: None,
kind: None,
raw_input: None,
status: Some("pending".to_string()),
content: vec![],
_meta: Some(json!({
"claudeCode": {
"toolName": "test_tool"
}
})),
},
];
for variant in variants {
// Serialize
let json_str = serde_json::to_string(&variant).unwrap();
let json_val: serde_json::Value = serde_json::from_str(&json_str).unwrap();
// Verify sessionUpdate tag in JSON
assert!(
json_val.get("sessionUpdate").is_some(),
"Missing sessionUpdate in serialized JSON for variant: {:?}",
variant
);
// Verify no type tag
assert!(
json_val.get("type").is_none(),
"Found unexpected 'type' tag for variant: {:?}",
variant
);
// Deserialize back
let deserialized: SessionUpdateVariant = serde_json::from_str(&json_str).unwrap();
// Verify equality (roundtrip)
assert_eq!(
variant, deserialized,
"Roundtrip failed for variant: {:?}",
variant
);
}
}
// ========================================================================
// camelCase Field Naming Compliance Tests
// ========================================================================
#[test]
fn test_all_fields_use_camel_case() {
// Test SessionUpdateParams with ToolCall variant
let params = SessionUpdateParams {
session_id: "sess_123".to_string(),
update: SessionUpdateVariant::ToolCall {
tool_call_id: "tool_123".to_string(),
title: Some("test".to_string()),
kind: Some("search".to_string()),
raw_input: Some(json!({"key": "value"})),
status: Some("pending".to_string()),
content: vec![],
_meta: None,
},
..Default::default()
};
let json = serde_json::to_value(&params).unwrap();
// Verify top-level camelCase
assert!(
json.get("sessionId").is_some(),
"Should have sessionId (not session_id)"
);
assert!(
json.get("session_id").is_none(),
"Should NOT have session_id"
);
let update = &json["update"];
// Verify ToolCall variant fields use camelCase
assert!(
update.get("toolCallId").is_some(),
"Should have toolCallId (not tool_call_id)"
);
assert!(
update.get("tool_call_id").is_none(),
"Should NOT have tool_call_id"
);
assert!(
update.get("rawInput").is_some(),
"Should have rawInput (not raw_input)"
);
assert!(
update.get("raw_input").is_none(),
"Should NOT have raw_input"
);
}
#[test]
fn test_no_snake_case_in_serialized_json() {
// Create various SessionUpdateParams with different variants
let test_cases = vec![
SessionUpdateParams {
session_id: "s1".to_string(),
update: SessionUpdateVariant::ToolCall {
tool_call_id: "tc1".to_string(),
title: Some("t".to_string()),
kind: Some("k".to_string()),
raw_input: Some(json!({})),
status: Some("pending".to_string()),
content: vec![],
_meta: None,
},
..Default::default()
},
SessionUpdateParams {
session_id: "s2".to_string(),
update: SessionUpdateVariant::ToolCallUpdate {
tool_call_id: "tc2".to_string(),
status: Some("completed".to_string()),
content: vec![],
raw_output: Some(json!({})),
error: None,
_meta: None,
},
..Default::default()
},
SessionUpdateParams {
session_id: "s3".to_string(),
update: SessionUpdateVariant::MessageComplete {
message_id: Some("m3".to_string()),
},
..Default::default()
},
];
for params in test_cases {
let json_str = serde_json::to_string(&params).unwrap();
// Check for common snake_case fields that should NOT appear
let forbidden_patterns = vec![
"session_id",
"tool_call_id",
"raw_input",
"raw_output",
"message_id",
];
for pattern in forbidden_patterns {
assert!(
!json_str.contains(&format!("\"{}\"", pattern)),
"Found forbidden snake_case field '{}' in JSON: {}",
pattern,
json_str
);
}
}
}
}
+403
View File
@@ -0,0 +1,403 @@
//! SSE Notifier for managing client subscriptions
//!
//! This module provides the subscription management infrastructure for SSE streaming.
//! The `SseNotifier` manages per-client broadcast channels for targeted notifications.
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::{Arc, RwLock};
use tokio::sync::broadcast;
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use tokio_stream::wrappers::BroadcastStream;
use tokio_stream::Stream;
use tracing::{debug, trace, warn};
use super::types::SessionUpdateParams;
/// Default broadcast channel capacity
const DEFAULT_CHANNEL_CAPACITY: usize = 256;
/// Internal state for SSE subscriptions
#[derive(Debug, Default)]
struct SseNotifierState {
/// Map from client_id to their broadcast sender
subscriptions: HashMap<String, broadcast::Sender<SessionUpdateParams>>,
}
/// SSE Notifier for managing client subscriptions and broadcasting notifications
///
/// The `SseNotifier` manages per-client subscriptions using tokio broadcast channels.
/// Each client has their own broadcast channel, allowing targeted notifications.
///
/// ## Thread Safety
///
/// The `SseNotifier` is designed to be cloned and shared across async tasks.
/// Internal state is protected by `Arc<RwLock<...>>` for thread-safe access.
///
/// ## Example
///
/// ```rust,ignore
/// let notifier = SseNotifier::new();
///
/// // Subscribe a client and get their stream
/// let stream = notifier.subscribe("client-123");
///
/// // In another task, broadcast notifications
/// notifier.broadcast("client-123", notification);
///
/// // When client disconnects
/// notifier.unsubscribe("client-123");
/// ```
#[derive(Debug, Clone)]
pub struct SseNotifier {
state: Arc<RwLock<SseNotifierState>>,
channel_capacity: usize,
}
impl Default for SseNotifier {
fn default() -> Self {
Self::new()
}
}
impl SseNotifier {
/// Create a new SSE notifier with default capacity
///
/// Creates a new notifier with the default channel capacity of 256 messages.
pub fn new() -> Self {
Self {
state: Arc::new(RwLock::new(SseNotifierState::default())),
channel_capacity: DEFAULT_CHANNEL_CAPACITY,
}
}
/// Create a new SSE notifier with custom channel capacity
///
/// # Parameters
///
/// - `capacity`: The maximum number of messages that can be buffered per client
pub fn with_capacity(capacity: usize) -> Self {
Self {
state: Arc::new(RwLock::new(SseNotifierState::default())),
channel_capacity: capacity,
}
}
/// Subscribe a client and return an async stream of notifications
///
/// Creates a broadcast channel for the client and returns a stream that can
/// be used with Axum's `Sse<impl Stream>` response type.
///
/// If the client is already subscribed, a new receiver is created from the
/// existing sender (allowing multiple connections per client).
///
/// # Parameters
///
/// - `client_id`: Unique identifier for the client
///
/// # Returns
///
/// A pinned stream of `SessionUpdateParams` wrapped in `Result` for error handling.
/// The stream is compatible with Axum's SSE handler.
pub fn subscribe(
&self,
client_id: &str,
) -> Pin<Box<dyn Stream<Item = Result<SessionUpdateParams, BroadcastStreamRecvError>> + Send>>
{
let mut state = self.state.write().expect("Lock poisoned");
// Get or create the sender for this client
let sender = state
.subscriptions
.entry(client_id.to_string())
.or_insert_with(|| {
debug!("Creating new broadcast channel for client: {}", client_id);
let (tx, _rx) = broadcast::channel(self.channel_capacity);
tx
});
// Create a new receiver
let receiver = sender.subscribe();
debug!(
"Client subscribed to SSE: {} (subscribers: {})",
client_id,
sender.receiver_count()
);
// Convert to stream using tokio_stream
let stream = BroadcastStream::new(receiver);
Box::pin(stream)
}
/// Unsubscribe a client and cleanup resources
///
/// Removes the client's subscription and drops the sender, which will
/// cause all associated receivers to receive an error on their next poll.
///
/// # Parameters
///
/// - `client_id`: The ID of the client to unsubscribe
///
/// # Returns
///
/// `true` if the client was subscribed and removed, `false` if not found
pub fn unsubscribe(&self, client_id: &str) -> bool {
let mut state = self.state.write().expect("Lock poisoned");
if let Some(sender) = state.subscriptions.remove(client_id) {
debug!(
"Client unsubscribed from SSE: {} (was {} receivers)",
client_id,
sender.receiver_count()
);
// Sender is dropped here, which will close the channel
true
} else {
debug!("Client was not subscribed: {}", client_id);
false
}
}
/// Broadcast a session update to a specific client
///
/// Sends the session update to all receivers subscribed to the specified client.
/// If no receivers are active (e.g., all have dropped or lagged), the send
/// is handled gracefully.
///
/// # Parameters
///
/// - `client_id`: The ID of the client to send to
/// - `update_params`: The session update parameters to send
///
/// # Returns
///
/// - `Ok(n)` where n is the number of receivers that received the update
/// - `Err(())` if the client is not subscribed
pub fn broadcast(
&self,
client_id: &str,
update_params: SessionUpdateParams,
) -> Result<usize, ()> {
let state = self.state.read().expect("Lock poisoned");
if let Some(sender) = state.subscriptions.get(client_id) {
match sender.send(update_params) {
Ok(n) => {
trace!(
"Broadcast session update to client {}: {} receivers",
client_id,
n
);
Ok(n)
}
Err(_) => {
// No active receivers - this is okay, the update is just dropped
trace!("Broadcast to client {}: no active receivers", client_id);
Ok(0)
}
}
} else {
warn!("Attempted broadcast to unsubscribed client: {}", client_id);
Err(())
}
}
/// Broadcast a session update to all subscribed clients
///
/// Sends the session update to all connected clients. Failed sends to
/// individual clients are logged but don't stop the broadcast.
///
/// # Parameters
///
/// - `update_params`: The session update parameters to broadcast
///
/// # Returns
///
/// The total number of receivers that received the update
pub fn broadcast_all(&self, update_params: SessionUpdateParams) -> usize {
let state = self.state.read().expect("Lock poisoned");
let mut total_receivers = 0;
for (client_id, sender) in state.subscriptions.iter() {
match sender.send(update_params.clone()) {
Ok(n) => {
total_receivers += n;
trace!("Broadcast to client {}: {} receivers", client_id, n);
}
Err(_) => {
trace!("Broadcast to client {}: no active receivers", client_id);
}
}
}
debug!(
"Broadcast to all clients: {} total receivers across {} clients",
total_receivers,
state.subscriptions.len()
);
total_receivers
}
/// Get the number of subscribed clients
pub fn client_count(&self) -> usize {
let state = self.state.read().expect("Lock poisoned");
state.subscriptions.len()
}
/// Check if a client is subscribed
pub fn is_subscribed(&self, client_id: &str) -> bool {
let state = self.state.read().expect("Lock poisoned");
state.subscriptions.contains_key(client_id)
}
/// Get the list of subscribed client IDs
pub fn subscribed_clients(&self) -> Vec<String> {
let state = self.state.read().expect("Lock poisoned");
state.subscriptions.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sse::types::SessionUpdateVariant;
use tokio_stream::StreamExt;
#[test]
fn test_sse_notifier_new() {
let notifier = SseNotifier::new();
assert_eq!(notifier.client_count(), 0);
}
#[test]
fn test_sse_notifier_with_capacity() {
let notifier = SseNotifier::with_capacity(512);
assert_eq!(notifier.channel_capacity, 512);
}
#[tokio::test]
async fn test_subscribe_unsubscribe() {
let notifier = SseNotifier::new();
// Subscribe
let _stream = notifier.subscribe("client-1");
assert!(notifier.is_subscribed("client-1"));
assert_eq!(notifier.client_count(), 1);
// Unsubscribe
assert!(notifier.unsubscribe("client-1"));
assert!(!notifier.is_subscribed("client-1"));
assert_eq!(notifier.client_count(), 0);
// Unsubscribe non-existent
assert!(!notifier.unsubscribe("client-1"));
}
#[tokio::test]
async fn test_broadcast_to_client() {
let notifier = SseNotifier::new();
// Subscribe
let mut stream = notifier.subscribe("client-1");
// Broadcast
let update_params = SessionUpdateParams {
session_id: "sess-1".to_string(),
update: SessionUpdateVariant::SessionIdle {},
..Default::default()
};
let result = notifier.broadcast("client-1", update_params.clone());
assert!(result.is_ok());
assert_eq!(result.unwrap(), 1);
// Receive
let received = stream.next().await;
assert!(received.is_some());
let received = received.unwrap();
assert!(received.is_ok());
assert_eq!(received.unwrap(), update_params);
}
#[tokio::test]
async fn test_broadcast_to_unsubscribed_client() {
let notifier = SseNotifier::new();
let update_params = SessionUpdateParams {
session_id: "sess-1".to_string(),
update: SessionUpdateVariant::SessionIdle {},
..Default::default()
};
let result = notifier.broadcast("unknown-client", update_params);
assert!(result.is_err());
}
#[tokio::test]
async fn test_broadcast_all() {
let notifier = SseNotifier::new();
// Subscribe multiple clients
let mut stream1 = notifier.subscribe("client-1");
let mut stream2 = notifier.subscribe("client-2");
let update_params = SessionUpdateParams {
session_id: "sess-1".to_string(),
update: SessionUpdateVariant::SessionIdle {},
..Default::default()
};
let count = notifier.broadcast_all(update_params.clone());
assert_eq!(count, 2);
// Both should receive
let r1 = stream1.next().await.unwrap().unwrap();
let r2 = stream2.next().await.unwrap().unwrap();
assert_eq!(r1, update_params);
assert_eq!(r2, update_params);
}
#[test]
fn test_subscribed_clients() {
let notifier = SseNotifier::new();
let _s1 = notifier.subscribe("client-1");
let _s2 = notifier.subscribe("client-2");
let clients = notifier.subscribed_clients();
assert_eq!(clients.len(), 2);
assert!(clients.contains(&"client-1".to_string()));
assert!(clients.contains(&"client-2".to_string()));
}
#[tokio::test]
async fn test_multiple_receivers_same_client() {
let notifier = SseNotifier::new();
// Subscribe same client twice
let mut stream1 = notifier.subscribe("client-1");
let mut stream2 = notifier.subscribe("client-1");
let update_params = SessionUpdateParams {
session_id: "sess-1".to_string(),
update: SessionUpdateVariant::SessionIdle {},
..Default::default()
};
let count = notifier.broadcast("client-1", update_params.clone());
assert!(count.is_ok());
assert_eq!(count.unwrap(), 2); // Both receivers
// Both should receive
let r1 = stream1.next().await.unwrap().unwrap();
let r2 = stream2.next().await.unwrap().unwrap();
assert_eq!(r1, update_params);
assert_eq!(r2, update_params);
}
}
+548
View File
@@ -0,0 +1,548 @@
//! SSE Type Definitions for ACP Server
//!
//! This module contains all the data structures and types used for SSE notifications.
//! These types define the wire format for ACP protocol messages.
use serde::{Deserialize, Serialize};
use dirigent_protocol::ContentBlock;
// ============================================================================
// SessionUpdateParams
// ============================================================================
/// Params for session/update JSON-RPC notification (ACP spec compliant)
///
/// This matches the structure sent by Claude and expected by Zed.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "camelCase", default)]
pub struct SessionUpdateParams {
/// The session ID this update belongs to
pub session_id: String,
/// The update content
pub update: SessionUpdateVariant,
/// Optional event type override (default: "session/update")
/// Reserved for future protocol extensions. Currently unused.
#[serde(skip)]
pub event_type_override: Option<String>,
}
impl SessionUpdateParams {
/// Get the SSE event type for session updates
///
/// According to ACP spec, session updates use "session/update" event type.
pub fn event_type(&self) -> &str {
self.event_type_override
.as_deref()
.unwrap_or("session/update")
}
/// Create a raw event with a custom event type
///
/// This creates an event that will be serialized as raw JSON data
/// (without the SessionUpdateParams wrapper) when sent over SSE.
/// Used for `_rpc_response` deferred responses pushed via SSE after gateway transfer.
pub fn raw_event(event_type: &str, data: serde_json::Value) -> Self {
Self {
session_id: String::new(), // Not used for raw events
update: SessionUpdateVariant::Unknown { data },
event_type_override: Some(event_type.to_string()),
}
}
/// Check if this is a raw event (should serialize data directly)
pub fn is_raw_event(&self) -> bool {
self.event_type_override.is_some()
}
/// Get the data to serialize for SSE
///
/// For raw events (with event_type_override), returns the raw JSON data.
/// For normal events, returns None (caller should serialize the full struct).
pub fn raw_data(&self) -> Option<&serde_json::Value> {
if self.event_type_override.is_some() {
if let SessionUpdateVariant::Unknown { data } = &self.update {
return Some(data);
}
}
None
}
/// Serialize to JSON string for SSE transmission
///
/// For raw events, serializes just the raw data.
/// For normal events, serializes the full SessionUpdateParams struct.
pub fn to_sse_json(&self) -> String {
if let Some(raw_data) = self.raw_data() {
// Raw event: serialize just the data
serde_json::to_string(raw_data).unwrap_or_else(|_| "{}".to_string())
} else {
// Normal event: serialize the full struct
serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string())
}
}
}
// ============================================================================
// ConfigOption Types
// ============================================================================
/// A single configuration option for session settings (outgoing ACP wire format)
///
/// Used in `config_option_update` notifications to send the complete set of
/// configuration options (modes, models, etc.) to the client.
///
/// Note: This is the **outgoing** wire format for ACP Server SSE.
/// For the **incoming** format (parsed from agents), see `dirigent_protocol::ConfigOption`.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ConfigOption {
/// Unique identifier for this option (e.g., "mode", "model")
pub id: String,
/// Human-readable display name (e.g., "Session Mode", "Model")
pub name: String,
/// Optional description
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
/// Semantic category for UX grouping (e.g., "mode", "model", "thought_level")
#[serde(skip_serializing_if = "Option::is_none")]
pub category: Option<String>,
/// Input type (e.g., "select", "toggle", "text")
#[serde(rename = "type")]
pub option_type: ConfigOptionType,
/// Currently selected value
pub current_value: String,
/// Available options for "select" type
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<Vec<ConfigOptionChoice>>,
}
/// Type of configuration option input
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ConfigOptionType {
Select,
Toggle,
Text,
}
/// A single choice within a select-type config option
///
/// Matches ACP's `SessionConfigSelectOption` structure.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ConfigOptionChoice {
/// Value identifier for this choice (sent back when selected)
pub value: String,
/// Human-readable display name
pub name: String,
/// Optional description
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
// ============================================================================
// SessionUpdateVariant
// ============================================================================
/// Variants of session updates (ACP spec compliant)
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "sessionUpdate", rename_all = "snake_case")]
pub enum SessionUpdateVariant {
/// Agent message chunk (streaming text)
#[serde(rename = "agent_message_chunk")]
AgentMessageChunk {
/// The content block (text, image, etc.)
content: ContentBlock,
},
/// Message generation complete
#[serde(rename = "message_complete")]
MessageComplete {
/// The message ID that completed
#[serde(rename = "messageId", skip_serializing_if = "Option::is_none")]
message_id: Option<String>,
},
/// Session is idle and ready for input
#[serde(rename = "session_idle")]
SessionIdle {},
/// Available slash commands update
#[serde(rename = "available_commands_update")]
AvailableCommandsUpdate {
/// List of available slash commands
#[serde(rename = "availableCommands")]
available_commands: Vec<SlashCommand>,
},
/// Connector changed (session transferred to a different connector)
#[serde(rename = "connector_changed")]
ConnectorChanged {
/// The new connector ID
#[serde(rename = "newConnectorId")]
new_connector_id: String,
/// The new internal session ID in the target connector
#[serde(rename = "newInternalSessionId")]
new_internal_session_id: String,
/// Whether a new session was created (true) or existing session loaded (false)
#[serde(rename = "isNewSession")]
is_new_session: bool,
},
/// Tool call started/initiated
#[serde(rename = "tool_call")]
ToolCall {
/// The tool call ID
#[serde(rename = "toolCallId")]
tool_call_id: String,
/// Optional title for the tool call
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
/// Optional kind/category (e.g., "search", "edit")
#[serde(skip_serializing_if = "Option::is_none")]
kind: Option<String>,
/// Raw input parameters
#[serde(rename = "rawInput", skip_serializing_if = "Option::is_none")]
raw_input: Option<serde_json::Value>,
/// Current status (pending, in_progress, completed, failed)
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
/// Content blocks (e.g., tool output)
#[serde(default, skip_serializing_if = "Vec::is_empty")]
content: Vec<ContentBlock>,
/// Metadata (can include claudeCode.toolName, toolResponse, etc.)
#[serde(skip_serializing_if = "Option::is_none")]
_meta: Option<serde_json::Value>,
},
/// Tool call update (status change, output, etc.)
#[serde(rename = "tool_call_update")]
ToolCallUpdate {
/// The tool call ID being updated
#[serde(rename = "toolCallId")]
tool_call_id: String,
/// Updated status
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
/// Updated content blocks
#[serde(default, skip_serializing_if = "Vec::is_empty")]
content: Vec<ContentBlock>,
/// Raw output from the tool
#[serde(rename = "rawOutput", skip_serializing_if = "Option::is_none")]
raw_output: Option<serde_json::Value>,
/// Error message if tool call failed
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
/// Metadata (can include toolResponse from Claude)
#[serde(skip_serializing_if = "Option::is_none")]
_meta: Option<serde_json::Value>,
},
/// Agent request needing client response (e.g., permission prompt)
///
/// This variant is used for bidirectional requests where the agent initiates
/// a request (like `session/request_permission`) that requires a response from
/// the client. The client should respond via the `/acp/agent_response` endpoint.
#[serde(rename = "agent_request")]
AgentRequest {
/// The request ID from the agent (to correlate response)
#[serde(rename = "requestId")]
request_id: serde_json::Value,
/// The method being requested (e.g., "session/request_permission")
method: String,
/// The request parameters
params: serde_json::Value,
},
/// Session modes update (full state) - NOT SUPPORTED BY ZED
/// Kept for future ACP protocol support
#[serde(rename = "modes_update")]
ModesUpdate {
/// Session mode state (flattened)
#[serde(flatten)]
modes: dirigent_protocol::SessionModeState,
},
/// Session models update (full state) - NOT SUPPORTED BY ZED
/// Kept for future ACP protocol support
#[serde(rename = "models_update")]
ModelsUpdate {
/// Session model state (flattened)
#[serde(flatten)]
models: dirigent_protocol::SessionModelState,
},
/// Current mode update (standard ACP notification)
/// Signals which mode is currently selected. Zed accepts this.
#[serde(rename = "current_mode_update")]
CurrentModeUpdate {
/// The ID of the currently selected mode
/// Note: Zed expects "currentModeId" not "modeId"
#[serde(rename = "currentModeId")]
mode_id: String,
},
/// Current model update - NOT SUPPORTED BY ZED
/// Kept for completeness but Zed only accepts current_mode_update
#[serde(rename = "current_model_update")]
CurrentModelUpdate {
/// The ID of the currently selected model
#[serde(rename = "modelId")]
model_id: String,
},
/// Config option update - send full configuration options to client
///
/// This is used after session transfer to provide the target connector's
/// actual modes/models instead of Gateway's placeholder values.
/// See: https://agentclientprotocol.com/rfds/session-config-options
#[serde(rename = "config_option_update")]
ConfigOptionUpdate {
/// List of configuration options to display
#[serde(rename = "configOptions")]
config_options: Vec<ConfigOption>,
},
/// Unknown update type (forward compatibility - pass through as raw JSON)
#[serde(untagged)]
Unknown {
#[serde(flatten)]
data: serde_json::Value,
},
}
impl SessionUpdateVariant {
/// Returns the variant name for logging (without data)
pub fn variant_name(&self) -> &'static str {
match self {
SessionUpdateVariant::AgentMessageChunk { .. } => "AgentMessageChunk",
SessionUpdateVariant::MessageComplete { .. } => "MessageComplete",
SessionUpdateVariant::SessionIdle { .. } => "SessionIdle",
SessionUpdateVariant::AvailableCommandsUpdate { .. } => "AvailableCommandsUpdate",
SessionUpdateVariant::ConnectorChanged { .. } => "ConnectorChanged",
SessionUpdateVariant::ToolCall { .. } => "ToolCall",
SessionUpdateVariant::ToolCallUpdate { .. } => "ToolCallUpdate",
SessionUpdateVariant::AgentRequest { .. } => "AgentRequest",
SessionUpdateVariant::ModesUpdate { .. } => "ModesUpdate",
SessionUpdateVariant::ModelsUpdate { .. } => "ModelsUpdate",
SessionUpdateVariant::CurrentModeUpdate { .. } => "CurrentModeUpdate",
SessionUpdateVariant::CurrentModelUpdate { .. } => "CurrentModelUpdate",
SessionUpdateVariant::ConfigOptionUpdate { .. } => "ConfigOptionUpdate",
SessionUpdateVariant::Unknown { .. } => "Unknown",
}
}
}
impl Default for SessionUpdateVariant {
fn default() -> Self {
SessionUpdateVariant::SessionIdle {}
}
}
// ============================================================================
// SlashCommand
// ============================================================================
/// Slash command definition
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct SlashCommand {
pub name: String,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub input: Option<serde_json::Value>,
}
// ============================================================================
// AcpNotification (Legacy)
// ============================================================================
/// ACP Notification types for SSE streaming (DEPRECATED - keeping for compatibility)
///
/// These notifications are sent to clients over the SSE connection to inform
/// them about session events, message streaming, and errors.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(
tag = "type",
rename_all = "snake_case",
rename_all_fields = "camelCase"
)]
pub enum AcpNotification {
/// A chunk of message content has been received
///
/// Sent during streaming to provide incremental content updates.
/// The client should append this content to the message being built.
MessageChunk {
/// The session ID this chunk belongs to
session_id: String,
/// The message ID this chunk belongs to
message_id: String,
/// The content chunk (typically text)
content: String,
/// Optional content type (e.g., "text", "thought")
#[serde(skip_serializing_if = "Option::is_none")]
content_type: Option<String>,
},
/// A message has been completed
///
/// Sent when the agent has finished generating a message.
/// The client should finalize the message display.
MessageComplete {
/// The session ID
session_id: String,
/// The completed message ID
message_id: String,
},
/// A session has become idle
///
/// Sent when the session transitions to an idle state, meaning
/// no active generation is occurring.
SessionIdle {
/// The session ID that became idle
session_id: String,
},
/// An error occurred in a session
///
/// Sent when an error occurs during processing.
SessionError {
/// The session ID where the error occurred
session_id: String,
/// Human-readable error message
error: String,
},
/// A tool call has started or been updated
///
/// Sent when a tool call is initiated or its status changes.
ToolCallUpdate {
/// The session ID
session_id: String,
/// The message ID containing this tool call
message_id: String,
/// The tool call ID
tool_call_id: String,
/// The tool name
tool_name: String,
/// Current status (pending, running, completed, error)
status: String,
/// Optional title for the tool call
#[serde(skip_serializing_if = "Option::is_none")]
title: Option<String>,
/// Optional error message
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
},
}
impl AcpNotification {
/// Get the SSE event type for this notification
///
/// This is used as the `event:` field in the SSE stream.
pub fn event_type(&self) -> &'static str {
match self {
AcpNotification::MessageChunk { .. } => "message/chunk",
AcpNotification::MessageComplete { .. } => "message/complete",
AcpNotification::SessionIdle { .. } => "session/idle",
AcpNotification::SessionError { .. } => "session/error",
AcpNotification::ToolCallUpdate { .. } => "tool/update",
}
}
/// Convert to SSE format string
///
/// Returns a string formatted for SSE transmission:
/// ```text
/// event: <event_type>
/// data: <json_data>
/// ```
pub fn to_sse_string(&self) -> String {
let data = serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string());
format!("event: {}\ndata: {}\n\n", self.event_type(), data)
}
}
// ============================================================================
// ConfigOption Conversion Functions
// ============================================================================
/// Convert SessionModeState to a ConfigOption
///
/// Transforms the protocol's mode state into the ACP `config_option_update` format.
/// Uses "mode" ID to match the ACP protocol standard.
pub fn modes_to_config_option(modes: &dirigent_protocol::SessionModeState) -> ConfigOption {
ConfigOption {
id: "mode".to_string(),
name: "Session Mode".to_string(),
description: None,
category: Some("mode".to_string()),
option_type: ConfigOptionType::Select,
current_value: modes.current_mode_id.clone(),
options: Some(
modes
.available_modes
.iter()
.map(|m| ConfigOptionChoice {
value: m.id.clone(),
name: m.name.clone(),
description: m.description.clone(),
})
.collect(),
),
}
}
/// Convert SessionModelState to a ConfigOption
///
/// Transforms the protocol's model state into the ACP `config_option_update` format.
/// Uses "model" ID to match the ACP protocol standard.
pub fn models_to_config_option(models: &dirigent_protocol::SessionModelState) -> ConfigOption {
ConfigOption {
id: "model".to_string(),
name: "Model".to_string(),
description: None,
category: Some("model".to_string()),
option_type: ConfigOptionType::Select,
current_value: models.current_model_id.clone(),
options: Some(
models
.available_models
.iter()
.map(|m| ConfigOptionChoice {
value: m.model_id.clone(),
name: m.name.clone(),
description: m.description.clone(),
})
.collect(),
),
}
}
+338
View File
@@ -0,0 +1,338 @@
//! Common test utilities for bidirectional flow testing.
//!
//! This module provides mock implementations and test helpers for testing
//! the bidirectional request/response flow in the ACP Server.
use anyhow::Result;
use serde_json::{json, Value};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot, Mutex};
use std::collections::HashMap;
/// Mock event for testing event bridge
#[derive(Debug, Clone)]
pub struct MockEvent {
pub event_type: String,
pub data: Value,
}
/// Mock SSE client for testing
///
/// Simulates an HTTP client that receives SSE events and posts responses.
pub struct MockSseClient {
pub client_id: String,
/// Events received via SSE
pub received_events: Arc<Mutex<Vec<MockEvent>>>,
/// Sender for simulating HTTP POST responses
pub response_tx: mpsc::UnboundedSender<(Value, oneshot::Sender<Result<()>>)>,
}
impl MockSseClient {
pub fn new(client_id: String) -> (Self, mpsc::UnboundedReceiver<(Value, oneshot::Sender<Result<()>>)>) {
let (response_tx, response_rx) = mpsc::unbounded_channel();
(
Self {
client_id,
received_events: Arc::new(Mutex::new(Vec::new())),
response_tx,
},
response_rx,
)
}
/// Simulate receiving an SSE event
pub async fn receive_sse(&self, event_type: String, data: Value) {
let mut events = self.received_events.lock().await;
events.push(MockEvent { event_type, data });
}
/// Get all received events
pub async fn get_events(&self) -> Vec<MockEvent> {
let events = self.received_events.lock().await;
events.clone()
}
/// Get the most recent event of a specific type
pub async fn get_latest_event(&self, event_type: &str) -> Option<MockEvent> {
let events = self.received_events.lock().await;
events.iter()
.filter(|e| e.event_type == event_type)
.last()
.cloned()
}
/// Clear received events
pub async fn clear_events(&self) {
let mut events = self.received_events.lock().await;
events.clear();
}
/// Simulate sending a response to /acp/agent_response
pub async fn send_response(&self, response: Value) -> Result<()> {
let (tx, rx) = oneshot::channel();
self.response_tx.send((response, tx))
.map_err(|_| anyhow::anyhow!("Failed to send response"))?;
rx.await?
}
}
/// Mock connector for testing
///
/// Simulates an ACP connector that can send agent requests and receive responses.
pub struct MockConnector {
pub connector_id: String,
/// Channel for receiving agent requests from this connector
pub request_tx: mpsc::UnboundedSender<(Value, String, String, Value)>,
/// Pending responses (request_id -> sender)
pub pending_responses: Arc<Mutex<HashMap<Value, oneshot::Sender<Value>>>>,
}
impl MockConnector {
pub fn new(
connector_id: String,
) -> (Self, mpsc::UnboundedReceiver<(Value, String, String, Value)>) {
let (request_tx, request_rx) = mpsc::unbounded_channel();
(
Self {
connector_id,
request_tx,
pending_responses: Arc::new(Mutex::new(HashMap::new())),
},
request_rx,
)
}
/// Simulate sending an agent request
pub async fn send_agent_request(
&self,
session_id: String,
request_id: Value,
method: String,
params: Value,
) -> oneshot::Receiver<Value> {
let (tx, rx) = oneshot::channel();
// Store the response sender
let mut pending = self.pending_responses.lock().await;
pending.insert(request_id.clone(), tx);
drop(pending);
// Send the request
self.request_tx.send((request_id, session_id, method, params))
.expect("Failed to send agent request");
rx
}
/// Complete a pending response (simulating response from ACP Server)
pub async fn complete_response(&self, request_id: Value, response: Value) -> Result<()> {
let mut pending = self.pending_responses.lock().await;
if let Some(tx) = pending.remove(&request_id) {
tx.send(response)
.map_err(|_| anyhow::anyhow!("Failed to send response to connector"))?;
Ok(())
} else {
Err(anyhow::anyhow!("No pending response for request_id: {}", request_id))
}
}
}
/// Test context for integration tests
///
/// Provides a complete test environment with mocked components.
pub struct TestContext {
/// Mock SSE clients by client_id
pub clients: Arc<Mutex<HashMap<String, MockSseClient>>>,
/// Mock connectors by connector_id
pub connectors: Arc<Mutex<HashMap<String, MockConnector>>>,
}
impl TestContext {
pub fn new() -> Self {
Self {
clients: Arc::new(Mutex::new(HashMap::new())),
connectors: Arc::new(Mutex::new(HashMap::new())),
}
}
/// Create a mock SSE client
pub async fn create_client(
&self,
client_id: String,
) -> (MockSseClient, mpsc::UnboundedReceiver<(Value, oneshot::Sender<Result<()>>)>) {
let (client, response_rx) = MockSseClient::new(client_id.clone());
let mut clients = self.clients.lock().await;
clients.insert(client_id.clone(), client.clone());
(client, response_rx)
}
/// Create a mock connector
pub async fn create_connector(
&self,
connector_id: String,
) -> (MockConnector, mpsc::UnboundedReceiver<(Value, String, String, Value)>) {
let (connector, request_rx) = MockConnector::new(connector_id.clone());
let mut connectors = self.connectors.lock().await;
connectors.insert(connector_id.clone(), connector.clone());
(connector, request_rx)
}
/// Get a client by ID
pub async fn get_client(&self, client_id: &str) -> Option<MockSseClient> {
let clients = self.clients.lock().await;
clients.get(client_id).cloned()
}
/// Get a connector by ID
pub async fn get_connector(&self, connector_id: &str) -> Option<MockConnector> {
let connectors = self.connectors.lock().await;
connectors.get(connector_id).cloned()
}
}
impl Default for TestContext {
fn default() -> Self {
Self::new()
}
}
/// Helper to create a sample permission request
pub fn sample_permission_request(request_id: u64) -> Value {
json!({
"jsonrpc": "2.0",
"id": request_id,
"method": "session/request_permission",
"params": {
"sessionId": "test-session",
"tool": "Write",
"parameters": {
"path": "/tmp/test.txt",
"content": "test"
}
}
})
}
/// Helper to create a sample permission response
pub fn sample_permission_response(request_id: u64, allow: bool) -> Value {
json!({
"jsonrpc": "2.0",
"id": request_id,
"result": {
"selectedOptionId": if allow { "allow" } else { "deny" }
}
})
}
/// Helper to extract agent_request data from SSE event
pub fn extract_agent_request(event: &MockEvent) -> Option<(Value, String, Value)> {
if event.event_type != "session/update" {
return None;
}
let update = event.data.get("update")?;
if update.get("sessionUpdate")?.as_str()? != "agent_request" {
return None;
}
let request_id = update.get("requestId")?.clone();
let method = update.get("method")?.as_str()?.to_string();
let params = update.get("params")?.clone();
Some((request_id, method, params))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_client() {
let (client, _response_rx) = MockSseClient::new("test-client".to_string());
// Simulate receiving an event
client.receive_sse("session/update".to_string(), json!({"test": "data"})).await;
// Verify event was received
let events = client.get_events().await;
assert_eq!(events.len(), 1);
assert_eq!(events[0].event_type, "session/update");
}
#[tokio::test]
async fn test_mock_connector() {
let (connector, mut request_rx) = MockConnector::new("test-connector".to_string());
// Simulate sending an agent request
let response_fut = connector.send_agent_request(
"session-1".to_string(),
json!(0),
"session/request_permission".to_string(),
json!({"tool": "Write"}),
).await;
// Verify request was sent
let request = request_rx.recv().await.unwrap();
assert_eq!(request.0, json!(0));
assert_eq!(request.1, "session-1");
assert_eq!(request.2, "session/request_permission");
// Simulate completing the response
connector.complete_response(json!(0), json!({"result": "success"})).await.unwrap();
// Verify response was received
let response = response_fut.await.unwrap();
assert_eq!(response, json!({"result": "success"}));
}
#[tokio::test]
async fn test_test_context() {
let ctx = TestContext::new();
// Create client and connector
let (client, _) = ctx.create_client("client-1".to_string()).await;
let (connector, _) = ctx.create_connector("connector-1".to_string()).await;
// Verify we can retrieve them
assert!(ctx.get_client("client-1").await.is_some());
assert!(ctx.get_connector("connector-1").await.is_some());
assert!(ctx.get_client("non-existent").await.is_none());
}
#[test]
fn test_sample_helpers() {
let request = sample_permission_request(0);
assert_eq!(request["method"], "session/request_permission");
let response = sample_permission_response(0, true);
assert_eq!(response["result"]["selectedOptionId"], "allow");
}
#[test]
fn test_extract_agent_request() {
let event = MockEvent {
event_type: "session/update".to_string(),
data: json!({
"sessionId": "session-1",
"update": {
"sessionUpdate": "agent_request",
"requestId": 0,
"method": "session/request_permission",
"params": {"tool": "Write"}
}
}),
};
let (request_id, method, params) = extract_agent_request(&event).unwrap();
assert_eq!(request_id, json!(0));
assert_eq!(method, "session/request_permission");
assert_eq!(params["tool"], "Write");
}
}
@@ -0,0 +1,327 @@
//! Integration test for concurrent agent requests (T050)
//!
//! This test verifies that the system can handle multiple agent requests
//! simultaneously without cross-contamination.
//!
//! Test scenario:
//! 1. Register multiple pending requests concurrently
//! 2. Complete them in random/different order
//! 3. Verify each response goes to the correct request
//! 4. Verify no cross-contamination
use dirigent_acp_api::agent_requests::AgentRequestTracker;
use serde_json::json;
use tokio::task::JoinSet;
#[tokio::test]
async fn test_concurrent_requests_basic() {
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
// Register 5 concurrent requests
let mut receivers = Vec::new();
for i in 0..5 {
let rx = tracker.register(client_id, json!(i));
receivers.push((i, rx));
}
assert_eq!(tracker.pending_count(), 5);
// Complete them in reverse order
for i in (0..5).rev() {
let response = json!({"request": i, "result": "success"});
tracker.complete(client_id, json!(i), response).unwrap();
}
assert_eq!(tracker.pending_count(), 0);
// Verify each receiver got the correct response
for (i, rx) in receivers {
let response = rx.await.unwrap();
assert_eq!(response["request"], i);
}
}
#[tokio::test]
async fn test_concurrent_requests_multiple_clients() {
let tracker = AgentRequestTracker::new();
// Two clients, each with 3 requests
let client1 = "client-1";
let client2 = "client-2";
let mut receivers1 = Vec::new();
let mut receivers2 = Vec::new();
for i in 0..3 {
let rx1 = tracker.register(client1, json!(i));
let rx2 = tracker.register(client2, json!(i));
receivers1.push((i, rx1));
receivers2.push((i, rx2));
}
assert_eq!(tracker.pending_count(), 6);
assert_eq!(tracker.client_pending_count(client1), 3);
assert_eq!(tracker.client_pending_count(client2), 3);
// Complete client1's requests
for i in 0..3 {
let response = json!({"client": 1, "request": i});
tracker.complete(client1, json!(i), response).unwrap();
}
assert_eq!(tracker.client_pending_count(client1), 0);
assert_eq!(tracker.client_pending_count(client2), 3);
// Complete client2's requests
for i in 0..3 {
let response = json!({"client": 2, "request": i});
tracker.complete(client2, json!(i), response).unwrap();
}
assert_eq!(tracker.pending_count(), 0);
// Verify each receiver got the correct response
for (i, rx) in receivers1 {
let response = rx.await.unwrap();
assert_eq!(response["client"], 1);
assert_eq!(response["request"], i);
}
for (i, rx) in receivers2 {
let response = rx.await.unwrap();
assert_eq!(response["client"], 2);
assert_eq!(response["request"], i);
}
}
#[tokio::test]
async fn test_concurrent_requests_same_id_different_clients() {
// Test that same request_id for different clients are handled independently
let tracker = AgentRequestTracker::new();
let client1 = "client-1";
let client2 = "client-2";
let request_id = json!(0); // Same ID for both
let rx1 = tracker.register(client1, request_id.clone());
let rx2 = tracker.register(client2, request_id.clone());
assert_eq!(tracker.pending_count(), 2);
// Complete client1's request
let response1 = json!({"client": "client-1"});
tracker.complete(client1, request_id.clone(), response1.clone()).unwrap();
// Complete client2's request
let response2 = json!({"client": "client-2"});
tracker.complete(client2, request_id, response2.clone()).unwrap();
assert_eq!(tracker.pending_count(), 0);
// Verify each got the correct response
let received1 = rx1.await.unwrap();
let received2 = rx2.await.unwrap();
assert_eq!(received1["client"], "client-1");
assert_eq!(received2["client"], "client-2");
}
#[tokio::test]
async fn test_concurrent_async_completion() {
// Test completing requests from multiple async tasks concurrently
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
let num_requests = 10;
// Register requests
let mut receivers = Vec::new();
for i in 0..num_requests {
let rx = tracker.register(client_id, json!(i));
receivers.push((i, rx));
}
assert_eq!(tracker.pending_count(), num_requests);
// Spawn tasks to complete requests concurrently
let mut join_set = JoinSet::new();
for i in 0..num_requests {
let tracker_clone = tracker.clone();
join_set.spawn(async move {
// Small delay to ensure concurrency
tokio::time::sleep(tokio::time::Duration::from_millis(((i % 3) * 10) as u64)).await;
let response = json!({"request": i, "result": "success"});
tracker_clone.complete(client_id, json!(i), response)
});
}
// Wait for all completions
while let Some(result) = join_set.join_next().await {
assert!(result.unwrap().is_ok());
}
assert_eq!(tracker.pending_count(), 0);
// Verify all receivers got correct responses
for (i, rx) in receivers {
let response = rx.await.unwrap();
assert_eq!(response["request"], i);
}
}
#[tokio::test]
async fn test_concurrent_register_and_complete() {
// Test registering and completing requests concurrently
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
let num_requests = 20;
let mut join_set = JoinSet::new();
// Spawn tasks to register and complete requests
for i in 0..num_requests {
let tracker_clone = tracker.clone();
join_set.spawn(async move {
// Register
let rx = tracker_clone.register(client_id, json!(i));
// Small delay
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
// Complete
let response = json!({"request": i});
tracker_clone.complete(client_id, json!(i), response.clone()).unwrap();
// Wait for response
let received = rx.await.unwrap();
assert_eq!(received["request"], i);
i
});
}
// Wait for all tasks
let mut completed = Vec::new();
while let Some(result) = join_set.join_next().await {
completed.push(result.unwrap());
}
// All requests should have completed
assert_eq!(completed.len(), num_requests);
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_concurrent_mixed_operations() {
// Test mix of register, complete, and timeout operations
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
let mut join_set = JoinSet::new();
// Spawn 15 tasks with different behaviors
for i in 0..15 {
let tracker_clone = tracker.clone();
join_set.spawn(async move {
let rx = tracker_clone.register(client_id, json!(i));
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
match i % 3 {
0 => {
// Complete normally
let response = json!({"request": i, "type": "complete"});
tracker_clone.complete(client_id, json!(i), response.clone()).unwrap();
let received = rx.await.unwrap();
assert_eq!(received["type"], "complete");
"completed"
}
1 => {
// Timeout
tracker_clone.timeout(client_id, json!(i));
assert!(rx.await.is_err());
"timeout"
}
_ => {
// Complete with delay
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
let response = json!({"request": i, "type": "delayed"});
tracker_clone.complete(client_id, json!(i), response.clone()).unwrap();
let received = rx.await.unwrap();
assert_eq!(received["type"], "delayed");
"delayed"
}
}
});
}
// Wait for all tasks
let mut results = Vec::new();
while let Some(result) = join_set.join_next().await {
results.push(result.unwrap());
}
assert_eq!(results.len(), 15);
// Count outcomes
let completed = results.iter().filter(|&r| r == &"completed").count();
let timeout = results.iter().filter(|&r| r == &"timeout").count();
let delayed = results.iter().filter(|&r| r == &"delayed").count();
assert_eq!(completed, 5);
assert_eq!(timeout, 5);
assert_eq!(delayed, 5);
// All should be cleaned up
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_high_concurrency() {
// Stress test with many concurrent requests
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
let num_requests = 100;
let mut join_set = JoinSet::new();
for i in 0..num_requests {
let tracker_clone = tracker.clone();
join_set.spawn(async move {
let rx = tracker_clone.register(client_id, json!(i));
// Random-ish delay
let delay = ((i * 7) % 20) as u64;
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
let response = json!({"request": i});
tracker_clone.complete(client_id, json!(i), response).unwrap();
rx.await.unwrap()["request"].as_u64().unwrap()
});
}
// Collect all results
let mut results = Vec::new();
while let Some(result) = join_set.join_next().await {
results.push(result.unwrap());
}
// Verify all requests completed
assert_eq!(results.len(), num_requests);
// Verify all request IDs are present
results.sort();
for (idx, &val) in results.iter().enumerate() {
assert_eq!(val, idx as u64);
}
// All cleaned up
assert_eq!(tracker.pending_count(), 0);
}
@@ -0,0 +1,286 @@
//! Integration test for client disconnection (T051)
//!
//! This test verifies that pending agent requests are cleaned up when a client
//! disconnects, and that other clients are unaffected.
//!
//! Test scenario:
//! 1. Register pending requests for multiple clients
//! 2. Simulate client disconnection
//! 3. Verify cleanup occurs for disconnected client
//! 4. Verify other clients are unaffected
use dirigent_acp_api::agent_requests::AgentRequestTracker;
use serde_json::json;
#[tokio::test]
async fn test_disconnect_single_client() {
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
// Register 3 pending requests
let rx1 = tracker.register(client_id, json!(0));
let rx2 = tracker.register(client_id, json!(1));
let rx3 = tracker.register(client_id, json!(2));
assert_eq!(tracker.pending_count(), 3);
assert_eq!(tracker.client_pending_count(client_id), 3);
// Simulate client disconnection - clear all requests for this client
tracker.clear(Some(client_id));
// All requests should be removed
assert_eq!(tracker.pending_count(), 0);
assert_eq!(tracker.client_pending_count(client_id), 0);
// All receivers should get errors (channels closed)
assert!(rx1.await.is_err());
assert!(rx2.await.is_err());
assert!(rx3.await.is_err());
}
#[tokio::test]
async fn test_disconnect_multiple_clients() {
let tracker = AgentRequestTracker::new();
let client1 = "client-1";
let client2 = "client-2";
let client3 = "client-3";
// Register requests for all clients
let rx1_1 = tracker.register(client1, json!(0));
let rx1_2 = tracker.register(client1, json!(1));
let rx2_1 = tracker.register(client2, json!(0));
let rx2_2 = tracker.register(client2, json!(1));
let rx2_3 = tracker.register(client2, json!(2));
let rx3_1 = tracker.register(client3, json!(0));
assert_eq!(tracker.pending_count(), 6);
assert_eq!(tracker.client_pending_count(client1), 2);
assert_eq!(tracker.client_pending_count(client2), 3);
assert_eq!(tracker.client_pending_count(client3), 1);
// Disconnect client2
tracker.clear(Some(client2));
// Only client2's requests should be removed
assert_eq!(tracker.pending_count(), 3);
assert_eq!(tracker.client_pending_count(client1), 2);
assert_eq!(tracker.client_pending_count(client2), 0);
assert_eq!(tracker.client_pending_count(client3), 1);
// Client2's receivers should error
assert!(rx2_1.await.is_err());
assert!(rx2_2.await.is_err());
assert!(rx2_3.await.is_err());
// Client1 and client3 should still work
tracker.complete(client1, json!(0), json!({"result": "client1-0"})).unwrap();
assert_eq!(rx1_1.await.unwrap()["result"], "client1-0");
tracker.complete(client3, json!(0), json!({"result": "client3-0"})).unwrap();
assert_eq!(rx3_1.await.unwrap()["result"], "client3-0");
// Complete remaining client1 request
tracker.complete(client1, json!(1), json!({"result": "client1-1"})).unwrap();
assert_eq!(rx1_2.await.unwrap()["result"], "client1-1");
// All cleaned up
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_disconnect_then_reconnect() {
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
// First connection - register requests
let rx1 = tracker.register(client_id, json!(0));
let rx2 = tracker.register(client_id, json!(1));
assert_eq!(tracker.pending_count(), 2);
// Disconnect - cleanup
tracker.clear(Some(client_id));
assert_eq!(tracker.pending_count(), 0);
// Old receivers should error
assert!(rx1.await.is_err());
assert!(rx2.await.is_err());
// Reconnect - register new requests (same client_id, same request_ids)
let rx3 = tracker.register(client_id, json!(0));
let rx4 = tracker.register(client_id, json!(1));
assert_eq!(tracker.pending_count(), 2);
// Complete new requests
tracker.complete(client_id, json!(0), json!({"result": "new-0"})).unwrap();
tracker.complete(client_id, json!(1), json!({"result": "new-1"})).unwrap();
// New receivers should get responses
assert_eq!(rx3.await.unwrap()["result"], "new-0");
assert_eq!(rx4.await.unwrap()["result"], "new-1");
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_disconnect_no_pending_requests() {
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
// No pending requests
assert_eq!(tracker.client_pending_count(client_id), 0);
// Disconnect should be no-op
tracker.clear(Some(client_id));
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_clear_all_clients() {
let tracker = AgentRequestTracker::new();
let client1 = "client-1";
let client2 = "client-2";
let client3 = "client-3";
// Register requests for multiple clients
let rx1 = tracker.register(client1, json!(0));
let rx2 = tracker.register(client2, json!(0));
let rx3 = tracker.register(client3, json!(0));
assert_eq!(tracker.pending_count(), 3);
// Clear all (simulating server shutdown)
tracker.clear(None);
// All should be removed
assert_eq!(tracker.pending_count(), 0);
assert_eq!(tracker.client_pending_count(client1), 0);
assert_eq!(tracker.client_pending_count(client2), 0);
assert_eq!(tracker.client_pending_count(client3), 0);
// All receivers should error
assert!(rx1.await.is_err());
assert!(rx2.await.is_err());
assert!(rx3.await.is_err());
}
#[tokio::test]
async fn test_disconnect_race_with_completion() {
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
// Register multiple requests
let rx1 = tracker.register(client_id, json!(0));
let rx2 = tracker.register(client_id, json!(1));
let rx3 = tracker.register(client_id, json!(2));
assert_eq!(tracker.pending_count(), 3);
// Complete one request
tracker.complete(client_id, json!(0), json!({"result": "0"})).unwrap();
// Verify it was removed
assert_eq!(tracker.pending_count(), 2);
// Now disconnect (should only clear remaining requests)
tracker.clear(Some(client_id));
assert_eq!(tracker.pending_count(), 0);
// First receiver should have gotten response
assert_eq!(rx1.await.unwrap()["result"], "0");
// Other receivers should error
assert!(rx2.await.is_err());
assert!(rx3.await.is_err());
}
#[tokio::test]
async fn test_partial_disconnect_completion() {
// Test that completing a request after disconnect fails gracefully
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
let _rx = tracker.register(client_id, json!(0));
assert_eq!(tracker.pending_count(), 1);
// Disconnect
tracker.clear(Some(client_id));
assert_eq!(tracker.pending_count(), 0);
// Try to complete after disconnect - should fail
let result = tracker.complete(client_id, json!(0), json!({"result": "late"}));
assert!(result.is_err());
}
#[tokio::test]
async fn test_concurrent_disconnect_and_complete() {
use tokio::task::JoinSet;
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
// Register many requests
for i in 0..50 {
tracker.register(client_id, json!(i));
}
assert_eq!(tracker.pending_count(), 50);
let mut join_set = JoinSet::new();
// Spawn task to disconnect after delay
{
let tracker_clone = tracker.clone();
join_set.spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
tracker_clone.clear(Some(client_id));
"disconnected"
});
}
// Spawn tasks to complete requests
for i in 0..50 {
let tracker_clone = tracker.clone();
join_set.spawn(async move {
// Small delay to ensure some complete before disconnect
tokio::time::sleep(tokio::time::Duration::from_millis((i % 5) as u64)).await;
let response = json!({"request": i});
match tracker_clone.complete(client_id, json!(i), response) {
Ok(_) => "completed",
Err(_) => "failed",
}
});
}
// Wait for all tasks
let mut results = Vec::new();
while let Some(result) = join_set.join_next().await {
results.push(result.unwrap());
}
// Should have 1 disconnect + 50 completion attempts
assert_eq!(results.len(), 51);
// Some completions succeeded, some failed (after disconnect)
let disconnects = results.iter().filter(|r| r == &&"disconnected").count();
assert_eq!(disconnects, 1);
// All requests should be cleaned up
assert_eq!(tracker.pending_count(), 0);
}
@@ -0,0 +1,153 @@
//! Integration tests for session/list RPC method
use dirigent_acp_api::{
NoOpConnectorOperations, RpcHandler, SessionManager,
};
use dirigent_acp_api::agent_requests::AgentRequestTracker;
use dirigent_acp_api::sse::SseNotifier;
use serde_json::json;
use std::sync::Arc;
fn create_test_handler() -> RpcHandler<NoOpConnectorOperations> {
RpcHandler::new(
SessionManager::new(),
NoOpConnectorOperations,
SseNotifier::new(),
Arc::new(AgentRequestTracker::new()),
)
}
#[tokio::test]
async fn test_session_list_returns_sessions() {
let handler = create_test_handler();
let request_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "session/list",
"params": {
"connectorId": "stub-connector"
}
});
let response = handler.handle_request(&request_body.to_string(), None).await;
let response_json = serde_json::to_value(&response).unwrap();
// NoOp returns one stub session
let result = &response_json["result"];
assert!(result["sessions"].is_array());
let sessions = result["sessions"].as_array().unwrap();
assert!(!sessions.is_empty());
assert!(sessions[0]["sessionId"].is_string());
}
#[tokio::test]
async fn test_session_list_no_params() {
let handler = create_test_handler();
// session/list with no params should use default connector
let request_body = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "session/list"
});
let response = handler.handle_request(&request_body.to_string(), None).await;
let response_json = serde_json::to_value(&response).unwrap();
// Should succeed using default connector from NoOp
let result = &response_json["result"];
assert!(result["sessions"].is_array());
}
#[tokio::test]
async fn test_session_list_creates_mappings() {
let session_manager = SessionManager::new();
let handler = RpcHandler::new(
session_manager.clone(),
NoOpConnectorOperations,
SseNotifier::new(),
Arc::new(AgentRequestTracker::new()),
);
// First, initialize to register a client
let init_body = json!({
"jsonrpc": "2.0",
"id": 0,
"method": "initialize",
"params": {}
});
let init_response = handler.handle_request(&init_body.to_string(), Some("test-client")).await;
let init_json = serde_json::to_value(&init_response).unwrap();
let client_id = init_json["result"]["clientId"].as_str().unwrap();
// Now list sessions
let request_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "session/list",
"params": {
"connectorId": "stub-connector"
}
});
let response = handler.handle_request(&request_body.to_string(), Some(client_id)).await;
let response_json = serde_json::to_value(&response).unwrap();
let sessions = response_json["result"]["sessions"].as_array().unwrap();
assert!(!sessions.is_empty());
// Session mapping should exist for the returned session
let session_id = sessions[0]["sessionId"].as_str().unwrap();
let mapping = session_manager.get_mapping(session_id);
assert!(mapping.is_some(), "Session mapping should be created for listed sessions");
}
#[tokio::test]
async fn test_session_load_without_connector_id() {
let handler = create_test_handler();
// Standard ACP: only sessionId + cwd + mcpServers, no connectorId
let request_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "session/load",
"params": {
"sessionId": "sess-abc",
"cwd": "G:\\dev\\projects\\test",
"mcpServers": []
}
});
let response = handler.handle_request(&request_body.to_string(), None).await;
let response_json = serde_json::to_value(&response).unwrap();
let result = &response_json["result"];
assert!(result["sessionId"].is_string(), "session/load should succeed without connectorId");
assert!(result["createdAt"].is_string());
}
#[tokio::test]
async fn test_initialize_advertises_list_sessions() {
let handler = create_test_handler();
let request_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {}
});
let response = handler.handle_request(&request_body.to_string(), None).await;
let response_json = serde_json::to_value(&response).unwrap();
let caps = &response_json["result"]["agentCapabilities"];
assert_eq!(caps["listSessions"], true, "listSessions capability should be advertised");
assert_eq!(caps["loadSession"], true, "loadSession capability should still be advertised");
// Verify nested sessionCapabilities.list is advertised (required by Zed v0.9.4+)
assert!(
caps["sessionCapabilities"]["list"].is_object(),
"sessionCapabilities.list should be advertised as an empty object"
);
}
@@ -0,0 +1,134 @@
//! Integration tests for session/resume RPC method
use dirigent_acp_api::{
NoOpConnectorOperations, RpcHandler, SessionManager,
};
use dirigent_acp_api::agent_requests::AgentRequestTracker;
use dirigent_acp_api::sse::SseNotifier;
use serde_json::json;
use std::sync::Arc;
fn create_test_handler() -> RpcHandler<NoOpConnectorOperations> {
RpcHandler::new(
SessionManager::new(),
NoOpConnectorOperations,
SseNotifier::new(),
Arc::new(AgentRequestTracker::new()),
)
}
#[tokio::test]
async fn test_session_resume_returns_session() {
let handler = create_test_handler();
let request_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "session/resume",
"params": {
"sessionId": "sess-123",
"connectorId": "stub-connector"
}
});
let response = handler.handle_request(&request_body.to_string(), None).await;
let response_json = serde_json::to_value(&response).unwrap();
let result = &response_json["result"];
assert!(result["sessionId"].is_string());
assert!(result["connectorId"].is_string());
assert!(result["createdAt"].is_string());
}
#[tokio::test]
async fn test_session_resume_missing_params() {
let handler = create_test_handler();
let request_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "session/resume"
});
let response = handler.handle_request(&request_body.to_string(), None).await;
let response_json = serde_json::to_value(&response).unwrap();
// Should return error for missing params
assert!(response_json["error"].is_object());
}
#[tokio::test]
async fn test_session_resume_creates_mapping() {
let session_manager = SessionManager::new();
let handler = RpcHandler::new(
session_manager.clone(),
NoOpConnectorOperations,
SseNotifier::new(),
Arc::new(AgentRequestTracker::new()),
);
let request_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "session/resume",
"params": {
"sessionId": "sess-456",
"connectorId": "stub-connector"
}
});
let response = handler.handle_request(&request_body.to_string(), None).await;
let response_json = serde_json::to_value(&response).unwrap();
let session_id = response_json["result"]["sessionId"].as_str().unwrap();
let mapping = session_manager.get_mapping(session_id);
assert!(mapping.is_some(), "Session mapping should be created for resumed session");
}
#[tokio::test]
async fn test_session_resume_without_connector_id() {
let handler = create_test_handler();
// Standard ACP: only sessionId, no connectorId — should resolve via default connector
let request_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "session/resume",
"params": {
"sessionId": "sess-789",
"cwd": "G:\\dev\\projects\\test"
}
});
let response = handler.handle_request(&request_body.to_string(), None).await;
let response_json = serde_json::to_value(&response).unwrap();
let result = &response_json["result"];
assert!(result["sessionId"].is_string(), "Should succeed without connectorId");
assert!(result["createdAt"].is_string());
}
#[tokio::test]
async fn test_initialize_advertises_session_resume() {
let handler = create_test_handler();
let request_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {}
});
let response = handler.handle_request(&request_body.to_string(), None).await;
let response_json = serde_json::to_value(&response).unwrap();
let caps = &response_json["result"]["agentCapabilities"];
assert!(
caps["sessionCapabilities"]["list"].is_object(),
"sessionCapabilities.list should be advertised as an empty object"
);
assert!(
caps["sessionCapabilities"]["resume"].is_object(),
"sessionCapabilities.resume should be advertised as an empty object"
);
}
@@ -0,0 +1,229 @@
//! Integration test for timeout handling (T049)
//!
//! This test verifies that the system handles timeouts gracefully when a client
//! fails to respond to an agent request within the timeout period.
//!
//! Test scenario:
//! 1. Register a pending agent request
//! 2. Wait for timeout (using reduced timeout for testing)
//! 3. Verify timeout occurs
//! 4. Verify cleanup happens correctly
//! 5. Verify no resource leaks
use dirigent_acp_api::agent_requests::AgentRequestTracker;
use serde_json::json;
use tokio::time::{timeout, Duration};
#[tokio::test]
async fn test_timeout_basic() {
// Create tracker
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
let request_id = json!(0);
// Register request
let receiver = tracker.register(client_id, request_id.clone());
assert_eq!(tracker.pending_count(), 1);
// Wait for timeout (use short timeout for testing)
let result = timeout(Duration::from_millis(100), receiver).await;
// Should timeout
assert!(result.is_err(), "Expected timeout but request completed");
// Manually trigger cleanup (in production, event bridge does this)
tracker.timeout(client_id, request_id);
// Verify cleanup
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_timeout_cleanup() {
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
let request_id = json!(123);
// Register request
let receiver = tracker.register(client_id, request_id.clone());
assert_eq!(tracker.pending_count(), 1);
// Trigger timeout before waiting
tracker.timeout(client_id, request_id);
// Verify cleanup happened
assert_eq!(tracker.pending_count(), 0);
// Receiver should get error (channel closed)
let result = receiver.await;
assert!(result.is_err(), "Expected receiver to get error after timeout");
}
#[tokio::test]
async fn test_timeout_multiple_clients() {
let tracker = AgentRequestTracker::new();
let client1 = "client-1";
let client2 = "client-2";
// Register requests for both clients
let rx1 = tracker.register(client1, json!(0));
let rx2 = tracker.register(client2, json!(0));
assert_eq!(tracker.pending_count(), 2);
assert_eq!(tracker.client_pending_count(client1), 1);
assert_eq!(tracker.client_pending_count(client2), 1);
// Timeout only client1's request
tracker.timeout(client1, json!(0));
// Verify only client1's request is removed
assert_eq!(tracker.pending_count(), 1);
assert_eq!(tracker.client_pending_count(client1), 0);
assert_eq!(tracker.client_pending_count(client2), 1);
// Client1's receiver should error
assert!(rx1.await.is_err());
// Complete client2's request normally
let response = json!({"result": "success"});
let result = tracker.complete(client2, json!(0), response);
assert!(result.is_ok());
// Client2's receiver should get response
let received = rx2.await.unwrap();
assert_eq!(received, json!({"result": "success"}));
// All cleaned up
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_timeout_no_double_cleanup() {
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
let request_id = json!(0);
// Register request
let _receiver = tracker.register(client_id, request_id.clone());
assert_eq!(tracker.pending_count(), 1);
// First timeout - should remove
tracker.timeout(client_id, request_id.clone());
assert_eq!(tracker.pending_count(), 0);
// Second timeout - should be no-op (not panic)
tracker.timeout(client_id, request_id);
assert_eq!(tracker.pending_count(), 0);
}
#[tokio::test]
async fn test_timeout_race_with_complete() {
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
let request_id = json!(0);
// Register request
let receiver = tracker.register(client_id, request_id.clone());
assert_eq!(tracker.pending_count(), 1);
// Complete the request
let response = json!({"result": "success"});
let result = tracker.complete(client_id, request_id.clone(), response.clone());
assert!(result.is_ok());
assert_eq!(tracker.pending_count(), 0);
// Try to timeout after completion - should be no-op
tracker.timeout(client_id, request_id);
assert_eq!(tracker.pending_count(), 0);
// Receiver should still get the response
let received = receiver.await.unwrap();
assert_eq!(received, response);
}
#[tokio::test]
async fn test_concurrent_timeouts() {
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
// Register 10 requests
let mut receivers = Vec::new();
for i in 0..10 {
let rx = tracker.register(client_id, json!(i));
receivers.push((i, rx));
}
assert_eq!(tracker.pending_count(), 10);
// Spawn tasks to timeout each request after random delays
let tracker_clone = tracker.clone();
let timeout_handles: Vec<_> = (0..10)
.map(|i| {
let tracker = tracker_clone.clone();
tokio::spawn(async move {
// Small random-ish delay based on index
tokio::time::sleep(Duration::from_millis((i * 10) as u64)).await;
tracker.timeout(client_id, json!(i));
})
})
.collect();
// Wait for all timeouts to complete
for handle in timeout_handles {
handle.await.unwrap();
}
// All should be cleaned up
assert_eq!(tracker.pending_count(), 0);
// All receivers should get errors
for (_i, rx) in receivers {
assert!(rx.await.is_err());
}
}
#[tokio::test]
async fn test_timeout_with_actual_delay() {
// This test uses actual time delays to verify timeout behavior more realistically
let tracker = AgentRequestTracker::new();
let client_id = "test-client";
let request_id = json!(0);
let start = std::time::Instant::now();
// Register request
let receiver = tracker.register(client_id, request_id.clone());
// Spawn task to timeout after 200ms
let tracker_clone = tracker.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(200)).await;
tracker_clone.timeout(client_id, json!(0));
});
// Wait on receiver with longer timeout
let result = timeout(Duration::from_secs(1), receiver).await;
let elapsed = start.elapsed();
// Should complete due to timeout() call, not tokio::time::timeout
assert!(result.is_ok(), "Should complete when timeout() is called");
assert!(result.unwrap().is_err(), "Receiver should get error");
// Should take approximately 200ms
assert!(
elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(300),
"Expected ~200ms but took {:?}",
elapsed
);
// Should be cleaned up
assert_eq!(tracker.pending_count(), 0);
}