sync from monorepo @ 2452e92e
This commit is contained in:
@@ -0,0 +1,124 @@
|
||||
# Package: dirigent_acp_api
|
||||
|
||||
ACP Server implementation for accepting incoming ACP connections from external agents.
|
||||
|
||||
## Quick Facts
|
||||
- **Type**: Library
|
||||
- **Main Entry**: src/lib.rs
|
||||
- **Dependencies**: axum, tokio, serde, tracing, uuid, async-trait, dirigent_protocol
|
||||
- **Status**: Core structure complete, integration with CoreRuntime pending
|
||||
|
||||
## Overview
|
||||
|
||||
The `dirigent_acp_api` package implements an ACP (Agent-Client Protocol) server that allows Dirigent to accept incoming connections from external ACP clients like Claude Code or custom agents. This enables session sharing, remote orchestration, and multi-client collaboration.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Core Components
|
||||
|
||||
- **config.rs** - Server configuration types (`AcpServerConfig`)
|
||||
- **error.rs** - Error types (`AcpServerError`, `JsonRpcErrorObject`)
|
||||
- **jsonrpc.rs** - JSON-RPC 2.0 types and parsing
|
||||
- **rpc.rs** - RPC handler and method dispatch
|
||||
- **session_manager.rs** - Session/client tracking (TODO)
|
||||
- **sse.rs** - SSE notification system (TODO)
|
||||
- **event_bridge.rs** - Event translation (TODO)
|
||||
|
||||
### Key Types
|
||||
|
||||
```rust
|
||||
pub struct AcpServerConfig {
|
||||
pub enabled: bool, // Enable/disable server
|
||||
pub port: u16, // Listen port (default: 3001)
|
||||
pub allowed_origins: Option<Vec<String>>, // CORS origins
|
||||
pub max_connections: usize, // Connection limit (default: 100)
|
||||
}
|
||||
```
|
||||
|
||||
### ConnectorOperations Trait
|
||||
|
||||
The RPC handler uses a trait abstraction to avoid circular dependencies with dirigent_core:
|
||||
|
||||
```rust
|
||||
#[async_trait]
|
||||
pub trait ConnectorOperations: Send + Sync {
|
||||
async fn create_session(&self, connector_id: &str) -> Result<String>;
|
||||
async fn load_session(&self, connector_id: &str, session_id: &str) -> Result<Session>;
|
||||
async fn send_prompt(&self, connector_id: &str, session_id: &str, prompt: &str) -> Result<String>;
|
||||
// ... more methods
|
||||
}
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### POST `/rpc`
|
||||
|
||||
JSON-RPC 2.0 endpoint supporting:
|
||||
- `initialize` - Client handshake
|
||||
- `session/new` - Create session
|
||||
- `session/load` - Load existing session
|
||||
- `session/prompt` - Send prompt
|
||||
- `session/cancel` - Cancel generation
|
||||
- `session/close` - Close session
|
||||
|
||||
### GET `/events`
|
||||
|
||||
Server-Sent Events for streaming notifications:
|
||||
- `acp/messageChunk` - Streaming content
|
||||
- `acp/messageComplete` - Generation complete
|
||||
- `acp/sessionIdle` - Ready for input
|
||||
|
||||
### GET `/health`
|
||||
|
||||
Health check endpoint.
|
||||
|
||||
## Configuration UI
|
||||
|
||||
The ACP Server is configured via the web UI at **Configuration > ACP Server**:
|
||||
|
||||
- Enable/disable toggle
|
||||
- Port configuration
|
||||
- Max connections limit
|
||||
- Allowed origins (CORS)
|
||||
- Default connector selection
|
||||
- Connected clients management
|
||||
|
||||
Server functions in `crates/api/src/acp_server.rs` bridge the UI and this package.
|
||||
|
||||
## Implementation Status
|
||||
|
||||
**Completed:**
|
||||
- Configuration types (`AcpServerConfig`)
|
||||
- Error types (`AcpServerError`)
|
||||
- JSON-RPC types and parsing
|
||||
- RPC handler structure with ConnectorOperations trait
|
||||
|
||||
**Pending:**
|
||||
- Session Manager implementation
|
||||
- SSE Notifier implementation
|
||||
- Event Bridge implementation
|
||||
- Axum router integration
|
||||
- Web server integration
|
||||
|
||||
## Key Files
|
||||
|
||||
| File | Description |
|
||||
|------|-------------|
|
||||
| `src/lib.rs` | Module exports and router creation |
|
||||
| `src/config.rs` | AcpServerConfig with validation |
|
||||
| `src/error.rs` | Error types and codes |
|
||||
| `src/jsonrpc.rs` | JSON-RPC 2.0 implementation |
|
||||
| `src/rpc.rs` | RPC handler and method dispatch |
|
||||
|
||||
## Related Packages
|
||||
|
||||
- **dirigent_core** - Provides CoreHandle implementation of ConnectorOperations
|
||||
- **dirigent_protocol** - Shared event and message types
|
||||
- **api** - Server functions for UI configuration
|
||||
- **web** - Configuration UI components
|
||||
|
||||
## Documentation
|
||||
|
||||
- **Architecture**: `docs/architecture/acp_server.md`
|
||||
- **Configuration**: `docs/configuration/acp-connectors.md`
|
||||
- **Tasks**: `docs/building/07_acp_serve/02_acp_server_tasks.md`
|
||||
@@ -0,0 +1,30 @@
|
||||
[package]
|
||||
name = "dirigent_acp_api"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
async-trait = "0.1"
|
||||
# ACP (Agent-Client Protocol) dependencies
|
||||
axum = "0.8"
|
||||
# ACP Server dependencies (Phase 2)
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
# Workspace dependencies
|
||||
# Note: dirigent_protocol is used for Event types
|
||||
dirigent_protocol = { path = "../dirigent_protocol" }
|
||||
# Note: dirigent_core is NOT a direct dependency to avoid circular dependency.
|
||||
# CoreHandle is passed into the ACP server at runtime via generics or trait objects.
|
||||
# The mode/model mapping logic is duplicated here for legacy mode support.
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-stream = { version = "0.1", features = ["sync"] }
|
||||
tower = "0.5"
|
||||
tower-http = { version = "0.6", features = ["cors"] }
|
||||
# ACP Server dependencies (Phase 1)
|
||||
tracing = "0.1"
|
||||
uuid = { version = "1.0", features = ["serde", "v4", "v7"] }
|
||||
@@ -0,0 +1,5 @@
|
||||
This crate should expose an ACP API for dirigent to be used by another ACP Client.
|
||||
|
||||
Funny Integration Test would be to use Dirigent using Dirigent (but that would need probably a dummy acp agent to be used by that)
|
||||
|
||||
This will however also require dirigent to "pass through" functionality, it believes to be responsible for.
|
||||
@@ -0,0 +1,438 @@
|
||||
//! Agent Request Tracker
|
||||
//!
|
||||
//! This module provides infrastructure for tracking agent-initiated requests
|
||||
//! that require responses from HTTP clients. It's used to implement bidirectional
|
||||
//! JSON-RPC communication in the ACP Server.
|
||||
//!
|
||||
//! ## Use Case
|
||||
//!
|
||||
//! When an agent (like Claude) sends a request to the client (like a permission
|
||||
//! prompt), the ACP Server needs to:
|
||||
//! 1. Forward the request to the client via SSE
|
||||
//! 2. Wait for the client's response via HTTP POST
|
||||
//! 3. Deliver the response back to the agent
|
||||
//!
|
||||
//! The `AgentRequestTracker` manages the pending requests and provides a way
|
||||
//! to correlate responses with their corresponding requests.
|
||||
//!
|
||||
//! ## Example Flow
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
//! use serde_json::json;
|
||||
//!
|
||||
//! // Create tracker
|
||||
//! let tracker = AgentRequestTracker::new();
|
||||
//!
|
||||
//! // Agent sends request - register it and get receiver
|
||||
//! let request_id = json!(0);
|
||||
//! let client_id = "client-123";
|
||||
//! let receiver = tracker.register(client_id, request_id.clone());
|
||||
//!
|
||||
//! // Forward request to client via SSE...
|
||||
//!
|
||||
//! // Client responds via HTTP POST - complete the request
|
||||
//! let response = json!({"selectedOptionId": "allow"});
|
||||
//! tracker.complete(client_id, request_id, response)?;
|
||||
//!
|
||||
//! // The receiver now gets the response
|
||||
//! let response_value = receiver.await?;
|
||||
//! ```
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use anyhow::{anyhow, Result};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// Tracks pending agent requests awaiting client responses
|
||||
///
|
||||
/// This struct provides thread-safe storage for correlating agent requests
|
||||
/// with their eventual client responses. It uses oneshot channels to deliver
|
||||
/// responses to waiting tasks.
|
||||
///
|
||||
/// ## Thread Safety
|
||||
///
|
||||
/// The tracker is designed to be cloned and shared across async tasks.
|
||||
/// Internal state is protected by `Arc<Mutex<...>>` for thread-safe access.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentRequestTracker {
|
||||
/// Maps (client_id, request_id_string) to oneshot sender for delivering response
|
||||
///
|
||||
/// The key is a tuple of client ID and request ID (as string) to uniquely
|
||||
/// identify each pending request. The value is a oneshot sender that will
|
||||
/// be used to deliver the response when it arrives.
|
||||
pending: Arc<Mutex<HashMap<(String, String), oneshot::Sender<Value>>>>,
|
||||
}
|
||||
|
||||
impl Default for AgentRequestTracker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl AgentRequestTracker {
|
||||
/// Create a new agent request tracker
|
||||
///
|
||||
/// Returns an empty tracker ready to register pending requests.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
pending: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a pending agent request and return a receiver for the response
|
||||
///
|
||||
/// This method creates a oneshot channel, stores the sender in the pending
|
||||
/// requests map, and returns the receiver. The caller can await on the
|
||||
/// receiver to get the client's response.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `client_id`: The ID of the client that should respond to this request
|
||||
/// - `request_id`: The request ID from the agent (JSON-RPC id field)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A oneshot receiver that will receive the client's response when
|
||||
/// `complete()` is called with matching client_id and request_id.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let receiver = tracker.register("client-123", json!(0));
|
||||
///
|
||||
/// // Later, when client responds...
|
||||
/// tracker.complete("client-123", json!(0), response)?;
|
||||
///
|
||||
/// // The receiver gets the response
|
||||
/// let response = receiver.await?;
|
||||
/// ```
|
||||
pub fn register(&self, client_id: &str, request_id: Value) -> oneshot::Receiver<Value> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
let key = (client_id.to_string(), request_id.to_string());
|
||||
|
||||
let mut pending = self.pending.lock().expect("Lock poisoned");
|
||||
pending.insert(key.clone(), tx);
|
||||
|
||||
debug!(
|
||||
client_id = %client_id,
|
||||
request_id = %request_id,
|
||||
"Registered pending agent request"
|
||||
);
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
/// Complete a pending agent request with the client's response
|
||||
///
|
||||
/// This method looks up the pending request, sends the response through
|
||||
/// the oneshot channel, and removes it from the pending map.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `client_id`: The ID of the client sending the response
|
||||
/// - `request_id`: The request ID from the original agent request
|
||||
/// - `response`: The client's response (JSON-RPC response object)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - `Ok(())` if the request was found and the response was delivered
|
||||
/// - `Err` if the request_id was not found in pending requests
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if:
|
||||
/// - The request_id is not found in pending requests (may have timed out)
|
||||
/// - The receiver has been dropped (unlikely but possible)
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// // Client POSTs response to /acp/agent_response
|
||||
/// let response = json!({
|
||||
/// "jsonrpc": "2.0",
|
||||
/// "id": 0,
|
||||
/// "result": {"selectedOptionId": "allow"}
|
||||
/// });
|
||||
///
|
||||
/// tracker.complete("client-123", json!(0), response)?;
|
||||
/// ```
|
||||
pub fn complete(&self, client_id: &str, request_id: Value, response: Value) -> Result<()> {
|
||||
let key = (client_id.to_string(), request_id.to_string());
|
||||
|
||||
let mut pending = self.pending.lock().expect("Lock poisoned");
|
||||
|
||||
if let Some(sender) = pending.remove(&key) {
|
||||
debug!(
|
||||
client_id = %client_id,
|
||||
request_id = %request_id,
|
||||
"Completing pending agent request"
|
||||
);
|
||||
|
||||
// Send the response through the oneshot channel
|
||||
sender.send(response).map_err(|_| {
|
||||
anyhow!(
|
||||
"Failed to send response for request {}: receiver dropped",
|
||||
request_id
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
warn!(
|
||||
client_id = %client_id,
|
||||
request_id = %request_id,
|
||||
"Attempted to complete non-existent agent request (may have timed out)"
|
||||
);
|
||||
|
||||
Err(anyhow!(
|
||||
"Request ID {} not found for client {}",
|
||||
request_id,
|
||||
client_id
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Timeout a pending agent request
|
||||
///
|
||||
/// This method removes a pending request from the map and logs a timeout
|
||||
/// warning. It should be called when a request has been pending for too
|
||||
/// long (e.g., 30 seconds) without a client response.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `client_id`: The ID of the client that was supposed to respond
|
||||
/// - `request_id`: The request ID that timed out
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This method does not send any error response through the channel.
|
||||
/// The receiver will get a `RecvError` when it tries to receive, which
|
||||
/// the caller should interpret as a timeout.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use tokio::time::{timeout, Duration};
|
||||
///
|
||||
/// let receiver = tracker.register("client-123", json!(0));
|
||||
///
|
||||
/// // Wait up to 30 seconds for response
|
||||
/// match timeout(Duration::from_secs(30), receiver).await {
|
||||
/// Ok(Ok(response)) => {
|
||||
/// // Got response
|
||||
/// }
|
||||
/// Ok(Err(_)) => {
|
||||
/// // Channel closed (timeout was called)
|
||||
/// tracker.timeout("client-123", json!(0));
|
||||
/// }
|
||||
/// Err(_) => {
|
||||
/// // Timeout elapsed
|
||||
/// tracker.timeout("client-123", json!(0));
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub fn timeout(&self, client_id: &str, request_id: Value) {
|
||||
let key = (client_id.to_string(), request_id.to_string());
|
||||
|
||||
let mut pending = self.pending.lock().expect("Lock poisoned");
|
||||
|
||||
if pending.remove(&key).is_some() {
|
||||
warn!(
|
||||
client_id = %client_id,
|
||||
request_id = %request_id,
|
||||
"Agent request timed out (30s elapsed without client response)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of pending requests
|
||||
///
|
||||
/// This method returns the total number of requests currently awaiting
|
||||
/// client responses across all clients.
|
||||
pub fn pending_count(&self) -> usize {
|
||||
let pending = self.pending.lock().expect("Lock poisoned");
|
||||
pending.len()
|
||||
}
|
||||
|
||||
/// Get the number of pending requests for a specific client
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `client_id`: The ID of the client to query
|
||||
pub fn client_pending_count(&self, client_id: &str) -> usize {
|
||||
let pending = self.pending.lock().expect("Lock poisoned");
|
||||
pending
|
||||
.keys()
|
||||
.filter(|(cid, _)| cid == client_id)
|
||||
.count()
|
||||
}
|
||||
|
||||
/// Clear all pending requests (used when shutting down or on client disconnect)
|
||||
///
|
||||
/// This method removes all pending requests from the tracker. The oneshot
|
||||
/// senders are dropped, which will cause their receivers to get `RecvError`.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `client_id`: Optional client ID to clear only that client's pending requests.
|
||||
/// If None, clears all pending requests.
|
||||
pub fn clear(&self, client_id: Option<&str>) {
|
||||
let mut pending = self.pending.lock().expect("Lock poisoned");
|
||||
|
||||
match client_id {
|
||||
Some(id) => {
|
||||
let keys_to_remove: Vec<_> = pending
|
||||
.keys()
|
||||
.filter(|(cid, _)| cid == id)
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
for key in keys_to_remove {
|
||||
pending.remove(&key);
|
||||
}
|
||||
|
||||
debug!(
|
||||
client_id = %id,
|
||||
"Cleared all pending agent requests for client"
|
||||
);
|
||||
}
|
||||
None => {
|
||||
let count = pending.len();
|
||||
pending.clear();
|
||||
|
||||
debug!(
|
||||
count = count,
|
||||
"Cleared all pending agent requests"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_register_and_complete() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "client-123";
|
||||
let request_id = json!(0);
|
||||
let response = json!({"result": "success"});
|
||||
|
||||
// Register request
|
||||
let receiver = tracker.register(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// Complete request
|
||||
let result = tracker.complete(client_id, request_id, response.clone());
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Receiver should get the response
|
||||
let received = receiver.await.unwrap();
|
||||
assert_eq!(received, response);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_complete_non_existent_request() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "client-123";
|
||||
let request_id = json!(999);
|
||||
let response = json!({"result": "success"});
|
||||
|
||||
// Try to complete non-existent request
|
||||
let result = tracker.complete(client_id, request_id, response);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "client-123";
|
||||
let request_id = json!(0);
|
||||
|
||||
// Register request
|
||||
let receiver = tracker.register(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// Timeout the request
|
||||
tracker.timeout(client_id, request_id);
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Receiver should get error (channel closed)
|
||||
let result = receiver.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_pending_requests() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
|
||||
// Register multiple requests
|
||||
let _rx1 = tracker.register(client1, json!(0));
|
||||
let _rx2 = tracker.register(client1, json!(1));
|
||||
let _rx3 = tracker.register(client2, json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 3);
|
||||
assert_eq!(tracker.client_pending_count(client1), 2);
|
||||
assert_eq!(tracker.client_pending_count(client2), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_clear_all() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let _rx1 = tracker.register("client-1", json!(0));
|
||||
let _rx2 = tracker.register("client-2", json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
|
||||
tracker.clear(None);
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_clear_client() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
|
||||
let _rx1 = tracker.register(client1, json!(0));
|
||||
let _rx2 = tracker.register(client1, json!(1));
|
||||
let _rx3 = tracker.register(client2, json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 3);
|
||||
|
||||
// Clear only client1's requests
|
||||
tracker.clear(Some(client1));
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
assert_eq!(tracker.client_pending_count(client1), 0);
|
||||
assert_eq!(tracker.client_pending_count(client2), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tracker_clone() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
let tracker_clone = tracker.clone();
|
||||
|
||||
// Both should point to same underlying state
|
||||
let _rx = tracker.register("client-1", json!(0));
|
||||
assert_eq!(tracker_clone.pending_count(), 1);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,275 @@
|
||||
//! Configuration types for the ACP Server
|
||||
//!
|
||||
//! This module defines configuration types for the ACP Server, including
|
||||
//! server settings, connection limits, and CORS configuration.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Default port for the ACP Server
|
||||
pub const DEFAULT_PORT: u16 = 3001;
|
||||
|
||||
/// Default maximum number of concurrent connections
|
||||
pub const DEFAULT_MAX_CONNECTIONS: usize = 100;
|
||||
|
||||
/// Configuration for the ACP Server
|
||||
///
|
||||
/// This struct contains all configurable options for the ACP Server,
|
||||
/// including network settings, security options, and resource limits.
|
||||
///
|
||||
/// **Note**: This config is used when starting an actual TCP server
|
||||
/// (separate port mode only). For integrated mode (mounting at /acp),
|
||||
/// the port field is still required but represents which port was chosen
|
||||
/// for the separate server path. Higher-level configs (dirigent_core, api)
|
||||
/// use `Option<u16>` to represent the integrated vs separate distinction.
|
||||
///
|
||||
/// TODO: Consider moving the AcpPortConfig enum from web package to core
|
||||
/// and using it here to make the integrated/separate distinction explicit.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct AcpServerConfig {
|
||||
/// Whether the ACP Server is enabled
|
||||
///
|
||||
/// When disabled, the server will not accept incoming connections.
|
||||
/// Default: false
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
|
||||
/// The port to listen on for incoming connections
|
||||
///
|
||||
/// This is always a concrete port number because this config is used
|
||||
/// to start an actual TCP server. Use Option<u16> in higher-level configs
|
||||
/// to represent integrated mode (None) vs separate mode (Some(port)).
|
||||
///
|
||||
/// Default: 3001
|
||||
#[serde(default = "default_port")]
|
||||
pub port: u16,
|
||||
|
||||
/// List of allowed origins for CORS
|
||||
///
|
||||
/// When Some, only requests from these origins are allowed.
|
||||
/// When None, all origins are allowed (use with caution).
|
||||
///
|
||||
/// Example: ["http://localhost:3000", "https://app.example.com"]
|
||||
#[serde(default)]
|
||||
pub allowed_origins: Option<Vec<String>>,
|
||||
|
||||
/// Maximum number of concurrent client connections
|
||||
///
|
||||
/// New connections will be rejected when this limit is reached.
|
||||
/// Default: 100
|
||||
#[serde(default = "default_max_connections")]
|
||||
pub max_connections: usize,
|
||||
}
|
||||
|
||||
/// Returns the default port value
|
||||
fn default_port() -> u16 {
|
||||
DEFAULT_PORT
|
||||
}
|
||||
|
||||
/// Returns the default max connections value
|
||||
fn default_max_connections() -> usize {
|
||||
DEFAULT_MAX_CONNECTIONS
|
||||
}
|
||||
|
||||
impl Default for AcpServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
port: DEFAULT_PORT,
|
||||
allowed_origins: None,
|
||||
max_connections: DEFAULT_MAX_CONNECTIONS,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AcpServerConfig {
|
||||
/// Create a new configuration with default values
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Create a configuration with the server enabled
|
||||
pub fn enabled() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration with a specific port
|
||||
pub fn with_port(port: u16) -> Self {
|
||||
Self {
|
||||
port,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
/// Set whether the server is enabled
|
||||
pub fn set_enabled(mut self, enabled: bool) -> Self {
|
||||
self.enabled = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the port
|
||||
pub fn set_port(mut self, port: u16) -> Self {
|
||||
self.port = port;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the allowed origins
|
||||
pub fn set_allowed_origins(mut self, origins: Option<Vec<String>>) -> Self {
|
||||
self.allowed_origins = origins;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the maximum number of connections
|
||||
pub fn set_max_connections(mut self, max: usize) -> Self {
|
||||
self.max_connections = max;
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if the configuration is valid
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.port == 0 {
|
||||
return Err("Port cannot be 0".to_string());
|
||||
}
|
||||
|
||||
if self.max_connections == 0 {
|
||||
return Err("max_connections must be at least 1".to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = AcpServerConfig::default();
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.port, DEFAULT_PORT);
|
||||
assert!(config.allowed_origins.is_none());
|
||||
assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_enabled_config() {
|
||||
let config = AcpServerConfig::enabled();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.port, DEFAULT_PORT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_port() {
|
||||
let config = AcpServerConfig::with_port(8080);
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.port, 8080);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_pattern() {
|
||||
let config = AcpServerConfig::new()
|
||||
.set_enabled(true)
|
||||
.set_port(4000)
|
||||
.set_allowed_origins(Some(vec!["http://localhost:3000".to_string()]))
|
||||
.set_max_connections(50);
|
||||
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.port, 4000);
|
||||
assert_eq!(
|
||||
config.allowed_origins,
|
||||
Some(vec!["http://localhost:3000".to_string()])
|
||||
);
|
||||
assert_eq!(config.max_connections, 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validation_valid() {
|
||||
let config = AcpServerConfig::default();
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validation_invalid_port() {
|
||||
let config = AcpServerConfig {
|
||||
port: 0,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validation_invalid_max_connections() {
|
||||
let config = AcpServerConfig {
|
||||
max_connections: 0,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let config = AcpServerConfig {
|
||||
enabled: true,
|
||||
port: 3001,
|
||||
allowed_origins: Some(vec!["http://localhost:3000".to_string()]),
|
||||
max_connections: 100,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
assert!(json.contains("\"enabled\":true"));
|
||||
assert!(json.contains("\"port\":3001"));
|
||||
assert!(json.contains("\"allowed_origins\""));
|
||||
assert!(json.contains("\"max_connections\":100"));
|
||||
|
||||
// Deserialize back
|
||||
let parsed: AcpServerConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, config);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialization_with_defaults() {
|
||||
// Minimal JSON with defaults
|
||||
let json = r#"{"enabled":true}"#;
|
||||
let config: AcpServerConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.port, DEFAULT_PORT);
|
||||
assert!(config.allowed_origins.is_none());
|
||||
assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialization_empty() {
|
||||
// Empty JSON should use all defaults
|
||||
let json = "{}";
|
||||
let config: AcpServerConfig = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.port, DEFAULT_PORT);
|
||||
assert!(config.allowed_origins.is_none());
|
||||
assert_eq!(config.max_connections, DEFAULT_MAX_CONNECTIONS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_equality() {
|
||||
let config1 = AcpServerConfig::default();
|
||||
let config2 = AcpServerConfig::default();
|
||||
assert_eq!(config1, config2);
|
||||
|
||||
let config3 = AcpServerConfig::enabled();
|
||||
assert_ne!(config1, config3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clone() {
|
||||
let config = AcpServerConfig::enabled()
|
||||
.set_port(8080)
|
||||
.set_allowed_origins(Some(vec!["origin".to_string()]));
|
||||
|
||||
let cloned = config.clone();
|
||||
assert_eq!(config, cloned);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,362 @@
|
||||
//! Error types for the ACP Server
|
||||
//!
|
||||
//! This module defines error types used throughout the ACP Server implementation,
|
||||
//! including conversions to JSON-RPC error format for client responses.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Standard JSON-RPC error codes as defined in the specification
|
||||
pub mod error_codes {
|
||||
/// Parse error - Invalid JSON was received by the server
|
||||
pub const PARSE_ERROR: i32 = -32700;
|
||||
|
||||
/// Invalid Request - The JSON sent is not a valid Request object
|
||||
pub const INVALID_REQUEST: i32 = -32600;
|
||||
|
||||
/// Method not found - The method does not exist / is not available
|
||||
pub const METHOD_NOT_FOUND: i32 = -32601;
|
||||
|
||||
/// Invalid params - Invalid method parameter(s)
|
||||
pub const INVALID_PARAMS: i32 = -32602;
|
||||
|
||||
/// Internal error - Internal JSON-RPC error
|
||||
pub const INTERNAL_ERROR: i32 = -32603;
|
||||
|
||||
// Server errors reserved for implementation-defined errors (-32000 to -32099)
|
||||
|
||||
/// Session not found error
|
||||
pub const SESSION_NOT_FOUND: i32 = -32001;
|
||||
|
||||
/// Connector not found error
|
||||
pub const CONNECTOR_NOT_FOUND: i32 = -32002;
|
||||
|
||||
/// Invalid session error
|
||||
pub const INVALID_SESSION: i32 = -32003;
|
||||
|
||||
/// Transport error
|
||||
pub const TRANSPORT_ERROR: i32 = -32004;
|
||||
|
||||
/// Client not found error
|
||||
pub const CLIENT_NOT_FOUND: i32 = -32005;
|
||||
}
|
||||
|
||||
/// ACP Server error enum representing all possible error conditions
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum AcpServerError {
|
||||
/// The session ID provided is invalid or malformed
|
||||
InvalidSession,
|
||||
|
||||
/// An RPC-related error occurred
|
||||
///
|
||||
/// Contains a description of the RPC error, such as method not found,
|
||||
/// invalid params, or parse errors.
|
||||
RpcError(String),
|
||||
|
||||
/// A transport-level error occurred
|
||||
///
|
||||
/// Contains a description of the transport error, such as connection
|
||||
/// failures, timeout, or network issues.
|
||||
TransportError(String),
|
||||
|
||||
/// The requested session was not found
|
||||
///
|
||||
/// The session ID does not correspond to any known session in the
|
||||
/// session manager.
|
||||
SessionNotFound,
|
||||
|
||||
/// The requested connector was not found
|
||||
///
|
||||
/// The connector ID does not correspond to any registered connector
|
||||
/// in the runtime. Contains the connector ID that was not found.
|
||||
ConnectorNotFound(String),
|
||||
|
||||
/// The connector is not in a ready state
|
||||
///
|
||||
/// The connector exists but is not available to handle requests
|
||||
/// (e.g., still connecting, error state, or stopped).
|
||||
/// Contains the connector ID that is not ready.
|
||||
ConnectorNotReady(String),
|
||||
|
||||
/// Operation timed out
|
||||
///
|
||||
/// Contains a description of what timed out.
|
||||
Timeout(String),
|
||||
|
||||
/// An internal server error occurred
|
||||
///
|
||||
/// Contains a description of the internal error. Used for unexpected
|
||||
/// conditions that don't fit other categories.
|
||||
Internal(String),
|
||||
|
||||
/// The requested client was not found
|
||||
///
|
||||
/// The client ID does not correspond to any connected client in the
|
||||
/// session manager. Contains the client ID that was not found.
|
||||
ClientNotFound(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for AcpServerError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
AcpServerError::InvalidSession => {
|
||||
write!(f, "Invalid session ID provided")
|
||||
}
|
||||
AcpServerError::RpcError(msg) => {
|
||||
write!(f, "RPC error: {}", msg)
|
||||
}
|
||||
AcpServerError::TransportError(msg) => {
|
||||
write!(f, "Transport error: {}", msg)
|
||||
}
|
||||
AcpServerError::SessionNotFound => {
|
||||
write!(f, "Session not found")
|
||||
}
|
||||
AcpServerError::ConnectorNotFound(id) => {
|
||||
write!(f, "Connector not found: {}", id)
|
||||
}
|
||||
AcpServerError::ConnectorNotReady(id) => {
|
||||
write!(f, "Connector not ready: {}", id)
|
||||
}
|
||||
AcpServerError::Timeout(msg) => {
|
||||
write!(f, "Timeout: {}", msg)
|
||||
}
|
||||
AcpServerError::Internal(msg) => {
|
||||
write!(f, "Internal error: {}", msg)
|
||||
}
|
||||
AcpServerError::ClientNotFound(id) => {
|
||||
write!(f, "Client not found: {}", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for AcpServerError {
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
// None of our error variants wrap other errors currently
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// JSON-RPC error representation for wire format
|
||||
///
|
||||
/// This struct is used in JSON-RPC responses when an error occurs.
|
||||
/// It follows the JSON-RPC 2.0 specification for error objects.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct JsonRpcErrorObject {
|
||||
/// A Number that indicates the error type that occurred
|
||||
pub code: i32,
|
||||
|
||||
/// A String providing a short description of the error
|
||||
pub message: String,
|
||||
|
||||
/// Optional additional information about the error
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl JsonRpcErrorObject {
|
||||
/// Create a new JSON-RPC error object
|
||||
pub fn new(code: i32, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
code,
|
||||
message: message.into(),
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new JSON-RPC error object with additional data
|
||||
pub fn with_data(code: i32, message: impl Into<String>, data: serde_json::Value) -> Self {
|
||||
Self {
|
||||
code,
|
||||
message: message.into(),
|
||||
data: Some(data),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a parse error
|
||||
pub fn parse_error(message: impl Into<String>) -> Self {
|
||||
Self::new(error_codes::PARSE_ERROR, message)
|
||||
}
|
||||
|
||||
/// Create an invalid request error
|
||||
pub fn invalid_request(message: impl Into<String>) -> Self {
|
||||
Self::new(error_codes::INVALID_REQUEST, message)
|
||||
}
|
||||
|
||||
/// Create a method not found error
|
||||
pub fn method_not_found(method: impl Into<String>) -> Self {
|
||||
Self::new(
|
||||
error_codes::METHOD_NOT_FOUND,
|
||||
format!("Method not found: {}", method.into()),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create an invalid params error
|
||||
pub fn invalid_params(message: impl Into<String>) -> Self {
|
||||
Self::new(error_codes::INVALID_PARAMS, message)
|
||||
}
|
||||
|
||||
/// Create an internal error
|
||||
pub fn internal_error(message: impl Into<String>) -> Self {
|
||||
Self::new(error_codes::INTERNAL_ERROR, message)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<AcpServerError> for JsonRpcErrorObject {
|
||||
fn from(error: AcpServerError) -> Self {
|
||||
match error {
|
||||
AcpServerError::InvalidSession => {
|
||||
JsonRpcErrorObject::new(error_codes::INVALID_SESSION, "Invalid session ID provided")
|
||||
}
|
||||
AcpServerError::RpcError(msg) => {
|
||||
JsonRpcErrorObject::new(error_codes::INTERNAL_ERROR, msg)
|
||||
}
|
||||
AcpServerError::TransportError(msg) => {
|
||||
JsonRpcErrorObject::new(error_codes::TRANSPORT_ERROR, msg)
|
||||
}
|
||||
AcpServerError::SessionNotFound => {
|
||||
JsonRpcErrorObject::new(error_codes::SESSION_NOT_FOUND, "Session not found")
|
||||
}
|
||||
AcpServerError::ConnectorNotFound(id) => {
|
||||
JsonRpcErrorObject::new(error_codes::CONNECTOR_NOT_FOUND, format!("Connector not found: {}", id))
|
||||
}
|
||||
AcpServerError::ConnectorNotReady(id) => {
|
||||
JsonRpcErrorObject::new(error_codes::CONNECTOR_NOT_FOUND, format!("Connector not ready: {}", id))
|
||||
}
|
||||
AcpServerError::Timeout(msg) => {
|
||||
JsonRpcErrorObject::new(error_codes::TRANSPORT_ERROR, format!("Timeout: {}", msg))
|
||||
}
|
||||
AcpServerError::Internal(msg) => JsonRpcErrorObject::internal_error(msg),
|
||||
AcpServerError::ClientNotFound(id) => {
|
||||
JsonRpcErrorObject::new(error_codes::CLIENT_NOT_FOUND, format!("Client not found: {}", id))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
assert_eq!(
|
||||
AcpServerError::InvalidSession.to_string(),
|
||||
"Invalid session ID provided"
|
||||
);
|
||||
assert_eq!(
|
||||
AcpServerError::RpcError("test".to_string()).to_string(),
|
||||
"RPC error: test"
|
||||
);
|
||||
assert_eq!(
|
||||
AcpServerError::TransportError("timeout".to_string()).to_string(),
|
||||
"Transport error: timeout"
|
||||
);
|
||||
assert_eq!(
|
||||
AcpServerError::SessionNotFound.to_string(),
|
||||
"Session not found"
|
||||
);
|
||||
assert_eq!(
|
||||
AcpServerError::ConnectorNotFound("test-conn".to_string()).to_string(),
|
||||
"Connector not found: test-conn"
|
||||
);
|
||||
assert_eq!(
|
||||
AcpServerError::ConnectorNotReady("test-conn".to_string()).to_string(),
|
||||
"Connector not ready: test-conn"
|
||||
);
|
||||
assert_eq!(
|
||||
AcpServerError::Timeout("request timed out".to_string()).to_string(),
|
||||
"Timeout: request timed out"
|
||||
);
|
||||
assert_eq!(
|
||||
AcpServerError::Internal("something went wrong".to_string()).to_string(),
|
||||
"Internal error: something went wrong"
|
||||
);
|
||||
assert_eq!(
|
||||
AcpServerError::ClientNotFound("client-123".to_string()).to_string(),
|
||||
"Client not found: client-123"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_to_jsonrpc() {
|
||||
let error: JsonRpcErrorObject = AcpServerError::InvalidSession.into();
|
||||
assert_eq!(error.code, error_codes::INVALID_SESSION);
|
||||
assert_eq!(error.message, "Invalid session ID provided");
|
||||
|
||||
let error: JsonRpcErrorObject = AcpServerError::SessionNotFound.into();
|
||||
assert_eq!(error.code, error_codes::SESSION_NOT_FOUND);
|
||||
|
||||
let error: JsonRpcErrorObject = AcpServerError::ConnectorNotFound("conn1".to_string()).into();
|
||||
assert_eq!(error.code, error_codes::CONNECTOR_NOT_FOUND);
|
||||
assert!(error.message.contains("conn1"));
|
||||
|
||||
let error: JsonRpcErrorObject = AcpServerError::ConnectorNotReady("conn2".to_string()).into();
|
||||
assert_eq!(error.code, error_codes::CONNECTOR_NOT_FOUND);
|
||||
assert!(error.message.contains("conn2"));
|
||||
|
||||
let error: JsonRpcErrorObject = AcpServerError::Timeout("session creation".to_string()).into();
|
||||
assert_eq!(error.code, error_codes::TRANSPORT_ERROR);
|
||||
|
||||
let error: JsonRpcErrorObject = AcpServerError::TransportError("net error".to_string()).into();
|
||||
assert_eq!(error.code, error_codes::TRANSPORT_ERROR);
|
||||
assert_eq!(error.message, "net error");
|
||||
|
||||
let error: JsonRpcErrorObject = AcpServerError::Internal("internal".to_string()).into();
|
||||
assert_eq!(error.code, error_codes::INTERNAL_ERROR);
|
||||
|
||||
let error: JsonRpcErrorObject = AcpServerError::ClientNotFound("client-456".to_string()).into();
|
||||
assert_eq!(error.code, error_codes::CLIENT_NOT_FOUND);
|
||||
assert!(error.message.contains("client-456"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_error_serialization() {
|
||||
let error = JsonRpcErrorObject::new(error_codes::PARSE_ERROR, "Invalid JSON");
|
||||
let json = serde_json::to_string(&error).unwrap();
|
||||
assert!(json.contains("-32700"));
|
||||
assert!(json.contains("Invalid JSON"));
|
||||
|
||||
// With data
|
||||
let error = JsonRpcErrorObject::with_data(
|
||||
error_codes::INVALID_PARAMS,
|
||||
"Missing field",
|
||||
serde_json::json!({"field": "session_id"}),
|
||||
);
|
||||
let json = serde_json::to_string(&error).unwrap();
|
||||
assert!(json.contains("session_id"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_error_factories() {
|
||||
let error = JsonRpcErrorObject::parse_error("bad json");
|
||||
assert_eq!(error.code, error_codes::PARSE_ERROR);
|
||||
|
||||
let error = JsonRpcErrorObject::invalid_request("missing jsonrpc field");
|
||||
assert_eq!(error.code, error_codes::INVALID_REQUEST);
|
||||
|
||||
let error = JsonRpcErrorObject::method_not_found("session.unknown");
|
||||
assert_eq!(error.code, error_codes::METHOD_NOT_FOUND);
|
||||
assert!(error.message.contains("session.unknown"));
|
||||
|
||||
let error = JsonRpcErrorObject::invalid_params("session_id required");
|
||||
assert_eq!(error.code, error_codes::INVALID_PARAMS);
|
||||
|
||||
let error = JsonRpcErrorObject::internal_error("panic");
|
||||
assert_eq!(error.code, error_codes::INTERNAL_ERROR);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_is_error_trait() {
|
||||
fn assert_is_error<T: std::error::Error>() {}
|
||||
assert_is_error::<AcpServerError>();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_clone() {
|
||||
let error = AcpServerError::Internal("test".to_string());
|
||||
let cloned = error.clone();
|
||||
assert_eq!(error, cloned);
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,460 @@
|
||||
//! JSON-RPC 2.0 types for the ACP Server
|
||||
//!
|
||||
//! This module implements JSON-RPC 2.0 request/response types according to
|
||||
//! the specification at https://www.jsonrpc.org/specification.
|
||||
//!
|
||||
//! Key features:
|
||||
//! - Support for both numeric and string IDs
|
||||
//! - Batch request/response handling
|
||||
//! - Proper serialization of null vs missing fields
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::JsonRpcErrorObject;
|
||||
|
||||
/// JSON-RPC protocol version constant
|
||||
pub const JSONRPC_VERSION: &str = "2.0";
|
||||
|
||||
/// JSON-RPC request/response identifier
|
||||
///
|
||||
/// According to the JSON-RPC 2.0 spec, an id can be a String, Number,
|
||||
/// or Null. This type uses an untagged enum to handle both string and
|
||||
/// number identifiers.
|
||||
///
|
||||
/// Note: The spec recommends not using Null as an id for requests.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JsonRpcId {
|
||||
/// Numeric identifier (integer)
|
||||
Number(i64),
|
||||
|
||||
/// String identifier
|
||||
String(String),
|
||||
|
||||
/// Null identifier (typically used in error responses for invalid requests)
|
||||
Null,
|
||||
}
|
||||
|
||||
impl From<i64> for JsonRpcId {
|
||||
fn from(n: i64) -> Self {
|
||||
JsonRpcId::Number(n)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for JsonRpcId {
|
||||
fn from(s: String) -> Self {
|
||||
JsonRpcId::String(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for JsonRpcId {
|
||||
fn from(s: &str) -> Self {
|
||||
JsonRpcId::String(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// A JSON-RPC 2.0 request object
|
||||
///
|
||||
/// Represents a remote procedure call with optional parameters.
|
||||
/// The `id` field determines whether this is a request (with id) or
|
||||
/// notification (without id).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
/// JSON-RPC protocol version (must be "2.0")
|
||||
pub jsonrpc: String,
|
||||
|
||||
/// A String containing the name of the method to be invoked
|
||||
pub method: String,
|
||||
|
||||
/// Optional structured value that holds the parameter values
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
|
||||
/// Optional identifier established by the client
|
||||
///
|
||||
/// If absent, the request is a notification (no response expected)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<JsonRpcId>,
|
||||
}
|
||||
|
||||
impl JsonRpcRequest {
|
||||
/// Create a new JSON-RPC request
|
||||
pub fn new(method: impl Into<String>, params: Option<serde_json::Value>, id: JsonRpcId) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
method: method.into(),
|
||||
params,
|
||||
id: Some(id),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new JSON-RPC notification (request without id)
|
||||
pub fn notification(method: impl Into<String>, params: Option<serde_json::Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
method: method.into(),
|
||||
params,
|
||||
id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this request is a notification (no id)
|
||||
pub fn is_notification(&self) -> bool {
|
||||
self.id.is_none()
|
||||
}
|
||||
|
||||
/// Validate the request format
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
if self.jsonrpc != JSONRPC_VERSION {
|
||||
return Err(format!(
|
||||
"Invalid JSON-RPC version: expected '{}', got '{}'",
|
||||
JSONRPC_VERSION, self.jsonrpc
|
||||
));
|
||||
}
|
||||
|
||||
if self.method.is_empty() {
|
||||
return Err("Method name cannot be empty".to_string());
|
||||
}
|
||||
|
||||
// Methods starting with "rpc." are reserved for internal use
|
||||
if self.method.starts_with("rpc.") {
|
||||
return Err(format!(
|
||||
"Method name '{}' is reserved (starts with 'rpc.')",
|
||||
self.method
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// A JSON-RPC 2.0 response object
|
||||
///
|
||||
/// Contains either a result (success) or an error (failure), never both.
|
||||
/// The id must match the corresponding request id.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
/// JSON-RPC protocol version (must be "2.0")
|
||||
pub jsonrpc: String,
|
||||
|
||||
/// The result of the call (on success)
|
||||
///
|
||||
/// This member is REQUIRED on success and MUST NOT exist on error.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
|
||||
/// The error object (on failure)
|
||||
///
|
||||
/// This member is REQUIRED on error and MUST NOT exist on success.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcErrorObject>,
|
||||
|
||||
/// The identifier matching the request
|
||||
///
|
||||
/// If there was an error detecting the id in the Request object
|
||||
/// (e.g. Parse error/Invalid Request), it MUST be Null.
|
||||
pub id: JsonRpcId,
|
||||
}
|
||||
|
||||
impl JsonRpcResponse {
|
||||
/// Create a successful response
|
||||
pub fn success(result: serde_json::Value, id: JsonRpcId) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an error response
|
||||
pub fn error(error: JsonRpcErrorObject, id: JsonRpcId) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
result: None,
|
||||
error: Some(error),
|
||||
id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an error response with a null id (for parse errors)
|
||||
pub fn error_with_null_id(error: JsonRpcErrorObject) -> Self {
|
||||
Self::error(error, JsonRpcId::Null)
|
||||
}
|
||||
|
||||
/// Check if this response represents success
|
||||
pub fn is_success(&self) -> bool {
|
||||
self.result.is_some() && self.error.is_none()
|
||||
}
|
||||
|
||||
/// Check if this response represents an error
|
||||
pub fn is_error(&self) -> bool {
|
||||
self.error.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents either a single request or a batch of requests
|
||||
///
|
||||
/// The JSON-RPC 2.0 spec allows sending multiple requests in a single
|
||||
/// JSON array for batch processing.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JsonRpcRequestBatch {
|
||||
/// A single request
|
||||
Single(JsonRpcRequest),
|
||||
|
||||
/// A batch of requests
|
||||
Batch(Vec<JsonRpcRequest>),
|
||||
}
|
||||
|
||||
impl JsonRpcRequestBatch {
|
||||
/// Check if this is an empty batch
|
||||
pub fn is_empty(&self) -> bool {
|
||||
match self {
|
||||
JsonRpcRequestBatch::Single(_) => false,
|
||||
JsonRpcRequestBatch::Batch(batch) => batch.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the number of requests
|
||||
pub fn len(&self) -> usize {
|
||||
match self {
|
||||
JsonRpcRequestBatch::Single(_) => 1,
|
||||
JsonRpcRequestBatch::Batch(batch) => batch.len(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to a vector of requests
|
||||
pub fn into_vec(self) -> Vec<JsonRpcRequest> {
|
||||
match self {
|
||||
JsonRpcRequestBatch::Single(req) => vec![req],
|
||||
JsonRpcRequestBatch::Batch(batch) => batch,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents either a single response or a batch of responses
|
||||
///
|
||||
/// The response format must match the request format: single request
|
||||
/// gets single response, batch request gets batch response.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum JsonRpcResponseBatch {
|
||||
/// A single response
|
||||
Single(JsonRpcResponse),
|
||||
|
||||
/// A batch of responses
|
||||
Batch(Vec<JsonRpcResponse>),
|
||||
}
|
||||
|
||||
impl JsonRpcResponseBatch {
|
||||
/// Create a batch response from a vector
|
||||
///
|
||||
/// Returns Single if there's exactly one response, otherwise Batch.
|
||||
pub fn from_vec(responses: Vec<JsonRpcResponse>) -> Self {
|
||||
if responses.len() == 1 {
|
||||
JsonRpcResponseBatch::Single(responses.into_iter().next().unwrap())
|
||||
} else {
|
||||
JsonRpcResponseBatch::Batch(responses)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_id_number() {
|
||||
let id = JsonRpcId::Number(42);
|
||||
let json = serde_json::to_string(&id).unwrap();
|
||||
assert_eq!(json, "42");
|
||||
|
||||
let parsed: JsonRpcId = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_id_string() {
|
||||
let id = JsonRpcId::String("abc-123".to_string());
|
||||
let json = serde_json::to_string(&id).unwrap();
|
||||
assert_eq!(json, "\"abc-123\"");
|
||||
|
||||
let parsed: JsonRpcId = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_id_null() {
|
||||
let id = JsonRpcId::Null;
|
||||
let json = serde_json::to_string(&id).unwrap();
|
||||
assert_eq!(json, "null");
|
||||
|
||||
let parsed: JsonRpcId = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed, id);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_id_from_conversions() {
|
||||
let id: JsonRpcId = 42i64.into();
|
||||
assert_eq!(id, JsonRpcId::Number(42));
|
||||
|
||||
let id: JsonRpcId = "test".into();
|
||||
assert_eq!(id, JsonRpcId::String("test".to_string()));
|
||||
|
||||
let id: JsonRpcId = String::from("owned").into();
|
||||
assert_eq!(id, JsonRpcId::String("owned".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_creation() {
|
||||
let req = JsonRpcRequest::new("session.new", Some(json!({"title": "Test"})), 1.into());
|
||||
assert_eq!(req.jsonrpc, "2.0");
|
||||
assert_eq!(req.method, "session.new");
|
||||
assert!(req.params.is_some());
|
||||
assert_eq!(req.id, Some(JsonRpcId::Number(1)));
|
||||
assert!(!req.is_notification());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_notification_creation() {
|
||||
let notif = JsonRpcRequest::notification("event.ping", None);
|
||||
assert!(notif.is_notification());
|
||||
assert_eq!(notif.id, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_validation() {
|
||||
let valid = JsonRpcRequest::new("test.method", None, 1.into());
|
||||
assert!(valid.validate().is_ok());
|
||||
|
||||
// Invalid version
|
||||
let mut invalid_version = valid.clone();
|
||||
invalid_version.jsonrpc = "1.0".to_string();
|
||||
assert!(invalid_version.validate().is_err());
|
||||
|
||||
// Empty method
|
||||
let mut empty_method = valid.clone();
|
||||
empty_method.method = String::new();
|
||||
assert!(empty_method.validate().is_err());
|
||||
|
||||
// Reserved method
|
||||
let mut reserved = valid.clone();
|
||||
reserved.method = "rpc.internal".to_string();
|
||||
assert!(reserved.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_serialization() {
|
||||
let req = JsonRpcRequest::new(
|
||||
"session.prompt",
|
||||
Some(json!({"session_id": "abc", "content": "Hello"})),
|
||||
"req-123".into(),
|
||||
);
|
||||
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"jsonrpc\":\"2.0\""));
|
||||
assert!(json.contains("\"method\":\"session.prompt\""));
|
||||
assert!(json.contains("\"id\":\"req-123\""));
|
||||
|
||||
// Deserialize back
|
||||
let parsed: JsonRpcRequest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.method, "session.prompt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_success() {
|
||||
let resp = JsonRpcResponse::success(json!({"session_id": "new-123"}), 1.into());
|
||||
assert!(resp.is_success());
|
||||
assert!(!resp.is_error());
|
||||
assert_eq!(resp.result, Some(json!({"session_id": "new-123"})));
|
||||
assert!(resp.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_error() {
|
||||
let error = JsonRpcErrorObject::method_not_found("unknown.method");
|
||||
let resp = JsonRpcResponse::error(error, 1.into());
|
||||
assert!(!resp.is_success());
|
||||
assert!(resp.is_error());
|
||||
assert!(resp.result.is_none());
|
||||
assert!(resp.error.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_serialization() {
|
||||
// Success response
|
||||
let success = JsonRpcResponse::success(json!({"ok": true}), 42.into());
|
||||
let json = serde_json::to_string(&success).unwrap();
|
||||
assert!(json.contains("\"result\""));
|
||||
assert!(!json.contains("\"error\""));
|
||||
|
||||
// Error response
|
||||
let error = JsonRpcResponse::error(
|
||||
JsonRpcErrorObject::internal_error("Something broke"),
|
||||
42.into(),
|
||||
);
|
||||
let json = serde_json::to_string(&error).unwrap();
|
||||
assert!(!json.contains("\"result\""));
|
||||
assert!(json.contains("\"error\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_request_single() {
|
||||
let req = JsonRpcRequest::new("test", None, 1.into());
|
||||
let batch = JsonRpcRequestBatch::Single(req);
|
||||
assert_eq!(batch.len(), 1);
|
||||
assert!(!batch.is_empty());
|
||||
|
||||
let vec = batch.into_vec();
|
||||
assert_eq!(vec.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_request_multiple() {
|
||||
let req1 = JsonRpcRequest::new("test1", None, 1.into());
|
||||
let req2 = JsonRpcRequest::new("test2", None, 2.into());
|
||||
let batch = JsonRpcRequestBatch::Batch(vec![req1, req2]);
|
||||
assert_eq!(batch.len(), 2);
|
||||
|
||||
let vec = batch.into_vec();
|
||||
assert_eq!(vec.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_request_empty() {
|
||||
let batch = JsonRpcRequestBatch::Batch(vec![]);
|
||||
assert!(batch.is_empty());
|
||||
assert_eq!(batch.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_request_deserialization() {
|
||||
// Single request
|
||||
let single_json = r#"{"jsonrpc":"2.0","method":"test","id":1}"#;
|
||||
let single: JsonRpcRequestBatch = serde_json::from_str(single_json).unwrap();
|
||||
assert_eq!(single.len(), 1);
|
||||
|
||||
// Batch request
|
||||
let batch_json = r#"[{"jsonrpc":"2.0","method":"test1","id":1},{"jsonrpc":"2.0","method":"test2","id":2}]"#;
|
||||
let batch: JsonRpcRequestBatch = serde_json::from_str(batch_json).unwrap();
|
||||
assert_eq!(batch.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_response_from_vec() {
|
||||
// Single response becomes Single variant
|
||||
let responses = vec![JsonRpcResponse::success(json!(null), 1.into())];
|
||||
let batch = JsonRpcResponseBatch::from_vec(responses);
|
||||
matches!(batch, JsonRpcResponseBatch::Single(_));
|
||||
|
||||
// Multiple responses become Batch variant
|
||||
let responses = vec![
|
||||
JsonRpcResponse::success(json!(null), 1.into()),
|
||||
JsonRpcResponse::success(json!(null), 2.into()),
|
||||
];
|
||||
let batch = JsonRpcResponseBatch::from_vec(responses);
|
||||
matches!(batch, JsonRpcResponseBatch::Batch(_));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
//! Dirigent ACP API
|
||||
//!
|
||||
//! This crate exposes an ACP (Agent-Client Protocol) API for Dirigent,
|
||||
//! allowing other ACP clients to interact with Dirigent.
|
||||
//!
|
||||
//! ## Modules
|
||||
//!
|
||||
//! - [`config`] - Server configuration types
|
||||
//! - [`error`] - Error types and JSON-RPC error conversion
|
||||
//! - [`event_bridge`] - Event forwarding from source streams to SSE clients
|
||||
//! - [`jsonrpc`] - JSON-RPC 2.0 request/response types
|
||||
//! - [`router`] - Axum router and HTTP handlers
|
||||
//! - [`rpc`] - JSON-RPC request handler and method dispatch
|
||||
//! - [`session_manager`] - Session mapping and client connection tracking
|
||||
//! - [`sse`] - SSE notifications for streaming events to clients
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use dirigent_acp_api::{
|
||||
//! AcpServerConfig, NoOpConnectorOperations,
|
||||
//! router::{AcpServerState, create_acp_server_router},
|
||||
//! };
|
||||
//!
|
||||
//! // Create server configuration
|
||||
//! let config = AcpServerConfig::enabled()
|
||||
//! .set_port(3001)
|
||||
//! .set_max_connections(100);
|
||||
//!
|
||||
//! // Create server state
|
||||
//! let state = AcpServerState::new(config);
|
||||
//!
|
||||
//! // Create the router with connector operations
|
||||
//! let router = create_acp_server_router(state, NoOpConnectorOperations);
|
||||
//!
|
||||
//! // Run with axum
|
||||
//! let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await?;
|
||||
//! axum::serve(listener, router).await?;
|
||||
//! ```
|
||||
|
||||
// Modules
|
||||
pub mod agent_requests;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod event_bridge;
|
||||
pub mod jsonrpc;
|
||||
pub mod router;
|
||||
pub mod rpc;
|
||||
pub mod session_manager;
|
||||
pub mod sse;
|
||||
|
||||
// Re-exports for convenience
|
||||
pub use agent_requests::AgentRequestTracker;
|
||||
pub use config::AcpServerConfig;
|
||||
pub use error::{AcpServerError, JsonRpcErrorObject};
|
||||
pub use event_bridge::{EventBridge, EventBridgeConfig};
|
||||
pub use jsonrpc::{
|
||||
JsonRpcId, JsonRpcRequest, JsonRpcRequestBatch, JsonRpcResponse, JsonRpcResponseBatch,
|
||||
JSONRPC_VERSION,
|
||||
};
|
||||
pub use router::{create_acp_server_router, AcpServerState, RouterState};
|
||||
pub use rpc::{
|
||||
ConnectorInfo, ConnectorOperations, NoOpConnectorOperations, RpcHandler, SessionInfo,
|
||||
ACP_PROTOCOL_VERSION, SERVER_NAME,
|
||||
};
|
||||
pub use session_manager::{ClientConnection, ClientInfo, SessionManager, SessionMapping};
|
||||
pub use sse::{AcpNotification, SseNotifier, translate_event};
|
||||
|
||||
use axum::{response::Json, routing::get, Router};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ApiInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// Create the ACP API router
|
||||
pub fn create_api_router() -> Router {
|
||||
Router::new()
|
||||
.route("/", get(api_info))
|
||||
.route("/health", get(health_check))
|
||||
}
|
||||
|
||||
/// Get API information
|
||||
async fn api_info() -> Json<ApiInfo> {
|
||||
Json(ApiInfo {
|
||||
name: "dirigent_acp_api".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
description: "ACP API for Dirigent - orchestrates agentic clients".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Health check endpoint
|
||||
async fn health_check() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"message": "Dirigent ACP API is running"
|
||||
}))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_api_info() {
|
||||
let info = ApiInfo {
|
||||
name: "test".to_string(),
|
||||
version: "0.1.0".to_string(),
|
||||
description: "test description".to_string(),
|
||||
};
|
||||
assert_eq!(info.name, "test");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,950 @@
|
||||
//! Axum Router Integration for ACP Server
|
||||
//!
|
||||
//! This module provides the Axum router and HTTP handlers for the ACP Server.
|
||||
//! It exposes a JSON-RPC endpoint at `/rpc` and an SSE events endpoint at `/events`.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The router is generic over `ConnectorOperations`, allowing different implementations
|
||||
//! to be used (e.g., one backed by `CoreHandle` for production, or `NoOpConnectorOperations`
|
||||
//! for testing).
|
||||
//!
|
||||
//! ## Endpoints
|
||||
//!
|
||||
//! - `POST /rpc` - JSON-RPC 2.0 endpoint for method calls
|
||||
//! - `GET /events?client_id=...` - SSE stream for real-time notifications
|
||||
//! - `GET /health` - Health check endpoint
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use dirigent_acp_api::router::{AcpServerState, create_acp_server_router};
|
||||
//! use dirigent_acp_api::{AcpServerConfig, NoOpConnectorOperations};
|
||||
//!
|
||||
//! // Create state
|
||||
//! let state = AcpServerState::new(AcpServerConfig::enabled());
|
||||
//!
|
||||
//! // Create router with no-op operations for testing
|
||||
//! let router = create_acp_server_router(state, NoOpConnectorOperations);
|
||||
//! ```
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
IntoResponse, Json,
|
||||
},
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio_stream::StreamExt;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing::{debug, info, trace, warn};
|
||||
|
||||
use crate::config::AcpServerConfig;
|
||||
use crate::rpc::{ConnectorOperations, RpcHandler};
|
||||
use crate::session_manager::SessionManager;
|
||||
use crate::sse::SseNotifier;
|
||||
|
||||
// ============================================================================
|
||||
// AcpServerState (T031)
|
||||
// ============================================================================
|
||||
|
||||
/// Internal state shared across handlers
|
||||
struct AcpServerStateInner {
|
||||
/// Session manager for tracking client sessions
|
||||
session_manager: SessionManager,
|
||||
|
||||
/// SSE notifier for broadcasting events to clients
|
||||
sse_notifier: SseNotifier,
|
||||
|
||||
/// Agent request tracker for bidirectional request/response
|
||||
agent_request_tracker: Arc<crate::agent_requests::AgentRequestTracker>,
|
||||
|
||||
/// Server configuration
|
||||
config: AcpServerConfig,
|
||||
}
|
||||
|
||||
/// Shared state for the ACP Server Axum handlers
|
||||
///
|
||||
/// This struct contains all the state needed by the HTTP handlers:
|
||||
/// - Session manager for tracking client sessions and mappings
|
||||
/// - SSE notifier for broadcasting events to connected clients
|
||||
/// - Server configuration
|
||||
///
|
||||
/// The state is wrapped in `Arc` internally, making it cheap to clone and
|
||||
/// share across async tasks and handlers.
|
||||
///
|
||||
/// Note: `RpcHandler` is created per-request with `ConnectorOperations` passed in,
|
||||
/// as the connector operations implementation cannot be stored in the shared state
|
||||
/// (it may contain non-Clone types like `CoreHandle`).
|
||||
#[derive(Clone)]
|
||||
pub struct AcpServerState {
|
||||
inner: Arc<AcpServerStateInner>,
|
||||
}
|
||||
|
||||
impl AcpServerState {
|
||||
/// Create a new ACP server state with the given configuration
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `config`: The server configuration
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use dirigent_acp_api::router::AcpServerState;
|
||||
/// use dirigent_acp_api::AcpServerConfig;
|
||||
///
|
||||
/// let state = AcpServerState::new(AcpServerConfig::enabled());
|
||||
/// ```
|
||||
pub fn new(config: AcpServerConfig) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(AcpServerStateInner {
|
||||
session_manager: SessionManager::new(),
|
||||
sse_notifier: SseNotifier::new(),
|
||||
agent_request_tracker: Arc::new(crate::agent_requests::AgentRequestTracker::new()),
|
||||
config,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new ACP server state with custom components
|
||||
///
|
||||
/// This is useful for testing or when you need to share a session manager
|
||||
/// or SSE notifier with other parts of the application.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `session_manager`: The session manager instance
|
||||
/// - `sse_notifier`: The SSE notifier instance
|
||||
/// - `agent_request_tracker`: The agent request tracker instance
|
||||
/// - `config`: The server configuration
|
||||
pub fn with_components(
|
||||
session_manager: SessionManager,
|
||||
sse_notifier: SseNotifier,
|
||||
agent_request_tracker: Arc<crate::agent_requests::AgentRequestTracker>,
|
||||
config: AcpServerConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(AcpServerStateInner {
|
||||
session_manager,
|
||||
sse_notifier,
|
||||
agent_request_tracker,
|
||||
config,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the session manager
|
||||
pub fn session_manager(&self) -> &SessionManager {
|
||||
&self.inner.session_manager
|
||||
}
|
||||
|
||||
/// Get a reference to the SSE notifier
|
||||
pub fn sse_notifier(&self) -> &SseNotifier {
|
||||
&self.inner.sse_notifier
|
||||
}
|
||||
|
||||
/// Get a reference to the agent request tracker
|
||||
pub fn agent_request_tracker(&self) -> &Arc<crate::agent_requests::AgentRequestTracker> {
|
||||
&self.inner.agent_request_tracker
|
||||
}
|
||||
|
||||
/// Get a reference to the configuration
|
||||
pub fn config(&self) -> &AcpServerConfig {
|
||||
&self.inner.config
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for AcpServerState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("AcpServerState")
|
||||
.field("session_count", &self.inner.session_manager.mapping_count())
|
||||
.field("client_count", &self.inner.sse_notifier.client_count())
|
||||
.field("config", &self.inner.config)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Router State with ConnectorOperations (T028)
|
||||
// ============================================================================
|
||||
|
||||
/// Combined state for Axum handlers that includes both server state and connector operations
|
||||
///
|
||||
/// This struct combines the shared `AcpServerState` with a `ConnectorOperations`
|
||||
/// implementation. It's used as the Axum state to provide handlers with access
|
||||
/// to both session management and connector functionality.
|
||||
#[derive(Clone)]
|
||||
pub struct RouterState<C: ConnectorOperations + Clone + Send + Sync + 'static> {
|
||||
/// The ACP server state (session manager, SSE notifier, config)
|
||||
pub state: AcpServerState,
|
||||
|
||||
/// The connector operations implementation
|
||||
pub connector_ops: C,
|
||||
}
|
||||
|
||||
impl<C: ConnectorOperations + Clone + Send + Sync + 'static> RouterState<C> {
|
||||
/// Create a new router state
|
||||
pub fn new(state: AcpServerState, connector_ops: C) -> Self {
|
||||
Self {
|
||||
state,
|
||||
connector_ops,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Router Factory (T028)
|
||||
// ============================================================================
|
||||
|
||||
/// Create the ACP server router with all endpoints configured
|
||||
///
|
||||
/// This function creates an Axum router with the following endpoints:
|
||||
/// - `POST /rpc` - JSON-RPC 2.0 endpoint
|
||||
/// - `GET /events` - SSE events stream
|
||||
/// - `GET /health` - Health check
|
||||
///
|
||||
/// CORS middleware is applied based on the configuration.
|
||||
///
|
||||
/// # Type Parameters
|
||||
///
|
||||
/// - `C`: The connector operations implementation
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `state`: The ACP server state
|
||||
/// - `connector_ops`: The connector operations implementation
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use dirigent_acp_api::router::{AcpServerState, create_acp_server_router};
|
||||
/// use dirigent_acp_api::{AcpServerConfig, NoOpConnectorOperations};
|
||||
///
|
||||
/// let state = AcpServerState::new(AcpServerConfig::enabled());
|
||||
/// let router = create_acp_server_router(state, NoOpConnectorOperations);
|
||||
///
|
||||
/// // Use with axum server
|
||||
/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3001").await?;
|
||||
/// axum::serve(listener, router).await?;
|
||||
/// ```
|
||||
pub fn create_acp_server_router<C>(state: AcpServerState, connector_ops: C) -> Router
|
||||
where
|
||||
C: ConnectorOperations + Clone + Send + Sync + 'static,
|
||||
{
|
||||
// Build CORS layer based on configuration
|
||||
let cors = build_cors_layer(&state.config());
|
||||
|
||||
// Create combined router state
|
||||
let router_state = RouterState::new(state, connector_ops);
|
||||
|
||||
// Build the router (T023: added /agent_response route)
|
||||
Router::new()
|
||||
.route("/rpc", post(handle_rpc::<C>))
|
||||
.route("/events", get(handle_sse::<C>))
|
||||
.route("/health", get(handle_health::<C>))
|
||||
.route("/agent_response", post(handle_agent_response::<C>))
|
||||
.layer(cors)
|
||||
.with_state(router_state)
|
||||
}
|
||||
|
||||
/// Build CORS layer from configuration
|
||||
fn build_cors_layer(config: &AcpServerConfig) -> CorsLayer {
|
||||
let cors = CorsLayer::new()
|
||||
.allow_methods([
|
||||
axum::http::Method::GET,
|
||||
axum::http::Method::POST,
|
||||
axum::http::Method::OPTIONS,
|
||||
])
|
||||
.allow_headers(Any);
|
||||
|
||||
match &config.allowed_origins {
|
||||
Some(origins) if !origins.is_empty() => {
|
||||
// Parse origins and set specific allowed origins
|
||||
let parsed_origins: Vec<_> = origins
|
||||
.iter()
|
||||
.filter_map(|o| o.parse().ok())
|
||||
.collect();
|
||||
|
||||
if parsed_origins.is_empty() {
|
||||
warn!("No valid origins in allowed_origins, allowing any origin");
|
||||
cors.allow_origin(Any)
|
||||
} else {
|
||||
debug!("CORS configured with {} allowed origins", parsed_origins.len());
|
||||
cors.allow_origin(parsed_origins)
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
debug!("CORS configured to allow any origin");
|
||||
cors.allow_origin(Any)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Request/Response Types
|
||||
// ============================================================================
|
||||
|
||||
/// Query parameters for the SSE events endpoint
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SseQuery {
|
||||
/// The client ID (required for SSE subscription)
|
||||
pub client_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Health check response
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct HealthResponse {
|
||||
/// Health status
|
||||
pub status: String,
|
||||
|
||||
/// Server message
|
||||
pub message: String,
|
||||
|
||||
/// Number of connected clients
|
||||
pub clients: usize,
|
||||
|
||||
/// Number of active sessions
|
||||
pub sessions: usize,
|
||||
}
|
||||
|
||||
/// Error response for SSE endpoint
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct SseErrorResponse {
|
||||
/// Error message
|
||||
pub error: String,
|
||||
|
||||
/// Error code
|
||||
pub code: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Endpoint Handlers (T029, T030)
|
||||
// ============================================================================
|
||||
|
||||
/// Handle POST /rpc requests (T029)
|
||||
///
|
||||
/// Extracts the JSON body, processes it through the RPC handler, and returns
|
||||
/// the JSON-RPC response.
|
||||
async fn handle_rpc<C>(
|
||||
State(router_state): State<RouterState<C>>,
|
||||
headers: HeaderMap,
|
||||
body: String,
|
||||
) -> impl IntoResponse
|
||||
where
|
||||
C: ConnectorOperations + Clone + Send + Sync + 'static,
|
||||
{
|
||||
debug!("Received RPC request: {} bytes", body.len());
|
||||
|
||||
// Extract client_id from X-Client-ID header
|
||||
let client_id = headers
|
||||
.get("X-Client-ID")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
// Extract select_connector from X-Select-Connector header (sent during initialize)
|
||||
let select_connector = headers
|
||||
.get("X-Select-Connector")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
if let Some(id) = client_id {
|
||||
info!(client_id = %id, "Received X-Client-ID header from RPC request");
|
||||
debug!("RPC headers: {:?}", headers);
|
||||
} else {
|
||||
warn!("No X-Client-ID header provided in RPC request");
|
||||
debug!("Available headers: {:?}", headers.keys().collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
// Log the incoming request
|
||||
debug!("Received RPC request body: {}", serde_json::to_string(&body).unwrap_or_else(|_| "failed to serialize".to_string()));
|
||||
|
||||
// Create an RPC handler for this request
|
||||
let handler = RpcHandler::new(
|
||||
router_state.state.session_manager().clone(),
|
||||
router_state.connector_ops.clone(),
|
||||
router_state.state.sse_notifier().clone(),
|
||||
router_state.state.agent_request_tracker().clone(),
|
||||
);
|
||||
|
||||
// Process the request with the client_id from header
|
||||
let response = handler.handle_request(&body, client_id).await;
|
||||
|
||||
// If this was an initialize request with X-Select-Connector header,
|
||||
// store the preference after the client is registered
|
||||
if let Some(ref connector) = select_connector {
|
||||
// Check if this is an initialize response by looking for clientId in result
|
||||
if let crate::jsonrpc::JsonRpcResponseBatch::Single(ref resp) = response {
|
||||
if let Some(ref result) = resp.result {
|
||||
if let Some(new_client_id) = result.get("clientId").and_then(|v| v.as_str()) {
|
||||
router_state.state.session_manager().update_client_preferred_connector(
|
||||
new_client_id,
|
||||
Some(connector.clone()),
|
||||
);
|
||||
info!(
|
||||
client_id = %new_client_id,
|
||||
select_connector = %connector,
|
||||
"Stored preferred connector from X-Select-Connector header"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Log the response being sent
|
||||
debug!("Sending RPC response: {}", serde_json::to_string(&response).unwrap_or_else(|_| "failed to serialize".to_string()));
|
||||
|
||||
// Return JSON response
|
||||
Json(response)
|
||||
}
|
||||
|
||||
/// Handle GET /events SSE requests (T030)
|
||||
///
|
||||
/// Creates an SSE stream subscription for the client. The client_id must be
|
||||
/// provided as a query parameter.
|
||||
async fn handle_sse<C>(
|
||||
State(router_state): State<RouterState<C>>,
|
||||
Query(query): Query<SseQuery>,
|
||||
) -> Result<Sse<impl tokio_stream::Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<SseErrorResponse>)>
|
||||
where
|
||||
C: ConnectorOperations + Clone + Send + Sync + 'static,
|
||||
{
|
||||
// Validate client_id
|
||||
let client_id = match query.client_id {
|
||||
Some(id) if !id.is_empty() => id,
|
||||
_ => {
|
||||
warn!("SSE request missing client_id");
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(SseErrorResponse {
|
||||
error: "Missing required parameter: client_id".to_string(),
|
||||
code: "MISSING_CLIENT_ID".to_string(),
|
||||
}),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
info!("SSE subscription requested for client: {}", client_id);
|
||||
debug!(
|
||||
"SSE notifier state before subscribe: client_count={}, subscribed_clients={:?}",
|
||||
router_state.state.sse_notifier().client_count(),
|
||||
router_state.state.sse_notifier().subscribed_clients()
|
||||
);
|
||||
|
||||
// Subscribe to notifications
|
||||
let notifier = router_state.state.sse_notifier();
|
||||
let notification_stream = notifier.subscribe(&client_id);
|
||||
|
||||
info!(
|
||||
"SSE client subscribed: client_id={}, total_clients={}, is_subscribed={}",
|
||||
client_id,
|
||||
notifier.client_count(),
|
||||
notifier.is_subscribed(&client_id)
|
||||
);
|
||||
|
||||
// Check if this client has any sessions that need tool updates
|
||||
// This handles the race condition where session/new happens before SSE subscription
|
||||
let session_manager = router_state.state.session_manager();
|
||||
let client_sessions = session_manager.list_client_sessions(&client_id);
|
||||
|
||||
if !client_sessions.is_empty() {
|
||||
debug!(
|
||||
"Client {} has {} existing sessions, sending tool updates",
|
||||
client_id,
|
||||
client_sessions.len()
|
||||
);
|
||||
|
||||
// For each session, get the connector and send tool updates
|
||||
for session_id in client_sessions {
|
||||
if let Some(mapping) = session_manager.get_mapping(&session_id) {
|
||||
debug!(
|
||||
"Fetching tool updates for session {} on connector {}",
|
||||
session_id,
|
||||
mapping.connector_id
|
||||
);
|
||||
|
||||
// Get available commands from the connector
|
||||
match router_state.connector_ops.get_connector_commands(&mapping.connector_id).await {
|
||||
Ok(commands) => {
|
||||
let update_params = crate::sse::SessionUpdateParams {
|
||||
session_id: session_id.clone(),
|
||||
update: crate::sse::SessionUpdateVariant::AvailableCommandsUpdate {
|
||||
available_commands: commands.clone(),
|
||||
},
|
||||
event_type_override: None,
|
||||
};
|
||||
|
||||
// Broadcast the tool update now that client is subscribed
|
||||
match notifier.broadcast(&client_id, update_params) {
|
||||
Ok(n) => {
|
||||
info!(
|
||||
"Sent deferred available_commands_update for session {}: {} commands to {} receivers",
|
||||
session_id,
|
||||
commands.len(),
|
||||
n
|
||||
);
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(
|
||||
"Failed to send deferred available_commands_update for session {}",
|
||||
session_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to get connector commands for session {}: {}",
|
||||
session_id,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Map the notification stream to SSE events
|
||||
let client_id_for_log = client_id.clone();
|
||||
let sse_stream = notification_stream.map(move |result| {
|
||||
match result {
|
||||
Ok(notification) => {
|
||||
// T013: Start timing for SSE event creation
|
||||
let event_start = Instant::now();
|
||||
|
||||
// Convert notification to SSE event
|
||||
// Use to_sse_json() which handles raw events correctly
|
||||
let event_type = notification.event_type();
|
||||
let data = notification.to_sse_json();
|
||||
|
||||
// T011: Log before SSE event is written (T023: includes session_id for correlation)
|
||||
trace!(
|
||||
client_id = %client_id_for_log,
|
||||
event_type = %event_type,
|
||||
session_id = %notification.session_id,
|
||||
data_len = data.len(),
|
||||
"Writing SSE event to client stream"
|
||||
);
|
||||
|
||||
debug!(
|
||||
"SSE: Sending event to client {}: type={}, session_id={}, data_len={}",
|
||||
client_id_for_log,
|
||||
event_type,
|
||||
notification.session_id,
|
||||
data.len()
|
||||
);
|
||||
trace!("SSE event data: {}", data);
|
||||
|
||||
let event = Event::default()
|
||||
.event(event_type)
|
||||
.data(data);
|
||||
|
||||
// T012: Log after Event::default() construction (T023: includes session_id for correlation)
|
||||
trace!(
|
||||
client_id = %client_id_for_log,
|
||||
event_type = %event_type,
|
||||
session_id = %notification.session_id,
|
||||
"SSE event constructed, sending to client"
|
||||
);
|
||||
|
||||
// T013: Log SSE event write completion with timing (T023: includes session_id for correlation)
|
||||
let elapsed_ms = event_start.elapsed().as_millis();
|
||||
trace!(
|
||||
client_id = %client_id_for_log,
|
||||
event_type = %event_type,
|
||||
session_id = %notification.session_id,
|
||||
elapsed_ms = elapsed_ms,
|
||||
"SSE event write completed"
|
||||
);
|
||||
|
||||
// T014: Warn for slow SSE event writes (T023: includes session_id for correlation)
|
||||
if elapsed_ms > 100 {
|
||||
warn!(
|
||||
elapsed_ms = elapsed_ms,
|
||||
client_id = %client_id_for_log,
|
||||
event_type = %event_type,
|
||||
session_id = %notification.session_id,
|
||||
"Slow SSE event write detected"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(event)
|
||||
}
|
||||
Err(e) => {
|
||||
// Broadcast stream error (e.g., lagged receiver)
|
||||
// We send an error event but keep the stream open
|
||||
warn!("SSE stream error for client {}: {:?}", client_id_for_log, e);
|
||||
Ok(Event::default()
|
||||
.event("error")
|
||||
.data(format!(r#"{{"error":"Stream error: {:?}"}}"#, e)))
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Return SSE response with keep-alive
|
||||
Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
|
||||
}
|
||||
|
||||
/// Handle GET /health requests
|
||||
///
|
||||
/// Returns basic health information about the server.
|
||||
async fn handle_health<C>(
|
||||
State(router_state): State<RouterState<C>>,
|
||||
) -> Json<HealthResponse>
|
||||
where
|
||||
C: ConnectorOperations + Clone + Send + Sync + 'static,
|
||||
{
|
||||
Json(HealthResponse {
|
||||
status: "healthy".to_string(),
|
||||
message: "Dirigent ACP Server is running".to_string(),
|
||||
clients: router_state.state.sse_notifier().client_count(),
|
||||
sessions: router_state.state.session_manager().mapping_count(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Response for agent response endpoint
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AgentResponseResult {
|
||||
/// Status of the response
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
/// Handle POST /agent_response requests (T019)
|
||||
///
|
||||
/// Accepts JSON-RPC responses from clients for pending agent requests.
|
||||
/// When a client approves/denies a permission request, they POST the
|
||||
/// response to this endpoint, which completes the pending request.
|
||||
/// (Note: When nested at /acp, the full path becomes /acp/agent_response)
|
||||
///
|
||||
/// Expected body format:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "jsonrpc": "2.0",
|
||||
/// "id": 0,
|
||||
/// "result": { "selectedOptionId": "allow" }
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// The response delivery to the connector happens through the oneshot channel
|
||||
/// registered in the event bridge (Phase 2.4). The tracker's `complete()` method
|
||||
/// sends the response value through the oneshot channel, and the event bridge task
|
||||
/// (which registered the request) will receive it and send the `ConnectorCommand::AgentResponse`.
|
||||
async fn handle_agent_response<C>(
|
||||
State(router_state): State<RouterState<C>>,
|
||||
headers: HeaderMap,
|
||||
Json(response): Json<serde_json::Value>,
|
||||
) -> Result<Json<AgentResponseResult>, (StatusCode, String)>
|
||||
where
|
||||
C: ConnectorOperations + Clone + Send + Sync + 'static,
|
||||
{
|
||||
// Extract client_id from X-Client-ID header (T020)
|
||||
let client_id = match headers.get("X-Client-ID").and_then(|v| v.to_str().ok()) {
|
||||
Some(id) if !id.is_empty() => id,
|
||||
_ => {
|
||||
warn!("Agent response request missing X-Client-ID header");
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Missing required header: X-Client-ID".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Received agent response from client: {}",
|
||||
client_id
|
||||
);
|
||||
trace!("Agent response body: {}", response);
|
||||
|
||||
// Extract request_id from JSON body (T020)
|
||||
let request_id = match response.get("id") {
|
||||
Some(id) => id.clone(),
|
||||
None => {
|
||||
warn!(
|
||||
client_id = %client_id,
|
||||
"Agent response missing 'id' field in body"
|
||||
);
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Missing required field: id".to_string(),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
info!(
|
||||
client_id = %client_id,
|
||||
request_id = %request_id,
|
||||
"Processing agent response"
|
||||
);
|
||||
|
||||
// Call AgentRequestTracker::complete() (T021)
|
||||
let tracker = router_state.state.agent_request_tracker();
|
||||
match tracker.complete(client_id, request_id.clone(), response.clone()) {
|
||||
Ok(()) => {
|
||||
info!(
|
||||
client_id = %client_id,
|
||||
request_id = %request_id,
|
||||
"Agent response delivered successfully"
|
||||
);
|
||||
|
||||
// Note (T022): The response delivery to the connector happens through
|
||||
// the oneshot channel in the event bridge. When the event bridge handles
|
||||
// an `Event::AgentRequest`, it registers the request with the tracker and
|
||||
// gets a receiver. After we complete the request here, the receiver in the
|
||||
// event bridge gets the response value and sends the `ConnectorCommand::AgentResponse`
|
||||
// to the connector. This flow is implemented in Phase 2.4 (T024-T030).
|
||||
|
||||
Ok(Json(AgentResponseResult {
|
||||
status: "ok".to_string(),
|
||||
}))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
client_id = %client_id,
|
||||
request_id = %request_id,
|
||||
error = %e,
|
||||
"Agent response failed: request not found (may have timed out)"
|
||||
);
|
||||
|
||||
Err((
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Request ID {} not found or already completed", request_id),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::jsonrpc::JsonRpcResponseBatch;
|
||||
use crate::rpc::NoOpConnectorOperations;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, StatusCode},
|
||||
};
|
||||
use tower::ServiceExt;
|
||||
|
||||
fn create_test_router() -> Router {
|
||||
let state = AcpServerState::new(AcpServerConfig::enabled());
|
||||
create_acp_server_router(state, NoOpConnectorOperations)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_endpoint() {
|
||||
let router = create_test_router();
|
||||
|
||||
let request = Request::builder()
|
||||
.uri("/health")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = router.oneshot(request).await.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let health: HealthResponse = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
assert_eq!(health.status, "healthy");
|
||||
assert_eq!(health.clients, 0);
|
||||
assert_eq!(health.sessions, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rpc_initialize() {
|
||||
let router = create_test_router();
|
||||
|
||||
let request_body = r#"{"jsonrpc":"2.0","method":"initialize","id":1}"#;
|
||||
|
||||
let request = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rpc")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(request_body))
|
||||
.unwrap();
|
||||
|
||||
let response = router.oneshot(request).await.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let resp: JsonRpcResponseBatch = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
match resp {
|
||||
JsonRpcResponseBatch::Single(r) => {
|
||||
assert!(r.is_success());
|
||||
let result = r.result.unwrap();
|
||||
assert_eq!(result["agentInfo"]["name"], "dirigent-acp-server");
|
||||
}
|
||||
_ => panic!("Expected single response"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rpc_invalid_json() {
|
||||
let router = create_test_router();
|
||||
|
||||
let request = Request::builder()
|
||||
.method("POST")
|
||||
.uri("/rpc")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from("not valid json"))
|
||||
.unwrap();
|
||||
|
||||
let response = router.oneshot(request).await.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let resp: JsonRpcResponseBatch = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
match resp {
|
||||
JsonRpcResponseBatch::Single(r) => {
|
||||
assert!(r.is_error());
|
||||
let error = r.error.unwrap();
|
||||
assert_eq!(error.code, crate::error::error_codes::PARSE_ERROR);
|
||||
}
|
||||
_ => panic!("Expected single response"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sse_missing_client_id() {
|
||||
let router = create_test_router();
|
||||
|
||||
let request = Request::builder()
|
||||
.uri("/events")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = router.oneshot(request).await.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.unwrap();
|
||||
let error: SseErrorResponse = serde_json::from_slice(&body).unwrap();
|
||||
|
||||
assert_eq!(error.code, "MISSING_CLIENT_ID");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sse_with_client_id() {
|
||||
let router = create_test_router();
|
||||
|
||||
let request = Request::builder()
|
||||
.uri("/events?client_id=test-client")
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let response = router.oneshot(request).await.unwrap();
|
||||
|
||||
// Should return 200 OK with SSE content type
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
assert!(content_type.is_some());
|
||||
assert!(content_type.unwrap().contains("text/event-stream"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_acp_server_state_new() {
|
||||
let state = AcpServerState::new(AcpServerConfig::enabled());
|
||||
|
||||
assert!(state.config().enabled);
|
||||
assert_eq!(state.session_manager().mapping_count(), 0);
|
||||
assert_eq!(state.sse_notifier().client_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_acp_server_state_with_components() {
|
||||
let session_manager = SessionManager::new();
|
||||
let sse_notifier = SseNotifier::new();
|
||||
let config = AcpServerConfig::enabled().set_port(4000);
|
||||
|
||||
// Pre-populate some state
|
||||
session_manager.register_client(None);
|
||||
|
||||
let agent_request_tracker = Arc::new(crate::agent_requests::AgentRequestTracker::new());
|
||||
|
||||
let state = AcpServerState::with_components(
|
||||
session_manager,
|
||||
sse_notifier,
|
||||
agent_request_tracker,
|
||||
config,
|
||||
);
|
||||
|
||||
assert!(state.config().enabled);
|
||||
assert_eq!(state.config().port, 4000);
|
||||
assert_eq!(state.session_manager().client_count(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_acp_server_state_clone() {
|
||||
let state1 = AcpServerState::new(AcpServerConfig::enabled());
|
||||
let state2 = state1.clone();
|
||||
|
||||
// Both should share the same underlying state
|
||||
let client_id = state1.session_manager().register_client(None);
|
||||
assert_eq!(state2.session_manager().client_count(), 1);
|
||||
assert!(state2.session_manager().get_client(&client_id).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_acp_server_state_debug() {
|
||||
let state = AcpServerState::new(AcpServerConfig::enabled());
|
||||
let debug_str = format!("{:?}", state);
|
||||
|
||||
assert!(debug_str.contains("AcpServerState"));
|
||||
assert!(debug_str.contains("session_count"));
|
||||
assert!(debug_str.contains("client_count"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_cors_layer_any_origin() {
|
||||
let config = AcpServerConfig::default();
|
||||
let _cors = build_cors_layer(&config);
|
||||
// No assertion needed - just verify it doesn't panic
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_cors_layer_specific_origins() {
|
||||
let config = AcpServerConfig::default()
|
||||
.set_allowed_origins(Some(vec![
|
||||
"http://localhost:3000".to_string(),
|
||||
"https://app.example.com".to_string(),
|
||||
]));
|
||||
let _cors = build_cors_layer(&config);
|
||||
// No assertion needed - just verify it doesn't panic
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_cors_layer_empty_origins() {
|
||||
let config = AcpServerConfig::default()
|
||||
.set_allowed_origins(Some(vec![]));
|
||||
let _cors = build_cors_layer(&config);
|
||||
// Should fall back to any origin
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,185 @@
|
||||
//! NoOp Connector Operations for Testing
|
||||
//!
|
||||
//! This module provides a stub implementation of `ConnectorOperations` that
|
||||
//! can be used for testing the RPC handler without actual connector access.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::error::AcpServerError;
|
||||
|
||||
use super::types::{ConnectorInfo, ConnectorOperations, SessionInfo};
|
||||
|
||||
/// A no-op implementation of ConnectorOperations for testing
|
||||
///
|
||||
/// This stub provides minimal implementations that return success without
|
||||
/// actually doing anything. Useful for unit testing the RPC dispatch logic.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct NoOpConnectorOperations;
|
||||
|
||||
#[async_trait]
|
||||
impl ConnectorOperations for NoOpConnectorOperations {
|
||||
async fn create_session(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
_cwd: Option<String>,
|
||||
_ownership: dirigent_protocol::SessionOwnership,
|
||||
) -> Result<SessionInfo, AcpServerError> {
|
||||
Ok(SessionInfo {
|
||||
session_id: uuid::Uuid::new_v4().to_string(),
|
||||
title: Some("New Session".to_string()),
|
||||
connector_id: connector_id.to_string(),
|
||||
cwd: None,
|
||||
created_at: chrono::Utc::now().to_rfc3339(),
|
||||
models: None,
|
||||
modes: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn load_session(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
session_id: &str,
|
||||
_cwd: Option<String>,
|
||||
_mcp_servers: Option<serde_json::Value>,
|
||||
) -> Result<SessionInfo, AcpServerError> {
|
||||
Ok(SessionInfo {
|
||||
session_id: session_id.to_string(),
|
||||
title: Some("Loaded Session".to_string()),
|
||||
connector_id: connector_id.to_string(),
|
||||
cwd: None,
|
||||
created_at: chrono::Utc::now().to_rfc3339(),
|
||||
models: None,
|
||||
modes: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_message(
|
||||
&self,
|
||||
_connector_id: &str,
|
||||
_session_id: &str,
|
||||
text: String,
|
||||
) -> Result<String, AcpServerError> {
|
||||
debug!("NoOp: send_message - {} chars", text.len());
|
||||
Ok("end_turn".to_string())
|
||||
}
|
||||
|
||||
async fn cancel_generation(
|
||||
&self,
|
||||
_connector_id: &str,
|
||||
_session_id: &str,
|
||||
) -> Result<(), AcpServerError> {
|
||||
debug!("NoOp: cancel_generation");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_connectors(&self) -> Result<Vec<ConnectorInfo>, AcpServerError> {
|
||||
Ok(vec![ConnectorInfo {
|
||||
id: "stub-connector".to_string(),
|
||||
name: "Stub Connector".to_string(),
|
||||
connector_type: "stub".to_string(),
|
||||
available: true,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn default_connector_id(&self) -> Option<String> {
|
||||
Some("stub-connector".to_string())
|
||||
}
|
||||
|
||||
async fn get_connector_commands(
|
||||
&self,
|
||||
_connector_id: &str,
|
||||
) -> Result<Vec<crate::sse::SlashCommand>, AcpServerError> {
|
||||
// Return a stub list of commands
|
||||
Ok(vec![crate::sse::SlashCommand {
|
||||
name: "echo".to_string(),
|
||||
description: "Echo command (stub)".to_string(),
|
||||
input: None,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn send_agent_response(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
request_id: serde_json::Value,
|
||||
_response: serde_json::Value,
|
||||
) -> Result<(), AcpServerError> {
|
||||
debug!(
|
||||
"NoOp: send_agent_response - connector: {}, request: {}",
|
||||
connector_id, request_id
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_session_metadata(
|
||||
&self,
|
||||
_connector_id: &str,
|
||||
_session_id: &str,
|
||||
) -> Result<
|
||||
(
|
||||
Option<dirigent_protocol::SessionModelState>,
|
||||
Option<dirigent_protocol::SessionModeState>,
|
||||
),
|
||||
AcpServerError,
|
||||
> {
|
||||
debug!("NoOp: get_session_metadata");
|
||||
// Return None for both - no metadata available in stub
|
||||
Ok((None, None))
|
||||
}
|
||||
|
||||
async fn set_session_mode(
|
||||
&self,
|
||||
_connector_id: &str,
|
||||
_session_id: &str,
|
||||
mode_id: &str,
|
||||
) -> Result<(), AcpServerError> {
|
||||
debug!("NoOp: set_session_mode - mode_id: {}", mode_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn set_session_model(
|
||||
&self,
|
||||
_connector_id: &str,
|
||||
_session_id: &str,
|
||||
model_id: &str,
|
||||
) -> Result<(), AcpServerError> {
|
||||
debug!("NoOp: set_session_model - model_id: {}", model_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_connector_agent_type(
|
||||
&self,
|
||||
_connector_id: &str,
|
||||
) -> Result<Option<String>, AcpServerError> {
|
||||
debug!("NoOp: get_connector_agent_type");
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_sessions(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
) -> Result<Vec<SessionInfo>, AcpServerError> {
|
||||
debug!("NoOp: list_sessions for connector: {}", connector_id);
|
||||
Ok(vec![
|
||||
SessionInfo {
|
||||
session_id: "stub-session-1".to_string(),
|
||||
title: Some("Stub Session".to_string()),
|
||||
connector_id: connector_id.to_string(),
|
||||
cwd: None,
|
||||
created_at: chrono::Utc::now().to_rfc3339(),
|
||||
models: None,
|
||||
modes: None,
|
||||
},
|
||||
])
|
||||
}
|
||||
|
||||
async fn resolve_session_connector(&self, _session_id: &str) -> Option<String> {
|
||||
debug!("NoOp: resolve_session_connector");
|
||||
None
|
||||
}
|
||||
|
||||
async fn list_all_sessions(&self) -> Result<Vec<SessionInfo>, AcpServerError> {
|
||||
debug!("NoOp: list_all_sessions");
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,705 @@
|
||||
//! JSON-RPC Request/Response Types for ACP Server
|
||||
//!
|
||||
//! This module contains all the parameter and result types used in
|
||||
//! the ACP JSON-RPC protocol.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::AcpServerError;
|
||||
use crate::sse::ConfigOption;
|
||||
|
||||
// ============================================================================
|
||||
// Session and Connector Info Types
|
||||
// ============================================================================
|
||||
|
||||
/// Information about a session
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionInfo {
|
||||
/// The session ID
|
||||
pub session_id: String,
|
||||
|
||||
/// Optional title for the session
|
||||
pub title: Option<String>,
|
||||
|
||||
/// The connector handling this session
|
||||
pub connector_id: String,
|
||||
|
||||
/// Working directory for this session
|
||||
pub cwd: Option<String>,
|
||||
|
||||
/// When the session was created (ISO 8601 format)
|
||||
pub created_at: String,
|
||||
|
||||
/// Available models and current model (optional, connector-dependent)
|
||||
pub models: Option<dirigent_protocol::SessionModelState>,
|
||||
|
||||
/// Available modes and current mode (optional, connector-dependent)
|
||||
pub modes: Option<dirigent_protocol::SessionModeState>,
|
||||
}
|
||||
|
||||
/// Information about a connector
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConnectorInfo {
|
||||
/// Unique identifier for this connector
|
||||
pub id: String,
|
||||
|
||||
/// Human-readable name
|
||||
pub name: String,
|
||||
|
||||
/// Type of connector (e.g., "opencode", "gateway")
|
||||
pub connector_type: String,
|
||||
|
||||
/// Whether the connector is currently available
|
||||
pub available: bool,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Initialize Types
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for the initialize request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeParams {
|
||||
/// Client capabilities
|
||||
#[serde(default)]
|
||||
pub capabilities: Option<serde_json::Value>,
|
||||
|
||||
/// Client name
|
||||
#[serde(default)]
|
||||
pub client_name: Option<String>,
|
||||
|
||||
/// Client version
|
||||
#[serde(default)]
|
||||
pub client_version: Option<String>,
|
||||
}
|
||||
|
||||
/// Result of the initialize request (ACP spec compliant)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeResult {
|
||||
/// Protocol version (integer)
|
||||
pub protocol_version: u32,
|
||||
|
||||
/// Agent capabilities
|
||||
pub agent_capabilities: AgentCapabilities,
|
||||
|
||||
/// Agent info (name, title, version)
|
||||
pub agent_info: AgentInfo,
|
||||
|
||||
/// Authentication methods (empty for now)
|
||||
pub auth_methods: Vec<AuthMethod>,
|
||||
|
||||
/// Client ID for HTTP-based transports (required for SSE routing)
|
||||
/// Note: This is not in the official ACP spec but is required for stateless
|
||||
/// HTTP connections where the client needs an ID to subscribe to SSE events.
|
||||
pub client_id: String,
|
||||
}
|
||||
|
||||
/// Agent information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AgentInfo {
|
||||
/// Agent name (programmatic)
|
||||
pub name: String,
|
||||
|
||||
/// Agent title (human-readable)
|
||||
pub title: String,
|
||||
|
||||
/// Agent version
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
/// Agent capabilities advertised to clients
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AgentCapabilities {
|
||||
/// Whether session/load is supported
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub load_session: Option<bool>,
|
||||
|
||||
/// Whether session/list is supported
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub list_sessions: Option<bool>,
|
||||
|
||||
/// Session capabilities (resume, fork, etc.)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_capabilities: Option<SessionCapabilities>,
|
||||
|
||||
/// Prompt capabilities (content types supported)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_capabilities: Option<PromptCapabilities>,
|
||||
|
||||
/// MCP server support
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mcp: Option<McpCapabilities>,
|
||||
}
|
||||
|
||||
/// Extended session capabilities
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionCapabilities {
|
||||
/// Whether session/list is supported (empty object = supported)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub list: Option<serde_json::Value>,
|
||||
|
||||
/// Whether session/resume is supported (empty object = supported)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub resume: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Prompt capabilities (content types)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptCapabilities {
|
||||
/// Image content support
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub image: Option<bool>,
|
||||
|
||||
/// Audio content support
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub audio: Option<bool>,
|
||||
|
||||
/// Embedded context support
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub embedded_context: Option<bool>,
|
||||
}
|
||||
|
||||
/// MCP capabilities
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct McpCapabilities {
|
||||
/// HTTP transport support
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub http: Option<bool>,
|
||||
|
||||
/// SSE transport support (deprecated)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sse: Option<bool>,
|
||||
}
|
||||
|
||||
/// Authentication method
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AuthMethod {
|
||||
/// Method ID
|
||||
pub id: String,
|
||||
|
||||
/// Method name
|
||||
pub name: String,
|
||||
|
||||
/// Method description
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
impl Default for AgentCapabilities {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
load_session: Some(true),
|
||||
list_sessions: Some(true),
|
||||
session_capabilities: Some(SessionCapabilities {
|
||||
list: Some(serde_json::json!({})),
|
||||
resume: Some(serde_json::json!({})),
|
||||
}),
|
||||
prompt_capabilities: Some(PromptCapabilities {
|
||||
image: Some(true),
|
||||
audio: None,
|
||||
embedded_context: Some(true),
|
||||
}),
|
||||
mcp: Some(McpCapabilities {
|
||||
http: Some(true),
|
||||
sse: Some(true),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session/New Types
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for session/new request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionNewParams {
|
||||
/// Optional connector ID to create the session on
|
||||
#[serde(default)]
|
||||
pub connector_id: Option<String>,
|
||||
|
||||
/// Optional working directory for the session
|
||||
#[serde(default)]
|
||||
pub cwd: Option<String>,
|
||||
|
||||
/// Optional client-provided session ID
|
||||
#[serde(default)]
|
||||
pub session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Result of session/new request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionNewResult {
|
||||
/// The client-facing session ID
|
||||
pub session_id: String,
|
||||
|
||||
/// Optional title for the session
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
|
||||
/// The connector handling this session
|
||||
pub connector_id: String,
|
||||
|
||||
/// When the session was created (ISO 8601 format)
|
||||
pub created_at: String,
|
||||
|
||||
/// Available models and current model (UNSTABLE in ACP spec)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub models: Option<dirigent_protocol::SessionModelState>,
|
||||
|
||||
/// Available modes and current mode
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub modes: Option<dirigent_protocol::SessionModeState>,
|
||||
|
||||
/// Configuration options (modes, models as unified config)
|
||||
/// This is the preferred way to expose configuration in ACP
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub config_options: Option<Vec<ConfigOption>>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session/Load Types
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for session/load request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionLoadParams {
|
||||
/// The session ID to load
|
||||
pub session_id: String,
|
||||
|
||||
/// The connector ID where the session exists (optional; resolved automatically if omitted)
|
||||
#[serde(default)]
|
||||
pub connector_id: Option<String>,
|
||||
|
||||
/// Optional working directory (standard ACP field sent by clients like Zed)
|
||||
#[serde(default)]
|
||||
pub cwd: Option<String>,
|
||||
|
||||
/// Optional MCP server configurations (standard ACP field)
|
||||
#[serde(default)]
|
||||
pub mcp_servers: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
/// Result of session/load request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionLoadResult {
|
||||
/// The client-facing session ID
|
||||
pub session_id: String,
|
||||
|
||||
/// Optional title for the session
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
|
||||
/// The connector handling this session
|
||||
pub connector_id: String,
|
||||
|
||||
/// When the session was created (ISO 8601 format)
|
||||
pub created_at: String,
|
||||
|
||||
/// Available models and current model (UNSTABLE in ACP spec)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub models: Option<dirigent_protocol::SessionModelState>,
|
||||
|
||||
/// Available modes and current mode
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub modes: Option<dirigent_protocol::SessionModeState>,
|
||||
|
||||
/// Configuration options (modes, models as unified config)
|
||||
/// This is the preferred way to expose configuration in ACP
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub config_options: Option<Vec<ConfigOption>>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session/List Types
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for session/list request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionListParams {
|
||||
/// Optional connector ID to list sessions from
|
||||
#[serde(default)]
|
||||
pub connector_id: Option<String>,
|
||||
|
||||
/// Optional working directory filter
|
||||
#[serde(default)]
|
||||
pub cwd: Option<String>,
|
||||
|
||||
/// Optional pagination cursor from previous response
|
||||
#[serde(default)]
|
||||
pub cursor: Option<String>,
|
||||
}
|
||||
|
||||
/// A session entry in a session/list response (ACP spec)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionListEntry {
|
||||
/// Unique session identifier
|
||||
pub session_id: String,
|
||||
|
||||
/// Working directory for this session
|
||||
pub cwd: String,
|
||||
|
||||
/// Human-readable title
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
|
||||
/// Last activity timestamp (ISO 8601)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub updated_at: Option<String>,
|
||||
|
||||
/// Agent-specific metadata
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "_meta")]
|
||||
pub meta: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Result of session/list request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionListResult {
|
||||
/// List of available sessions
|
||||
pub sessions: Vec<SessionListEntry>,
|
||||
|
||||
/// Pagination cursor for next page (absent when no more results)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub next_cursor: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session/Resume Types
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for session/resume request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionResumeParams {
|
||||
/// The session ID to resume
|
||||
pub session_id: String,
|
||||
|
||||
/// The connector ID where the session exists (optional; resolved automatically if omitted)
|
||||
#[serde(default)]
|
||||
pub connector_id: Option<String>,
|
||||
|
||||
/// Optional working directory (standard ACP field sent by clients like Zed)
|
||||
#[serde(default)]
|
||||
pub cwd: Option<String>,
|
||||
|
||||
/// Optional MCP server configurations (standard ACP field)
|
||||
#[serde(default)]
|
||||
pub mcp_servers: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
/// Result of session/resume request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionResumeResult {
|
||||
/// The client-facing session ID
|
||||
pub session_id: String,
|
||||
|
||||
/// Optional title for the session
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
|
||||
/// The connector handling this session
|
||||
pub connector_id: String,
|
||||
|
||||
/// When the session was created (ISO 8601 format)
|
||||
pub created_at: String,
|
||||
|
||||
/// Available models and current model
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub models: Option<dirigent_protocol::SessionModelState>,
|
||||
|
||||
/// Available modes and current mode
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub modes: Option<dirigent_protocol::SessionModeState>,
|
||||
|
||||
/// Configuration options
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub config_options: Option<Vec<crate::sse::ConfigOption>>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session/Prompt Types
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for session/prompt request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionPromptParams {
|
||||
/// The session ID to send the prompt to
|
||||
pub session_id: String,
|
||||
|
||||
/// The prompt content (array of content blocks)
|
||||
pub prompt: PromptContent,
|
||||
}
|
||||
|
||||
/// Content for a prompt - either simple text or structured blocks
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum PromptContent {
|
||||
/// Simple text content
|
||||
Text(String),
|
||||
|
||||
/// Structured content blocks
|
||||
Blocks(Vec<ContentBlock>),
|
||||
}
|
||||
|
||||
impl PromptContent {
|
||||
/// Convert to plain text representation
|
||||
pub fn to_text(&self) -> String {
|
||||
match self {
|
||||
PromptContent::Text(text) => text.clone(),
|
||||
PromptContent::Blocks(blocks) => blocks
|
||||
.iter()
|
||||
.filter_map(|b| {
|
||||
if let ContentBlock::Text { text } = b {
|
||||
Some(text.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A content block in a prompt
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentBlock {
|
||||
/// Text content
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
|
||||
/// Image content (base64 encoded)
|
||||
#[serde(rename = "image")]
|
||||
Image { data: String, media_type: String },
|
||||
}
|
||||
|
||||
/// Result of session/prompt request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionPromptResult {
|
||||
/// The reason the turn stopped
|
||||
pub stop_reason: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session/Cancel Types
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for session/cancel request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionCancelParams {
|
||||
/// The session ID to cancel
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
/// Result of session/cancel request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionCancelResult {
|
||||
/// Whether the cancellation was successful
|
||||
pub cancelled: bool,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session/Close Types
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for session/close request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionCloseParams {
|
||||
/// The session ID to close
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
/// Result of session/close request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionCloseResult {
|
||||
/// Whether the session was successfully closed
|
||||
pub closed: bool,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session/SetMode and Session/SetModel Types
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for session/set_mode request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionSetModeParams {
|
||||
/// The session ID to update
|
||||
pub session_id: String,
|
||||
|
||||
/// The mode ID to switch to
|
||||
pub mode_id: String,
|
||||
}
|
||||
|
||||
/// Parameters for session/set_model request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionSetModelParams {
|
||||
/// The session ID to update
|
||||
pub session_id: String,
|
||||
|
||||
/// The model ID to switch to
|
||||
pub model_id: String,
|
||||
}
|
||||
|
||||
/// Parameters for session/set_config_option request
|
||||
///
|
||||
/// This is the unified way to set configuration options (modes, models, etc.)
|
||||
/// for clients using the new config_options system.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionSetConfigOptionParams {
|
||||
/// The session ID to update
|
||||
pub session_id: String,
|
||||
|
||||
/// The config option ID (e.g., "mode", "model")
|
||||
pub config_id: String,
|
||||
|
||||
/// The value to set
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ConnectorOperations Trait
|
||||
// ============================================================================
|
||||
|
||||
/// Trait for connector operations that will be implemented by the web server
|
||||
///
|
||||
/// This trait abstracts the operations that require access to `CoreHandle`,
|
||||
/// allowing the RPC handler to remain decoupled from `dirigent_core`.
|
||||
/// The web server implements this trait to bridge between the ACP server
|
||||
/// and the actual connector implementations.
|
||||
#[async_trait::async_trait]
|
||||
pub trait ConnectorOperations: Send + Sync {
|
||||
/// Create a new session on the specified connector
|
||||
async fn create_session(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
cwd: Option<String>,
|
||||
ownership: dirigent_protocol::SessionOwnership,
|
||||
) -> Result<SessionInfo, AcpServerError>;
|
||||
|
||||
/// Load an existing session from a connector
|
||||
async fn load_session(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
session_id: &str,
|
||||
cwd: Option<String>,
|
||||
mcp_servers: Option<serde_json::Value>,
|
||||
) -> Result<SessionInfo, AcpServerError>;
|
||||
|
||||
/// Send a message to a session and wait for completion
|
||||
async fn send_message(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
session_id: &str,
|
||||
text: String,
|
||||
) -> Result<String, AcpServerError>;
|
||||
|
||||
/// Cancel active generation on a session
|
||||
async fn cancel_generation(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
session_id: &str,
|
||||
) -> Result<(), AcpServerError>;
|
||||
|
||||
/// List all available connectors
|
||||
async fn list_connectors(&self) -> Result<Vec<ConnectorInfo>, AcpServerError>;
|
||||
|
||||
/// Get the default connector ID (if configured)
|
||||
async fn default_connector_id(&self) -> Option<String>;
|
||||
|
||||
/// Get available commands/tools from a connector
|
||||
async fn get_connector_commands(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
) -> Result<Vec<crate::sse::SlashCommand>, AcpServerError>;
|
||||
|
||||
/// Send an agent response back to a connector
|
||||
async fn send_agent_response(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
request_id: serde_json::Value,
|
||||
response: serde_json::Value,
|
||||
) -> Result<(), AcpServerError>;
|
||||
|
||||
/// Get session metadata (models/modes) from a session
|
||||
async fn get_session_metadata(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
session_id: &str,
|
||||
) -> Result<
|
||||
(
|
||||
Option<dirigent_protocol::SessionModelState>,
|
||||
Option<dirigent_protocol::SessionModeState>,
|
||||
),
|
||||
AcpServerError,
|
||||
>;
|
||||
|
||||
/// Set the session mode
|
||||
async fn set_session_mode(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
session_id: &str,
|
||||
mode_id: &str,
|
||||
) -> Result<(), AcpServerError>;
|
||||
|
||||
/// Set the session model
|
||||
async fn set_session_model(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
session_id: &str,
|
||||
model_id: &str,
|
||||
) -> Result<(), AcpServerError>;
|
||||
|
||||
/// Get the agent type for a connector (for mode/model mapping)
|
||||
async fn get_connector_agent_type(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
) -> Result<Option<String>, AcpServerError>;
|
||||
|
||||
/// List all sessions on a connector
|
||||
async fn list_sessions(
|
||||
&self,
|
||||
connector_id: &str,
|
||||
) -> Result<Vec<SessionInfo>, AcpServerError>;
|
||||
|
||||
/// Look up which connector owns a session by its session ID.
|
||||
/// Uses archivist to find the connector. Returns connector_id (not UID).
|
||||
/// Default: returns None (no archivist available).
|
||||
async fn resolve_session_connector(&self, _session_id: &str) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// List sessions across all connectors (archivist-backed).
|
||||
/// Used when the resolved connector is a gateway to provide cross-connector view.
|
||||
/// Default: returns empty vec (no archivist available).
|
||||
async fn list_all_sessions(&self) -> Result<Vec<SessionInfo>, AcpServerError> {
|
||||
Ok(vec![])
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,435 @@
|
||||
//! Content and Metadata Transformation Utilities
|
||||
//!
|
||||
//! This module provides helper functions for transforming content between
|
||||
//! the internal Dirigent protocol and the ACP wire format.
|
||||
|
||||
use dirigent_protocol::ContentBlock;
|
||||
use serde_json::json;
|
||||
|
||||
/// Extract 'kind' from tool call metadata (stored as acp_kind)
|
||||
pub fn extract_kind(tool_call: &dirigent_protocol::ToolCall) -> Option<String> {
|
||||
tool_call
|
||||
.metadata
|
||||
.as_ref()
|
||||
.and_then(|m| m.get("acp_kind"))
|
||||
.and_then(|k| k.as_str())
|
||||
.map(String::from)
|
||||
}
|
||||
|
||||
/// Convert ToolCallStatus to ACP string format
|
||||
pub fn tool_call_status_to_string(status: &dirigent_protocol::ToolCallStatus) -> String {
|
||||
match status {
|
||||
dirigent_protocol::ToolCallStatus::Pending => "pending".to_string(),
|
||||
dirigent_protocol::ToolCallStatus::Running => "in_progress".to_string(),
|
||||
dirigent_protocol::ToolCallStatus::Completed => "completed".to_string(),
|
||||
dirigent_protocol::ToolCallStatus::Error => "failed".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Unwrap ToolCallContent wrappers to extract ContentBlocks for SSE output
|
||||
///
|
||||
/// The internal protocol uses ToolCallContent wrappers, but the SSE/ACP wire format
|
||||
/// expects flat ContentBlock arrays. This function extracts Content variants only.
|
||||
pub fn unwrap_tool_call_content(
|
||||
content: &[dirigent_protocol::ToolCallContent],
|
||||
) -> Vec<ContentBlock> {
|
||||
content
|
||||
.iter()
|
||||
.filter_map(|wrapper| match wrapper {
|
||||
dirigent_protocol::ToolCallContent::Content { content } => Some(content.clone()),
|
||||
// Diff and Terminal variants are not yet supported in SSE output
|
||||
_ => None,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Rebuild _meta for ACP output (ensures claudeCode.toolName is present)
|
||||
pub fn rebuild_meta(
|
||||
meta: &Option<dirigent_protocol::Meta>,
|
||||
tool_call: &dirigent_protocol::ToolCall,
|
||||
) -> Option<serde_json::Value> {
|
||||
let mut meta_obj = serde_json::Map::new();
|
||||
|
||||
// Add claudeCode.toolName
|
||||
let mut claude_code = serde_json::Map::new();
|
||||
claude_code.insert("toolName".to_string(), json!(tool_call.tool_name));
|
||||
|
||||
// Merge any existing _meta (preserves toolResponse, etc.)
|
||||
if let Some(existing_meta) = meta {
|
||||
// Convert Meta to JSON Value for merging
|
||||
if let Ok(meta_value) = serde_json::to_value(existing_meta) {
|
||||
if let Some(obj) = meta_value.as_object() {
|
||||
for (k, v) in obj {
|
||||
if k == "claudeCode" {
|
||||
// Merge into our claudeCode object
|
||||
if let Some(cc) = v.as_object() {
|
||||
for (ck, cv) in cc {
|
||||
// Don't override toolName we just set
|
||||
if ck != "toolName" {
|
||||
claude_code.insert(ck.clone(), cv.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
meta_obj.insert(k.clone(), v.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
meta_obj.insert("claudeCode".to_string(), json!(claude_code));
|
||||
Some(json!(meta_obj))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_extract_kind_present() {
|
||||
let mut metadata = serde_json::Map::new();
|
||||
metadata.insert("acp_kind".to_string(), json!("search"));
|
||||
|
||||
let tool_call = dirigent_protocol::ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: dirigent_protocol::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: Some(json!(metadata)),
|
||||
origin: None,
|
||||
};
|
||||
|
||||
assert_eq!(extract_kind(&tool_call), Some("search".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_kind_missing_metadata() {
|
||||
let tool_call = dirigent_protocol::ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: dirigent_protocol::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
assert_eq!(extract_kind(&tool_call), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_kind_missing_field() {
|
||||
let mut metadata = serde_json::Map::new();
|
||||
metadata.insert("other_field".to_string(), json!("value"));
|
||||
|
||||
let tool_call = dirigent_protocol::ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: dirigent_protocol::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: Some(json!(metadata)),
|
||||
origin: None,
|
||||
};
|
||||
|
||||
assert_eq!(extract_kind(&tool_call), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_kind_malformed_metadata() {
|
||||
// acp_kind is not a string
|
||||
let mut metadata = serde_json::Map::new();
|
||||
metadata.insert("acp_kind".to_string(), json!(123));
|
||||
|
||||
let tool_call = dirigent_protocol::ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: dirigent_protocol::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: Some(json!(metadata)),
|
||||
origin: None,
|
||||
};
|
||||
|
||||
assert_eq!(extract_kind(&tool_call), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_status_to_string_all_variants() {
|
||||
assert_eq!(
|
||||
tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Pending),
|
||||
"pending"
|
||||
);
|
||||
assert_eq!(
|
||||
tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Running),
|
||||
"in_progress"
|
||||
);
|
||||
assert_eq!(
|
||||
tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Completed),
|
||||
"completed"
|
||||
);
|
||||
assert_eq!(
|
||||
tool_call_status_to_string(&dirigent_protocol::ToolCallStatus::Error),
|
||||
"failed"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rebuild_meta_basic() {
|
||||
let tool_call = dirigent_protocol::ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "some_tool".to_string(),
|
||||
status: dirigent_protocol::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let meta = rebuild_meta(&None, &tool_call);
|
||||
|
||||
assert!(meta.is_some());
|
||||
let meta_obj = meta.unwrap();
|
||||
assert_eq!(meta_obj["claudeCode"]["toolName"], "some_tool");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rebuild_meta_preserves_existing_fields() {
|
||||
let existing_meta = dirigent_protocol::Meta {
|
||||
provider: Some(dirigent_protocol::ProviderMeta {
|
||||
name: "test_provider".to_string(),
|
||||
original_ids: None,
|
||||
raw_excerpt: None,
|
||||
}),
|
||||
extra: std::collections::HashMap::from([(
|
||||
"customField".to_string(),
|
||||
json!("customValue"),
|
||||
)]),
|
||||
};
|
||||
|
||||
let tool_call = dirigent_protocol::ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: dirigent_protocol::ToolCallStatus::Running,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let meta = rebuild_meta(&Some(existing_meta), &tool_call);
|
||||
|
||||
assert!(meta.is_some());
|
||||
let meta_obj = meta.unwrap();
|
||||
|
||||
// Verify claudeCode.toolName is present
|
||||
assert_eq!(meta_obj["claudeCode"]["toolName"], "bash");
|
||||
|
||||
// Verify existing fields are preserved
|
||||
assert_eq!(meta_obj["provider"]["name"], "test_provider");
|
||||
assert_eq!(meta_obj["customField"], "customValue");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rebuild_meta_merges_claude_code_fields() {
|
||||
let mut existing_meta = dirigent_protocol::Meta::default();
|
||||
existing_meta.extra.insert(
|
||||
"claudeCode".to_string(),
|
||||
json!({
|
||||
"toolResponse": "some response",
|
||||
"otherField": "other value"
|
||||
}),
|
||||
);
|
||||
|
||||
let tool_call = dirigent_protocol::ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "test_tool".to_string(),
|
||||
status: dirigent_protocol::ToolCallStatus::Completed,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let meta = rebuild_meta(&Some(existing_meta), &tool_call);
|
||||
|
||||
assert!(meta.is_some());
|
||||
let meta_obj = meta.unwrap();
|
||||
|
||||
// Verify toolName is present (we set it)
|
||||
assert_eq!(meta_obj["claudeCode"]["toolName"], "test_tool");
|
||||
|
||||
// Verify existing claudeCode fields are preserved
|
||||
assert_eq!(meta_obj["claudeCode"]["toolResponse"], "some response");
|
||||
assert_eq!(meta_obj["claudeCode"]["otherField"], "other value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rebuild_meta_does_not_override_tool_name() {
|
||||
// Existing meta has a different toolName
|
||||
let mut existing_meta = dirigent_protocol::Meta::default();
|
||||
existing_meta.extra.insert(
|
||||
"claudeCode".to_string(),
|
||||
json!({
|
||||
"toolName": "wrong_tool",
|
||||
"toolResponse": "response"
|
||||
}),
|
||||
);
|
||||
|
||||
let tool_call = dirigent_protocol::ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "correct_tool".to_string(),
|
||||
status: dirigent_protocol::ToolCallStatus::Completed,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let meta = rebuild_meta(&Some(existing_meta), &tool_call);
|
||||
|
||||
assert!(meta.is_some());
|
||||
let meta_obj = meta.unwrap();
|
||||
|
||||
// Verify our toolName is used (not the one from existing meta)
|
||||
assert_eq!(meta_obj["claudeCode"]["toolName"], "correct_tool");
|
||||
|
||||
// Verify other fields are preserved
|
||||
assert_eq!(meta_obj["claudeCode"]["toolResponse"], "response");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rebuild_meta_preserves_tool_response() {
|
||||
// Test that rebuild_meta preserves toolResponse from incoming Claude updates
|
||||
let mut existing_meta = dirigent_protocol::Meta::default();
|
||||
existing_meta.extra.insert(
|
||||
"claudeCode".to_string(),
|
||||
json!({
|
||||
"toolResponse": {
|
||||
"mode": "content",
|
||||
"numFiles": 0,
|
||||
"filenames": [],
|
||||
"content": "some output",
|
||||
"numLines": 58,
|
||||
"appliedLimit": 100
|
||||
},
|
||||
"toolName": "OldToolName"
|
||||
}),
|
||||
);
|
||||
|
||||
let tool_call = dirigent_protocol::ToolCall {
|
||||
id: "call_456".to_string(),
|
||||
tool_name: "Grep".to_string(),
|
||||
status: dirigent_protocol::ToolCallStatus::Running,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let meta = rebuild_meta(&Some(existing_meta), &tool_call);
|
||||
|
||||
assert!(meta.is_some());
|
||||
let meta_obj = meta.unwrap();
|
||||
|
||||
// toolName should be updated to current tool_call.tool_name
|
||||
assert_eq!(meta_obj["claudeCode"]["toolName"], "Grep");
|
||||
|
||||
// toolResponse should be preserved completely
|
||||
assert!(meta_obj["claudeCode"]["toolResponse"].is_object());
|
||||
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["mode"], "content");
|
||||
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numFiles"], 0);
|
||||
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numLines"], 58);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rebuild_meta_round_trip_no_data_loss() {
|
||||
// Test incoming → internal → outgoing preserves all fields
|
||||
let mut existing_meta = dirigent_protocol::Meta::default();
|
||||
existing_meta.extra.insert(
|
||||
"claudeCode".to_string(),
|
||||
json!({
|
||||
"toolResponse": {
|
||||
"mode": "content",
|
||||
"numFiles": 3,
|
||||
"filenames": ["file1.rs", "file2.rs", "file3.rs"],
|
||||
"content": "grep results here",
|
||||
"numLines": 42,
|
||||
"appliedLimit": 100,
|
||||
"customField": "should be preserved",
|
||||
"nestedObject": {
|
||||
"deep": "value"
|
||||
}
|
||||
},
|
||||
"additionalField": "also preserved"
|
||||
}),
|
||||
);
|
||||
existing_meta
|
||||
.extra
|
||||
.insert("customTopLevel".to_string(), json!("preserved too"));
|
||||
|
||||
let tool_call = dirigent_protocol::ToolCall {
|
||||
id: "call_789".to_string(),
|
||||
tool_name: "Grep".to_string(),
|
||||
status: dirigent_protocol::ToolCallStatus::Completed,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let meta = rebuild_meta(&Some(existing_meta), &tool_call);
|
||||
|
||||
assert!(meta.is_some());
|
||||
let meta_obj = meta.unwrap();
|
||||
|
||||
// Verify NO data loss
|
||||
assert_eq!(meta_obj["claudeCode"]["toolName"], "Grep");
|
||||
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["mode"], "content");
|
||||
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numFiles"], 3);
|
||||
assert_eq!(meta_obj["claudeCode"]["toolResponse"]["numLines"], 42);
|
||||
assert_eq!(
|
||||
meta_obj["claudeCode"]["toolResponse"]["customField"],
|
||||
"should be preserved"
|
||||
);
|
||||
assert_eq!(
|
||||
meta_obj["claudeCode"]["toolResponse"]["nestedObject"]["deep"],
|
||||
"value"
|
||||
);
|
||||
assert_eq!(meta_obj["claudeCode"]["additionalField"], "also preserved");
|
||||
assert_eq!(meta_obj["customTopLevel"], "preserved too");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,453 @@
|
||||
//! Event Translation Layer
|
||||
//!
|
||||
//! This module provides translation from Dirigent protocol Events to ACP notifications.
|
||||
//! It handles the mapping between internal event representations and the wire format
|
||||
//! expected by ACP clients.
|
||||
|
||||
use dirigent_protocol::{Event, SessionUpdate};
|
||||
|
||||
use super::content_transform::{
|
||||
extract_kind, rebuild_meta, tool_call_status_to_string, unwrap_tool_call_content,
|
||||
};
|
||||
use super::types::{
|
||||
models_to_config_option, modes_to_config_option, ConfigOption, ConfigOptionChoice,
|
||||
ConfigOptionType, SessionUpdateParams, SessionUpdateVariant,
|
||||
};
|
||||
|
||||
/// Translate a Dirigent protocol Event to ACP notifications
|
||||
///
|
||||
/// Not all Dirigent events need to be forwarded to ACP clients. This function
|
||||
/// returns a vector of notifications for events that should be sent, and an empty
|
||||
/// vector for events that should be filtered out.
|
||||
///
|
||||
/// ## Mapped Events
|
||||
///
|
||||
/// - `Event::SessionUpdate` with `SessionUpdate::AgentMessageChunk` -> `MessageChunk`
|
||||
/// - `Event::SessionUpdate` with `SessionUpdate::AgentThoughtChunk` -> `MessageChunk` (with content_type)
|
||||
/// - `Event::SessionUpdate` with `SessionUpdate::ToolCall` -> `ToolCallUpdate`
|
||||
/// - `Event::SessionUpdate` with `SessionUpdate::ToolCallUpdate` -> `ToolCallUpdate`
|
||||
/// - `Event::MessageCompleted` -> `MessageComplete`
|
||||
/// - `Event::SessionIdle` -> `SessionIdle`
|
||||
/// - `Event::SessionMetadataReceived` -> `ConfigOptionUpdate` (with modes and models)
|
||||
/// - `Event::MessageFailed` -> `SessionError`
|
||||
/// - `Event::Error` -> `SessionError` (system-wide error)
|
||||
///
|
||||
/// ## Filtered Events
|
||||
///
|
||||
/// - `Event::SessionsListed` - List operations don't need streaming
|
||||
/// - `Event::SessionCreated` - Handled via RPC response
|
||||
/// - `Event::SessionUpdated` - Metadata updates, not content
|
||||
/// - `Event::SessionMetadataUpdated` - Metadata updates
|
||||
/// - `Event::SessionDeleted` - Handled via RPC
|
||||
/// - `Event::Connected` / `Event::Disconnected` - System events
|
||||
/// - `Event::ConnectorCreated` / `Event::ConnectorRemoved` - System events
|
||||
/// - `Event::MessagesListed` - List operations
|
||||
/// - `Event::MessageStarted` - Initial message, content comes via chunks
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `event`: The Dirigent protocol event to translate
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vec of `SessionUpdateParams` to be sent. Most events return 0 or 1 update.
|
||||
pub fn translate_event(event: &Event) -> Vec<SessionUpdateParams> {
|
||||
match event {
|
||||
// SessionUpdate events - the main streaming content
|
||||
Event::SessionUpdate {
|
||||
session_id, update, ..
|
||||
} => translate_session_update(session_id, update)
|
||||
.map(|u| vec![u])
|
||||
.unwrap_or_default(),
|
||||
|
||||
// Message completion
|
||||
Event::MessageCompleted { message, .. } => vec![SessionUpdateParams {
|
||||
session_id: message.session_id.clone(),
|
||||
update: SessionUpdateVariant::MessageComplete {
|
||||
message_id: Some(message.id.clone()),
|
||||
},
|
||||
..Default::default()
|
||||
}],
|
||||
|
||||
// Session idle
|
||||
Event::SessionIdle { session_id, .. } => vec![SessionUpdateParams {
|
||||
session_id: session_id.clone(),
|
||||
update: SessionUpdateVariant::SessionIdle {},
|
||||
..Default::default()
|
||||
}],
|
||||
|
||||
// Session metadata received - forward as BOTH config_option_update AND current_mode_update
|
||||
//
|
||||
// We emit both notification types to support:
|
||||
// - New clients (acp-beta flag): Use config_option_update
|
||||
// - Legacy clients: Use current_mode_update (mode only - no legacy model updates exist)
|
||||
//
|
||||
// Clients safely ignore notification types they don't understand.
|
||||
// See: https://agentclientprotocol.com/rfds/session-config-options
|
||||
Event::SessionMetadataReceived {
|
||||
session_id,
|
||||
models,
|
||||
modes,
|
||||
config_options: event_config_options,
|
||||
..
|
||||
} => {
|
||||
let mut updates = vec![];
|
||||
|
||||
// If the event already has config_options from the agent, prefer those
|
||||
// over converting from modes/models (agent-provided options are authoritative).
|
||||
if let Some(agent_config_options) = event_config_options {
|
||||
if !agent_config_options.is_empty() {
|
||||
// Emit ConfigOptionUpdate with agent-provided options
|
||||
updates.push(SessionUpdateParams {
|
||||
session_id: session_id.clone(),
|
||||
update: SessionUpdateVariant::ConfigOptionUpdate {
|
||||
config_options: agent_config_options.iter().map(|co| ConfigOption {
|
||||
id: co.id.clone(),
|
||||
name: co.name.clone(),
|
||||
description: co.description.clone(),
|
||||
category: co.category.clone(),
|
||||
option_type: match co.option_type {
|
||||
dirigent_protocol::ConfigOptionType::Select => ConfigOptionType::Select,
|
||||
},
|
||||
current_value: co.current_value.clone(),
|
||||
options: Some(co.options.iter().map(|v| ConfigOptionChoice {
|
||||
value: v.value.clone(),
|
||||
name: v.name.clone(),
|
||||
description: v.description.clone(),
|
||||
}).collect()),
|
||||
}).collect(),
|
||||
},
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Also emit CurrentModeUpdate for legacy clients if a mode option exists
|
||||
if let Some(mode_opt) = agent_config_options.iter().find(|co| co.id == "mode" || co.category.as_deref() == Some("mode")) {
|
||||
updates.push(SessionUpdateParams {
|
||||
session_id: session_id.clone(),
|
||||
update: SessionUpdateVariant::CurrentModeUpdate {
|
||||
mode_id: mode_opt.current_value.clone(),
|
||||
},
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
return updates;
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to building config_options from legacy modes/models fields
|
||||
let mut config_options: Vec<ConfigOption> = vec![];
|
||||
if let Some(modes_state) = modes {
|
||||
config_options.push(modes_to_config_option(modes_state));
|
||||
|
||||
// Also emit CurrentModeUpdate for legacy clients
|
||||
// Legacy protocol only supports mode updates (no model updates)
|
||||
updates.push(SessionUpdateParams {
|
||||
session_id: session_id.clone(),
|
||||
update: SessionUpdateVariant::CurrentModeUpdate {
|
||||
mode_id: modes_state.current_mode_id.clone(),
|
||||
},
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(models_state) = models {
|
||||
config_options.push(models_to_config_option(models_state));
|
||||
// Note: Legacy protocol has no CurrentModelUpdate - only new ConfigOptionUpdate
|
||||
}
|
||||
|
||||
if !config_options.is_empty() {
|
||||
updates.push(SessionUpdateParams {
|
||||
session_id: session_id.clone(),
|
||||
update: SessionUpdateVariant::ConfigOptionUpdate { config_options },
|
||||
..Default::default()
|
||||
});
|
||||
}
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
// Message failure - not forwarded in new structure (no equivalent variant)
|
||||
Event::MessageFailed { .. } => vec![],
|
||||
|
||||
// System-wide error - not forwarded in new structure (no equivalent variant)
|
||||
Event::Error { .. } => vec![],
|
||||
|
||||
// Agent request - will be handled in Phase 2 (bidirectional request/response)
|
||||
// For now, not forwarded (requires new SSE variant and response mechanism)
|
||||
Event::AgentRequest { .. } => vec![],
|
||||
|
||||
// Events that should not be forwarded via this translation path
|
||||
// (they may be handled elsewhere or are not relevant to ACP clients)
|
||||
Event::SessionsListed { .. }
|
||||
| Event::SessionCreated { .. }
|
||||
| Event::SessionUpdated { .. }
|
||||
| Event::SessionMetadataUpdated { .. }
|
||||
| Event::SessionDeleted { .. }
|
||||
| Event::SessionClosed { .. }
|
||||
| Event::SessionSystemMessageSet { .. }
|
||||
| Event::SessionError { .. }
|
||||
| Event::SessionTransferred { .. }
|
||||
| Event::ForwardingPanic { .. }
|
||||
| Event::MessagesListed { .. }
|
||||
| Event::MessageStarted { .. }
|
||||
| Event::TurnComplete { .. }
|
||||
| Event::ConnectorCreated { .. }
|
||||
| Event::ConnectorRemoved { .. }
|
||||
| Event::ConnectorStateChanged { .. }
|
||||
| Event::Connected
|
||||
| Event::Disconnected => vec![],
|
||||
|
||||
// ACP client connection events - these are handled by the UI directly
|
||||
// via the main SSE endpoint, not translated to session updates
|
||||
Event::AcpClientConnected { .. }
|
||||
| Event::AcpClientDisconnected { .. }
|
||||
| Event::AcpClientSessionOpened { .. }
|
||||
| Event::AcpClientSessionRouted { .. } => vec![],
|
||||
|
||||
// Inspector events - not relevant for ACP session updates
|
||||
Event::InspectorSnapshot { .. }
|
||||
| Event::InspectorNodeRegistered { .. }
|
||||
| Event::InspectorNodeRemoved { .. }
|
||||
| Event::InspectorStateChanged { .. }
|
||||
| Event::InspectorPropertiesUpdated { .. } => vec![],
|
||||
|
||||
// Archivist-to-frontend signal — not relevant for ACP clients
|
||||
Event::SessionRegistered { .. } => vec![],
|
||||
|
||||
// System task events — internal UI concern, not relevant for ACP clients
|
||||
Event::SystemTaskStatusChanged { .. } => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Translate a SessionUpdate to a SessionUpdateParams
|
||||
fn translate_session_update(
|
||||
session_id: &str,
|
||||
update: &SessionUpdate,
|
||||
) -> Option<SessionUpdateParams> {
|
||||
match update {
|
||||
// Agent message content chunk
|
||||
SessionUpdate::AgentMessageChunk { content, .. } => Some(SessionUpdateParams {
|
||||
session_id: session_id.to_string(),
|
||||
update: SessionUpdateVariant::AgentMessageChunk {
|
||||
content: content.clone(),
|
||||
},
|
||||
..Default::default()
|
||||
}),
|
||||
|
||||
// Agent thought chunk - render as agent message (thoughts shown as agent text in UI)
|
||||
SessionUpdate::AgentThoughtChunk { content, .. } => Some(SessionUpdateParams {
|
||||
session_id: session_id.to_string(),
|
||||
update: SessionUpdateVariant::AgentMessageChunk {
|
||||
content: content.clone(),
|
||||
},
|
||||
..Default::default()
|
||||
}),
|
||||
|
||||
// User message chunks are typically not streamed to other clients
|
||||
// Filter them out for now
|
||||
SessionUpdate::UserMessageChunk { .. } => None,
|
||||
|
||||
// Tool call initiated - forward to client
|
||||
SessionUpdate::ToolCall {
|
||||
tool_call, _meta, ..
|
||||
} => Some(SessionUpdateParams {
|
||||
session_id: session_id.to_string(),
|
||||
update: SessionUpdateVariant::ToolCall {
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
title: tool_call.title.clone(),
|
||||
kind: extract_kind(tool_call),
|
||||
raw_input: tool_call.raw_input.clone(),
|
||||
status: Some(tool_call_status_to_string(&tool_call.status)),
|
||||
content: unwrap_tool_call_content(&tool_call.content),
|
||||
_meta: rebuild_meta(_meta, tool_call),
|
||||
},
|
||||
..Default::default()
|
||||
}),
|
||||
|
||||
// Tool call update - forward to client
|
||||
SessionUpdate::ToolCallUpdate {
|
||||
tool_call, _meta, ..
|
||||
} => Some(SessionUpdateParams {
|
||||
session_id: session_id.to_string(),
|
||||
update: SessionUpdateVariant::ToolCallUpdate {
|
||||
tool_call_id: tool_call.id.clone(),
|
||||
status: Some(tool_call_status_to_string(&tool_call.status)),
|
||||
content: unwrap_tool_call_content(&tool_call.content),
|
||||
raw_output: tool_call.raw_output.clone(),
|
||||
error: tool_call.error.clone(),
|
||||
_meta: rebuild_meta(_meta, tool_call),
|
||||
},
|
||||
..Default::default()
|
||||
}),
|
||||
|
||||
// Unknown update type - forward as-is for transparent proxying
|
||||
SessionUpdate::Unknown { data } => Some(SessionUpdateParams {
|
||||
session_id: session_id.to_string(),
|
||||
update: SessionUpdateVariant::Unknown { data: data.clone() },
|
||||
..Default::default()
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use dirigent_protocol::{ContentBlock, Message, MessageRole, MessageStatus};
|
||||
|
||||
#[test]
|
||||
fn test_translate_session_update_agent_chunk() {
|
||||
let event = Event::SessionUpdate {
|
||||
connector_id: "conn-1".to_string(),
|
||||
session_id: "sess-1".to_string(),
|
||||
update: SessionUpdate::AgentMessageChunk {
|
||||
message_id: "msg-1".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
},
|
||||
};
|
||||
|
||||
let updates = translate_event(&event);
|
||||
assert_eq!(updates.len(), 1);
|
||||
|
||||
let update_params = &updates[0];
|
||||
assert_eq!(update_params.session_id, "sess-1");
|
||||
|
||||
match &update_params.update {
|
||||
SessionUpdateVariant::AgentMessageChunk { content } => match content {
|
||||
ContentBlock::Text { text } => {
|
||||
assert_eq!(text, "Hello");
|
||||
}
|
||||
_ => panic!("Expected Text content"),
|
||||
},
|
||||
_ => panic!("Expected AgentMessageChunk variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_translate_session_update_thought_chunk() {
|
||||
let event = Event::SessionUpdate {
|
||||
connector_id: "conn-1".to_string(),
|
||||
session_id: "sess-1".to_string(),
|
||||
update: SessionUpdate::AgentThoughtChunk {
|
||||
message_id: "msg-1".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Thinking...".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
},
|
||||
};
|
||||
|
||||
let updates = translate_event(&event);
|
||||
assert_eq!(updates.len(), 1);
|
||||
|
||||
let update_params = &updates[0];
|
||||
assert_eq!(update_params.session_id, "sess-1");
|
||||
|
||||
// Thought chunks are rendered as agent message chunks
|
||||
match &update_params.update {
|
||||
SessionUpdateVariant::AgentMessageChunk { content } => match content {
|
||||
ContentBlock::Text { text } => {
|
||||
assert_eq!(text, "Thinking...");
|
||||
}
|
||||
_ => panic!("Expected Text content"),
|
||||
},
|
||||
_ => panic!("Expected AgentMessageChunk variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_translate_message_completed() {
|
||||
let now = chrono::Utc::now();
|
||||
let event = Event::MessageCompleted {
|
||||
connector_id: "conn-1".to_string(),
|
||||
message: Message {
|
||||
id: "msg-1".to_string(),
|
||||
session_id: "sess-1".to_string(),
|
||||
role: MessageRole::Assistant,
|
||||
created_at: now,
|
||||
content: vec![],
|
||||
status: MessageStatus::Completed,
|
||||
metadata: None,
|
||||
},
|
||||
};
|
||||
|
||||
let updates = translate_event(&event);
|
||||
assert_eq!(updates.len(), 1);
|
||||
|
||||
let update_params = &updates[0];
|
||||
assert_eq!(update_params.session_id, "sess-1");
|
||||
|
||||
match &update_params.update {
|
||||
SessionUpdateVariant::MessageComplete { message_id } => {
|
||||
assert_eq!(message_id, &Some("msg-1".to_string()));
|
||||
}
|
||||
_ => panic!("Expected MessageComplete variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_translate_session_idle() {
|
||||
let event = Event::SessionIdle {
|
||||
connector_id: "test-connector".to_string(),
|
||||
session_id: "sess-1".to_string(),
|
||||
};
|
||||
|
||||
let updates = translate_event(&event);
|
||||
assert_eq!(updates.len(), 1);
|
||||
|
||||
let update_params = &updates[0];
|
||||
assert_eq!(update_params.session_id, "sess-1");
|
||||
|
||||
match &update_params.update {
|
||||
SessionUpdateVariant::SessionIdle {} => {
|
||||
// Expected
|
||||
}
|
||||
_ => panic!("Expected SessionIdle variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_translate_message_failed() {
|
||||
let event = Event::MessageFailed {
|
||||
message_id: "msg-1".to_string(),
|
||||
error: "Connection lost".to_string(),
|
||||
};
|
||||
|
||||
// MessageFailed events are now filtered (return empty vec)
|
||||
let updates = translate_event(&event);
|
||||
assert!(
|
||||
updates.is_empty(),
|
||||
"MessageFailed events should be filtered"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_translate_filtered_events() {
|
||||
// These events should not produce notifications
|
||||
let filtered_events = vec![
|
||||
Event::Connected,
|
||||
Event::Disconnected,
|
||||
Event::SessionDeleted {
|
||||
session_id: "s".to_string(),
|
||||
},
|
||||
Event::Error {
|
||||
message: "System error".to_string(),
|
||||
},
|
||||
Event::MessageFailed {
|
||||
message_id: "msg-1".to_string(),
|
||||
error: "Failed".to_string(),
|
||||
},
|
||||
];
|
||||
|
||||
for event in filtered_events {
|
||||
assert!(
|
||||
translate_event(&event).is_empty(),
|
||||
"Expected {:?} to be filtered",
|
||||
event
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,478 @@
|
||||
//! SSE (Server-Sent Events) Notifications for ACP Server
|
||||
//!
|
||||
//! This module provides the SSE notification infrastructure for streaming events
|
||||
//! to connected ACP clients. It handles per-client subscriptions using broadcast
|
||||
//! channels and provides translation from Dirigent protocol events to ACP notifications.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! The module is organized into several sub-modules:
|
||||
//!
|
||||
//! - `types` - Data structures for SSE notifications (SessionUpdateParams, SessionUpdateVariant, etc.)
|
||||
//! - `notifier` - SseNotifier for managing client subscriptions and broadcasting
|
||||
//! - `event_translator` - Translation from Dirigent Events to ACP notifications
|
||||
//! - `content_transform` - Helper functions for content and metadata transformation
|
||||
//!
|
||||
//! ## SSE Wire Format
|
||||
//!
|
||||
//! The SSE stream uses the following format:
|
||||
//!
|
||||
//! ```text
|
||||
//! event: session/update
|
||||
//! data: {"sessionId": "...", "update": {"sessionUpdate": "agent_message_chunk", ...}}
|
||||
//!
|
||||
//! event: session/update
|
||||
//! data: {"sessionId": "...", "update": {"sessionUpdate": "message_complete", ...}}
|
||||
//! ```
|
||||
//!
|
||||
//! ## Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use dirigent_acp_api::sse::{SseNotifier, SessionUpdateParams, translate_event};
|
||||
//!
|
||||
//! // Create the notifier
|
||||
//! let notifier = SseNotifier::new();
|
||||
//!
|
||||
//! // Subscribe a client
|
||||
//! let stream = notifier.subscribe("client-123");
|
||||
//!
|
||||
//! // Translate and broadcast an event
|
||||
//! let updates = translate_event(&event);
|
||||
//! for update in updates {
|
||||
//! notifier.broadcast("client-123", update);
|
||||
//! }
|
||||
//!
|
||||
//! // Unsubscribe when done
|
||||
//! notifier.unsubscribe("client-123");
|
||||
//! ```
|
||||
|
||||
mod content_transform;
|
||||
mod event_translator;
|
||||
mod notifier;
|
||||
mod types;
|
||||
|
||||
// Re-export public types
|
||||
pub use content_transform::{
|
||||
extract_kind, rebuild_meta, tool_call_status_to_string, unwrap_tool_call_content,
|
||||
};
|
||||
pub use event_translator::translate_event;
|
||||
pub use notifier::SseNotifier;
|
||||
pub use types::{
|
||||
models_to_config_option, modes_to_config_option, AcpNotification, ConfigOption,
|
||||
ConfigOptionChoice, ConfigOptionType, SessionUpdateParams, SessionUpdateVariant, SlashCommand,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Tests (moved from original sse.rs)
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// ========================================================================
|
||||
// AcpNotification Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_message_chunk_serialization() {
|
||||
let notification = AcpNotification::MessageChunk {
|
||||
session_id: "sess-123".to_string(),
|
||||
message_id: "msg-456".to_string(),
|
||||
content: "Hello, world!".to_string(),
|
||||
content_type: Some("text".to_string()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(¬ification).unwrap();
|
||||
assert!(json.contains(r#""type":"message_chunk"#));
|
||||
assert!(json.contains(r#""sessionId":"sess-123"#));
|
||||
assert!(json.contains(r#""messageId":"msg-456"#));
|
||||
assert!(json.contains(r#""content":"Hello, world!"#));
|
||||
assert!(json.contains(r#""contentType":"text"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_complete_serialization() {
|
||||
let notification = AcpNotification::MessageComplete {
|
||||
session_id: "sess-123".to_string(),
|
||||
message_id: "msg-456".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(¬ification).unwrap();
|
||||
assert!(json.contains(r#""type":"message_complete"#));
|
||||
assert!(json.contains(r#""sessionId":"sess-123"#));
|
||||
assert!(json.contains(r#""messageId":"msg-456"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_idle_serialization() {
|
||||
let notification = AcpNotification::SessionIdle {
|
||||
session_id: "sess-123".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(¬ification).unwrap();
|
||||
assert!(json.contains(r#""type":"session_idle"#));
|
||||
assert!(json.contains(r#""sessionId":"sess-123"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_error_serialization() {
|
||||
let notification = AcpNotification::SessionError {
|
||||
session_id: "sess-123".to_string(),
|
||||
error: "Something went wrong".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(¬ification).unwrap();
|
||||
assert!(json.contains(r#""type":"session_error"#));
|
||||
assert!(json.contains(r#""error":"Something went wrong"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_update_serialization() {
|
||||
let notification = AcpNotification::ToolCallUpdate {
|
||||
session_id: "sess-123".to_string(),
|
||||
message_id: "msg-456".to_string(),
|
||||
tool_call_id: "call-789".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: "running".to_string(),
|
||||
title: Some("Running command".to_string()),
|
||||
error: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(¬ification).unwrap();
|
||||
assert!(json.contains(r#""type":"tool_call_update"#));
|
||||
assert!(json.contains(r#""toolName":"bash"#));
|
||||
assert!(json.contains(r#""status":"running"#));
|
||||
assert!(json.contains(r#""title":"Running command"#));
|
||||
// error should not be present when None
|
||||
assert!(!json.contains(r#""error""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_event_type() {
|
||||
assert_eq!(
|
||||
AcpNotification::MessageChunk {
|
||||
session_id: "s".to_string(),
|
||||
message_id: "m".to_string(),
|
||||
content: "c".to_string(),
|
||||
content_type: None,
|
||||
}
|
||||
.event_type(),
|
||||
"message/chunk"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
AcpNotification::MessageComplete {
|
||||
session_id: "s".to_string(),
|
||||
message_id: "m".to_string(),
|
||||
}
|
||||
.event_type(),
|
||||
"message/complete"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
AcpNotification::SessionIdle {
|
||||
session_id: "s".to_string(),
|
||||
}
|
||||
.event_type(),
|
||||
"session/idle"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
AcpNotification::SessionError {
|
||||
session_id: "s".to_string(),
|
||||
error: "e".to_string(),
|
||||
}
|
||||
.event_type(),
|
||||
"session/error"
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
AcpNotification::ToolCallUpdate {
|
||||
session_id: "s".to_string(),
|
||||
message_id: "m".to_string(),
|
||||
tool_call_id: "t".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: "running".to_string(),
|
||||
title: None,
|
||||
error: None,
|
||||
}
|
||||
.event_type(),
|
||||
"tool/update"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_sse_string() {
|
||||
let notification = AcpNotification::SessionIdle {
|
||||
session_id: "sess-123".to_string(),
|
||||
};
|
||||
|
||||
let sse = notification.to_sse_string();
|
||||
assert!(sse.starts_with("event: session/idle\n"));
|
||||
assert!(sse.contains("data: "));
|
||||
assert!(sse.contains("sess-123"));
|
||||
assert!(sse.ends_with("\n\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_notification_roundtrip() {
|
||||
let notifications = vec![
|
||||
AcpNotification::MessageChunk {
|
||||
session_id: "s".to_string(),
|
||||
message_id: "m".to_string(),
|
||||
content: "c".to_string(),
|
||||
content_type: Some("text".to_string()),
|
||||
},
|
||||
AcpNotification::MessageComplete {
|
||||
session_id: "s".to_string(),
|
||||
message_id: "m".to_string(),
|
||||
},
|
||||
AcpNotification::SessionIdle {
|
||||
session_id: "s".to_string(),
|
||||
},
|
||||
AcpNotification::SessionError {
|
||||
session_id: "s".to_string(),
|
||||
error: "e".to_string(),
|
||||
},
|
||||
AcpNotification::ToolCallUpdate {
|
||||
session_id: "s".to_string(),
|
||||
message_id: "m".to_string(),
|
||||
tool_call_id: "t".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: "running".to_string(),
|
||||
title: Some("title".to_string()),
|
||||
error: None,
|
||||
},
|
||||
];
|
||||
|
||||
for notification in notifications {
|
||||
let json = serde_json::to_string(¬ification).unwrap();
|
||||
let deserialized: AcpNotification = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(notification, deserialized);
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// SessionUpdate Tag Verification Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_session_update_variant_uses_session_update_tag_tool_call() {
|
||||
// Test ToolCall variant with flattened structure
|
||||
let variant = SessionUpdateVariant::ToolCall {
|
||||
tool_call_id: "call_456".to_string(),
|
||||
title: Some("Running command".to_string()),
|
||||
kind: Some("command".to_string()),
|
||||
raw_input: Some(json!({"command": "ls"})),
|
||||
status: Some("in_progress".to_string()),
|
||||
content: vec![],
|
||||
_meta: Some(json!({
|
||||
"claudeCode": {
|
||||
"toolName": "bash"
|
||||
}
|
||||
})),
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&variant).unwrap();
|
||||
|
||||
// Verify correct tag name "sessionUpdate" is present
|
||||
assert!(
|
||||
json.get("sessionUpdate").is_some(),
|
||||
"Missing 'sessionUpdate' field in JSON: {:?}",
|
||||
json
|
||||
);
|
||||
assert_eq!(
|
||||
json["sessionUpdate"], "tool_call",
|
||||
"Expected sessionUpdate value 'tool_call', got: {:?}",
|
||||
json["sessionUpdate"]
|
||||
);
|
||||
|
||||
// Verify incorrect tag "type" is NOT present
|
||||
assert!(
|
||||
json.get("type").is_none(),
|
||||
"Should not have 'type' field in ACP output, found: {:?}",
|
||||
json.get("type")
|
||||
);
|
||||
|
||||
// Verify flattened fields are present
|
||||
assert!(json.get("toolCallId").is_some(), "Missing toolCallId field");
|
||||
assert!(json.get("title").is_some(), "Missing title field");
|
||||
assert!(json.get("status").is_some(), "Missing status field");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_update_variant_roundtrip_with_session_update_tag() {
|
||||
// Test that serialization and deserialization work correctly with sessionUpdate tag
|
||||
let variants = vec![
|
||||
SessionUpdateVariant::SessionIdle {},
|
||||
SessionUpdateVariant::MessageComplete {
|
||||
message_id: Some("msg_1".to_string()),
|
||||
},
|
||||
SessionUpdateVariant::AgentMessageChunk {
|
||||
content: dirigent_protocol::ContentBlock::Text {
|
||||
text: "test".to_string(),
|
||||
},
|
||||
},
|
||||
SessionUpdateVariant::ToolCall {
|
||||
tool_call_id: "call_1".to_string(),
|
||||
title: None,
|
||||
kind: None,
|
||||
raw_input: None,
|
||||
status: Some("pending".to_string()),
|
||||
content: vec![],
|
||||
_meta: Some(json!({
|
||||
"claudeCode": {
|
||||
"toolName": "test_tool"
|
||||
}
|
||||
})),
|
||||
},
|
||||
];
|
||||
|
||||
for variant in variants {
|
||||
// Serialize
|
||||
let json_str = serde_json::to_string(&variant).unwrap();
|
||||
let json_val: serde_json::Value = serde_json::from_str(&json_str).unwrap();
|
||||
|
||||
// Verify sessionUpdate tag in JSON
|
||||
assert!(
|
||||
json_val.get("sessionUpdate").is_some(),
|
||||
"Missing sessionUpdate in serialized JSON for variant: {:?}",
|
||||
variant
|
||||
);
|
||||
|
||||
// Verify no type tag
|
||||
assert!(
|
||||
json_val.get("type").is_none(),
|
||||
"Found unexpected 'type' tag for variant: {:?}",
|
||||
variant
|
||||
);
|
||||
|
||||
// Deserialize back
|
||||
let deserialized: SessionUpdateVariant = serde_json::from_str(&json_str).unwrap();
|
||||
|
||||
// Verify equality (roundtrip)
|
||||
assert_eq!(
|
||||
variant, deserialized,
|
||||
"Roundtrip failed for variant: {:?}",
|
||||
variant
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// camelCase Field Naming Compliance Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_all_fields_use_camel_case() {
|
||||
// Test SessionUpdateParams with ToolCall variant
|
||||
let params = SessionUpdateParams {
|
||||
session_id: "sess_123".to_string(),
|
||||
update: SessionUpdateVariant::ToolCall {
|
||||
tool_call_id: "tool_123".to_string(),
|
||||
title: Some("test".to_string()),
|
||||
kind: Some("search".to_string()),
|
||||
raw_input: Some(json!({"key": "value"})),
|
||||
status: Some("pending".to_string()),
|
||||
content: vec![],
|
||||
_meta: None,
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(¶ms).unwrap();
|
||||
|
||||
// Verify top-level camelCase
|
||||
assert!(
|
||||
json.get("sessionId").is_some(),
|
||||
"Should have sessionId (not session_id)"
|
||||
);
|
||||
assert!(
|
||||
json.get("session_id").is_none(),
|
||||
"Should NOT have session_id"
|
||||
);
|
||||
|
||||
let update = &json["update"];
|
||||
|
||||
// Verify ToolCall variant fields use camelCase
|
||||
assert!(
|
||||
update.get("toolCallId").is_some(),
|
||||
"Should have toolCallId (not tool_call_id)"
|
||||
);
|
||||
assert!(
|
||||
update.get("tool_call_id").is_none(),
|
||||
"Should NOT have tool_call_id"
|
||||
);
|
||||
assert!(
|
||||
update.get("rawInput").is_some(),
|
||||
"Should have rawInput (not raw_input)"
|
||||
);
|
||||
assert!(
|
||||
update.get("raw_input").is_none(),
|
||||
"Should NOT have raw_input"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_snake_case_in_serialized_json() {
|
||||
// Create various SessionUpdateParams with different variants
|
||||
let test_cases = vec![
|
||||
SessionUpdateParams {
|
||||
session_id: "s1".to_string(),
|
||||
update: SessionUpdateVariant::ToolCall {
|
||||
tool_call_id: "tc1".to_string(),
|
||||
title: Some("t".to_string()),
|
||||
kind: Some("k".to_string()),
|
||||
raw_input: Some(json!({})),
|
||||
status: Some("pending".to_string()),
|
||||
content: vec![],
|
||||
_meta: None,
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
SessionUpdateParams {
|
||||
session_id: "s2".to_string(),
|
||||
update: SessionUpdateVariant::ToolCallUpdate {
|
||||
tool_call_id: "tc2".to_string(),
|
||||
status: Some("completed".to_string()),
|
||||
content: vec![],
|
||||
raw_output: Some(json!({})),
|
||||
error: None,
|
||||
_meta: None,
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
SessionUpdateParams {
|
||||
session_id: "s3".to_string(),
|
||||
update: SessionUpdateVariant::MessageComplete {
|
||||
message_id: Some("m3".to_string()),
|
||||
},
|
||||
..Default::default()
|
||||
},
|
||||
];
|
||||
|
||||
for params in test_cases {
|
||||
let json_str = serde_json::to_string(¶ms).unwrap();
|
||||
|
||||
// Check for common snake_case fields that should NOT appear
|
||||
let forbidden_patterns = vec![
|
||||
"session_id",
|
||||
"tool_call_id",
|
||||
"raw_input",
|
||||
"raw_output",
|
||||
"message_id",
|
||||
];
|
||||
|
||||
for pattern in forbidden_patterns {
|
||||
assert!(
|
||||
!json_str.contains(&format!("\"{}\"", pattern)),
|
||||
"Found forbidden snake_case field '{}' in JSON: {}",
|
||||
pattern,
|
||||
json_str
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,403 @@
|
||||
//! SSE Notifier for managing client subscriptions
|
||||
//!
|
||||
//! This module provides the subscription management infrastructure for SSE streaming.
|
||||
//! The `SseNotifier` manages per-client broadcast channels for targeted notifications.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use tokio::sync::broadcast;
|
||||
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
use tokio_stream::Stream;
|
||||
use tracing::{debug, trace, warn};
|
||||
|
||||
use super::types::SessionUpdateParams;
|
||||
|
||||
/// Default broadcast channel capacity
|
||||
const DEFAULT_CHANNEL_CAPACITY: usize = 256;
|
||||
|
||||
/// Internal state for SSE subscriptions
|
||||
#[derive(Debug, Default)]
|
||||
struct SseNotifierState {
|
||||
/// Map from client_id to their broadcast sender
|
||||
subscriptions: HashMap<String, broadcast::Sender<SessionUpdateParams>>,
|
||||
}
|
||||
|
||||
/// SSE Notifier for managing client subscriptions and broadcasting notifications
|
||||
///
|
||||
/// The `SseNotifier` manages per-client subscriptions using tokio broadcast channels.
|
||||
/// Each client has their own broadcast channel, allowing targeted notifications.
|
||||
///
|
||||
/// ## Thread Safety
|
||||
///
|
||||
/// The `SseNotifier` is designed to be cloned and shared across async tasks.
|
||||
/// Internal state is protected by `Arc<RwLock<...>>` for thread-safe access.
|
||||
///
|
||||
/// ## Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let notifier = SseNotifier::new();
|
||||
///
|
||||
/// // Subscribe a client and get their stream
|
||||
/// let stream = notifier.subscribe("client-123");
|
||||
///
|
||||
/// // In another task, broadcast notifications
|
||||
/// notifier.broadcast("client-123", notification);
|
||||
///
|
||||
/// // When client disconnects
|
||||
/// notifier.unsubscribe("client-123");
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SseNotifier {
|
||||
state: Arc<RwLock<SseNotifierState>>,
|
||||
channel_capacity: usize,
|
||||
}
|
||||
|
||||
impl Default for SseNotifier {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SseNotifier {
|
||||
/// Create a new SSE notifier with default capacity
|
||||
///
|
||||
/// Creates a new notifier with the default channel capacity of 256 messages.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: Arc::new(RwLock::new(SseNotifierState::default())),
|
||||
channel_capacity: DEFAULT_CHANNEL_CAPACITY,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new SSE notifier with custom channel capacity
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `capacity`: The maximum number of messages that can be buffered per client
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
state: Arc::new(RwLock::new(SseNotifierState::default())),
|
||||
channel_capacity: capacity,
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe a client and return an async stream of notifications
|
||||
///
|
||||
/// Creates a broadcast channel for the client and returns a stream that can
|
||||
/// be used with Axum's `Sse<impl Stream>` response type.
|
||||
///
|
||||
/// If the client is already subscribed, a new receiver is created from the
|
||||
/// existing sender (allowing multiple connections per client).
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `client_id`: Unique identifier for the client
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A pinned stream of `SessionUpdateParams` wrapped in `Result` for error handling.
|
||||
/// The stream is compatible with Axum's SSE handler.
|
||||
pub fn subscribe(
|
||||
&self,
|
||||
client_id: &str,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<SessionUpdateParams, BroadcastStreamRecvError>> + Send>>
|
||||
{
|
||||
let mut state = self.state.write().expect("Lock poisoned");
|
||||
|
||||
// Get or create the sender for this client
|
||||
let sender = state
|
||||
.subscriptions
|
||||
.entry(client_id.to_string())
|
||||
.or_insert_with(|| {
|
||||
debug!("Creating new broadcast channel for client: {}", client_id);
|
||||
let (tx, _rx) = broadcast::channel(self.channel_capacity);
|
||||
tx
|
||||
});
|
||||
|
||||
// Create a new receiver
|
||||
let receiver = sender.subscribe();
|
||||
|
||||
debug!(
|
||||
"Client subscribed to SSE: {} (subscribers: {})",
|
||||
client_id,
|
||||
sender.receiver_count()
|
||||
);
|
||||
|
||||
// Convert to stream using tokio_stream
|
||||
let stream = BroadcastStream::new(receiver);
|
||||
|
||||
Box::pin(stream)
|
||||
}
|
||||
|
||||
/// Unsubscribe a client and cleanup resources
|
||||
///
|
||||
/// Removes the client's subscription and drops the sender, which will
|
||||
/// cause all associated receivers to receive an error on their next poll.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `client_id`: The ID of the client to unsubscribe
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `true` if the client was subscribed and removed, `false` if not found
|
||||
pub fn unsubscribe(&self, client_id: &str) -> bool {
|
||||
let mut state = self.state.write().expect("Lock poisoned");
|
||||
|
||||
if let Some(sender) = state.subscriptions.remove(client_id) {
|
||||
debug!(
|
||||
"Client unsubscribed from SSE: {} (was {} receivers)",
|
||||
client_id,
|
||||
sender.receiver_count()
|
||||
);
|
||||
// Sender is dropped here, which will close the channel
|
||||
true
|
||||
} else {
|
||||
debug!("Client was not subscribed: {}", client_id);
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Broadcast a session update to a specific client
|
||||
///
|
||||
/// Sends the session update to all receivers subscribed to the specified client.
|
||||
/// If no receivers are active (e.g., all have dropped or lagged), the send
|
||||
/// is handled gracefully.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `client_id`: The ID of the client to send to
|
||||
/// - `update_params`: The session update parameters to send
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// - `Ok(n)` where n is the number of receivers that received the update
|
||||
/// - `Err(())` if the client is not subscribed
|
||||
pub fn broadcast(
|
||||
&self,
|
||||
client_id: &str,
|
||||
update_params: SessionUpdateParams,
|
||||
) -> Result<usize, ()> {
|
||||
let state = self.state.read().expect("Lock poisoned");
|
||||
|
||||
if let Some(sender) = state.subscriptions.get(client_id) {
|
||||
match sender.send(update_params) {
|
||||
Ok(n) => {
|
||||
trace!(
|
||||
"Broadcast session update to client {}: {} receivers",
|
||||
client_id,
|
||||
n
|
||||
);
|
||||
Ok(n)
|
||||
}
|
||||
Err(_) => {
|
||||
// No active receivers - this is okay, the update is just dropped
|
||||
trace!("Broadcast to client {}: no active receivers", client_id);
|
||||
Ok(0)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("Attempted broadcast to unsubscribed client: {}", client_id);
|
||||
Err(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Broadcast a session update to all subscribed clients
|
||||
///
|
||||
/// Sends the session update to all connected clients. Failed sends to
|
||||
/// individual clients are logged but don't stop the broadcast.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `update_params`: The session update parameters to broadcast
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The total number of receivers that received the update
|
||||
pub fn broadcast_all(&self, update_params: SessionUpdateParams) -> usize {
|
||||
let state = self.state.read().expect("Lock poisoned");
|
||||
let mut total_receivers = 0;
|
||||
|
||||
for (client_id, sender) in state.subscriptions.iter() {
|
||||
match sender.send(update_params.clone()) {
|
||||
Ok(n) => {
|
||||
total_receivers += n;
|
||||
trace!("Broadcast to client {}: {} receivers", client_id, n);
|
||||
}
|
||||
Err(_) => {
|
||||
trace!("Broadcast to client {}: no active receivers", client_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Broadcast to all clients: {} total receivers across {} clients",
|
||||
total_receivers,
|
||||
state.subscriptions.len()
|
||||
);
|
||||
|
||||
total_receivers
|
||||
}
|
||||
|
||||
/// Get the number of subscribed clients
|
||||
pub fn client_count(&self) -> usize {
|
||||
let state = self.state.read().expect("Lock poisoned");
|
||||
state.subscriptions.len()
|
||||
}
|
||||
|
||||
/// Check if a client is subscribed
|
||||
pub fn is_subscribed(&self, client_id: &str) -> bool {
|
||||
let state = self.state.read().expect("Lock poisoned");
|
||||
state.subscriptions.contains_key(client_id)
|
||||
}
|
||||
|
||||
/// Get the list of subscribed client IDs
|
||||
pub fn subscribed_clients(&self) -> Vec<String> {
|
||||
let state = self.state.read().expect("Lock poisoned");
|
||||
state.subscriptions.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::sse::types::SessionUpdateVariant;
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
#[test]
|
||||
fn test_sse_notifier_new() {
|
||||
let notifier = SseNotifier::new();
|
||||
assert_eq!(notifier.client_count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sse_notifier_with_capacity() {
|
||||
let notifier = SseNotifier::with_capacity(512);
|
||||
assert_eq!(notifier.channel_capacity, 512);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_subscribe_unsubscribe() {
|
||||
let notifier = SseNotifier::new();
|
||||
|
||||
// Subscribe
|
||||
let _stream = notifier.subscribe("client-1");
|
||||
assert!(notifier.is_subscribed("client-1"));
|
||||
assert_eq!(notifier.client_count(), 1);
|
||||
|
||||
// Unsubscribe
|
||||
assert!(notifier.unsubscribe("client-1"));
|
||||
assert!(!notifier.is_subscribed("client-1"));
|
||||
assert_eq!(notifier.client_count(), 0);
|
||||
|
||||
// Unsubscribe non-existent
|
||||
assert!(!notifier.unsubscribe("client-1"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_broadcast_to_client() {
|
||||
let notifier = SseNotifier::new();
|
||||
|
||||
// Subscribe
|
||||
let mut stream = notifier.subscribe("client-1");
|
||||
|
||||
// Broadcast
|
||||
let update_params = SessionUpdateParams {
|
||||
session_id: "sess-1".to_string(),
|
||||
update: SessionUpdateVariant::SessionIdle {},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = notifier.broadcast("client-1", update_params.clone());
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), 1);
|
||||
|
||||
// Receive
|
||||
let received = stream.next().await;
|
||||
assert!(received.is_some());
|
||||
let received = received.unwrap();
|
||||
assert!(received.is_ok());
|
||||
assert_eq!(received.unwrap(), update_params);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_broadcast_to_unsubscribed_client() {
|
||||
let notifier = SseNotifier::new();
|
||||
|
||||
let update_params = SessionUpdateParams {
|
||||
session_id: "sess-1".to_string(),
|
||||
update: SessionUpdateVariant::SessionIdle {},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = notifier.broadcast("unknown-client", update_params);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_broadcast_all() {
|
||||
let notifier = SseNotifier::new();
|
||||
|
||||
// Subscribe multiple clients
|
||||
let mut stream1 = notifier.subscribe("client-1");
|
||||
let mut stream2 = notifier.subscribe("client-2");
|
||||
|
||||
let update_params = SessionUpdateParams {
|
||||
session_id: "sess-1".to_string(),
|
||||
update: SessionUpdateVariant::SessionIdle {},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let count = notifier.broadcast_all(update_params.clone());
|
||||
assert_eq!(count, 2);
|
||||
|
||||
// Both should receive
|
||||
let r1 = stream1.next().await.unwrap().unwrap();
|
||||
let r2 = stream2.next().await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(r1, update_params);
|
||||
assert_eq!(r2, update_params);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscribed_clients() {
|
||||
let notifier = SseNotifier::new();
|
||||
|
||||
let _s1 = notifier.subscribe("client-1");
|
||||
let _s2 = notifier.subscribe("client-2");
|
||||
|
||||
let clients = notifier.subscribed_clients();
|
||||
assert_eq!(clients.len(), 2);
|
||||
assert!(clients.contains(&"client-1".to_string()));
|
||||
assert!(clients.contains(&"client-2".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_receivers_same_client() {
|
||||
let notifier = SseNotifier::new();
|
||||
|
||||
// Subscribe same client twice
|
||||
let mut stream1 = notifier.subscribe("client-1");
|
||||
let mut stream2 = notifier.subscribe("client-1");
|
||||
|
||||
let update_params = SessionUpdateParams {
|
||||
session_id: "sess-1".to_string(),
|
||||
update: SessionUpdateVariant::SessionIdle {},
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let count = notifier.broadcast("client-1", update_params.clone());
|
||||
assert!(count.is_ok());
|
||||
assert_eq!(count.unwrap(), 2); // Both receivers
|
||||
|
||||
// Both should receive
|
||||
let r1 = stream1.next().await.unwrap().unwrap();
|
||||
let r2 = stream2.next().await.unwrap().unwrap();
|
||||
|
||||
assert_eq!(r1, update_params);
|
||||
assert_eq!(r2, update_params);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,548 @@
|
||||
//! SSE Type Definitions for ACP Server
|
||||
//!
|
||||
//! This module contains all the data structures and types used for SSE notifications.
|
||||
//! These types define the wire format for ACP protocol messages.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use dirigent_protocol::ContentBlock;
|
||||
|
||||
// ============================================================================
|
||||
// SessionUpdateParams
|
||||
// ============================================================================
|
||||
|
||||
/// Params for session/update JSON-RPC notification (ACP spec compliant)
|
||||
///
|
||||
/// This matches the structure sent by Claude and expected by Zed.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||
#[serde(rename_all = "camelCase", default)]
|
||||
pub struct SessionUpdateParams {
|
||||
/// The session ID this update belongs to
|
||||
pub session_id: String,
|
||||
|
||||
/// The update content
|
||||
pub update: SessionUpdateVariant,
|
||||
|
||||
/// Optional event type override (default: "session/update")
|
||||
/// Reserved for future protocol extensions. Currently unused.
|
||||
#[serde(skip)]
|
||||
pub event_type_override: Option<String>,
|
||||
}
|
||||
|
||||
impl SessionUpdateParams {
|
||||
/// Get the SSE event type for session updates
|
||||
///
|
||||
/// According to ACP spec, session updates use "session/update" event type.
|
||||
pub fn event_type(&self) -> &str {
|
||||
self.event_type_override
|
||||
.as_deref()
|
||||
.unwrap_or("session/update")
|
||||
}
|
||||
|
||||
/// Create a raw event with a custom event type
|
||||
///
|
||||
/// This creates an event that will be serialized as raw JSON data
|
||||
/// (without the SessionUpdateParams wrapper) when sent over SSE.
|
||||
/// Used for `_rpc_response` deferred responses pushed via SSE after gateway transfer.
|
||||
pub fn raw_event(event_type: &str, data: serde_json::Value) -> Self {
|
||||
Self {
|
||||
session_id: String::new(), // Not used for raw events
|
||||
update: SessionUpdateVariant::Unknown { data },
|
||||
event_type_override: Some(event_type.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this is a raw event (should serialize data directly)
|
||||
pub fn is_raw_event(&self) -> bool {
|
||||
self.event_type_override.is_some()
|
||||
}
|
||||
|
||||
/// Get the data to serialize for SSE
|
||||
///
|
||||
/// For raw events (with event_type_override), returns the raw JSON data.
|
||||
/// For normal events, returns None (caller should serialize the full struct).
|
||||
pub fn raw_data(&self) -> Option<&serde_json::Value> {
|
||||
if self.event_type_override.is_some() {
|
||||
if let SessionUpdateVariant::Unknown { data } = &self.update {
|
||||
return Some(data);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Serialize to JSON string for SSE transmission
|
||||
///
|
||||
/// For raw events, serializes just the raw data.
|
||||
/// For normal events, serializes the full SessionUpdateParams struct.
|
||||
pub fn to_sse_json(&self) -> String {
|
||||
if let Some(raw_data) = self.raw_data() {
|
||||
// Raw event: serialize just the data
|
||||
serde_json::to_string(raw_data).unwrap_or_else(|_| "{}".to_string())
|
||||
} else {
|
||||
// Normal event: serialize the full struct
|
||||
serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ConfigOption Types
|
||||
// ============================================================================
|
||||
|
||||
/// A single configuration option for session settings (outgoing ACP wire format)
|
||||
///
|
||||
/// Used in `config_option_update` notifications to send the complete set of
|
||||
/// configuration options (modes, models, etc.) to the client.
|
||||
///
|
||||
/// Note: This is the **outgoing** wire format for ACP Server SSE.
|
||||
/// For the **incoming** format (parsed from agents), see `dirigent_protocol::ConfigOption`.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConfigOption {
|
||||
/// Unique identifier for this option (e.g., "mode", "model")
|
||||
pub id: String,
|
||||
/// Human-readable display name (e.g., "Session Mode", "Model")
|
||||
pub name: String,
|
||||
/// Optional description
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
/// Semantic category for UX grouping (e.g., "mode", "model", "thought_level")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub category: Option<String>,
|
||||
/// Input type (e.g., "select", "toggle", "text")
|
||||
#[serde(rename = "type")]
|
||||
pub option_type: ConfigOptionType,
|
||||
/// Currently selected value
|
||||
pub current_value: String,
|
||||
/// Available options for "select" type
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub options: Option<Vec<ConfigOptionChoice>>,
|
||||
}
|
||||
|
||||
/// Type of configuration option input
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ConfigOptionType {
|
||||
Select,
|
||||
Toggle,
|
||||
Text,
|
||||
}
|
||||
|
||||
/// A single choice within a select-type config option
|
||||
///
|
||||
/// Matches ACP's `SessionConfigSelectOption` structure.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConfigOptionChoice {
|
||||
/// Value identifier for this choice (sent back when selected)
|
||||
pub value: String,
|
||||
/// Human-readable display name
|
||||
pub name: String,
|
||||
/// Optional description
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SessionUpdateVariant
|
||||
// ============================================================================
|
||||
|
||||
/// Variants of session updates (ACP spec compliant)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "sessionUpdate", rename_all = "snake_case")]
|
||||
pub enum SessionUpdateVariant {
|
||||
/// Agent message chunk (streaming text)
|
||||
#[serde(rename = "agent_message_chunk")]
|
||||
AgentMessageChunk {
|
||||
/// The content block (text, image, etc.)
|
||||
content: ContentBlock,
|
||||
},
|
||||
|
||||
/// Message generation complete
|
||||
#[serde(rename = "message_complete")]
|
||||
MessageComplete {
|
||||
/// The message ID that completed
|
||||
#[serde(rename = "messageId", skip_serializing_if = "Option::is_none")]
|
||||
message_id: Option<String>,
|
||||
},
|
||||
|
||||
/// Session is idle and ready for input
|
||||
#[serde(rename = "session_idle")]
|
||||
SessionIdle {},
|
||||
|
||||
/// Available slash commands update
|
||||
#[serde(rename = "available_commands_update")]
|
||||
AvailableCommandsUpdate {
|
||||
/// List of available slash commands
|
||||
#[serde(rename = "availableCommands")]
|
||||
available_commands: Vec<SlashCommand>,
|
||||
},
|
||||
|
||||
/// Connector changed (session transferred to a different connector)
|
||||
#[serde(rename = "connector_changed")]
|
||||
ConnectorChanged {
|
||||
/// The new connector ID
|
||||
#[serde(rename = "newConnectorId")]
|
||||
new_connector_id: String,
|
||||
/// The new internal session ID in the target connector
|
||||
#[serde(rename = "newInternalSessionId")]
|
||||
new_internal_session_id: String,
|
||||
/// Whether a new session was created (true) or existing session loaded (false)
|
||||
#[serde(rename = "isNewSession")]
|
||||
is_new_session: bool,
|
||||
},
|
||||
|
||||
/// Tool call started/initiated
|
||||
#[serde(rename = "tool_call")]
|
||||
ToolCall {
|
||||
/// The tool call ID
|
||||
#[serde(rename = "toolCallId")]
|
||||
tool_call_id: String,
|
||||
|
||||
/// Optional title for the tool call
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
title: Option<String>,
|
||||
|
||||
/// Optional kind/category (e.g., "search", "edit")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
|
||||
/// Raw input parameters
|
||||
#[serde(rename = "rawInput", skip_serializing_if = "Option::is_none")]
|
||||
raw_input: Option<serde_json::Value>,
|
||||
|
||||
/// Current status (pending, in_progress, completed, failed)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
status: Option<String>,
|
||||
|
||||
/// Content blocks (e.g., tool output)
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
content: Vec<ContentBlock>,
|
||||
|
||||
/// Metadata (can include claudeCode.toolName, toolResponse, etc.)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
_meta: Option<serde_json::Value>,
|
||||
},
|
||||
|
||||
/// Tool call update (status change, output, etc.)
|
||||
#[serde(rename = "tool_call_update")]
|
||||
ToolCallUpdate {
|
||||
/// The tool call ID being updated
|
||||
#[serde(rename = "toolCallId")]
|
||||
tool_call_id: String,
|
||||
|
||||
/// Updated status
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
status: Option<String>,
|
||||
|
||||
/// Updated content blocks
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
content: Vec<ContentBlock>,
|
||||
|
||||
/// Raw output from the tool
|
||||
#[serde(rename = "rawOutput", skip_serializing_if = "Option::is_none")]
|
||||
raw_output: Option<serde_json::Value>,
|
||||
|
||||
/// Error message if tool call failed
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
|
||||
/// Metadata (can include toolResponse from Claude)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
_meta: Option<serde_json::Value>,
|
||||
},
|
||||
|
||||
/// Agent request needing client response (e.g., permission prompt)
|
||||
///
|
||||
/// This variant is used for bidirectional requests where the agent initiates
|
||||
/// a request (like `session/request_permission`) that requires a response from
|
||||
/// the client. The client should respond via the `/acp/agent_response` endpoint.
|
||||
#[serde(rename = "agent_request")]
|
||||
AgentRequest {
|
||||
/// The request ID from the agent (to correlate response)
|
||||
#[serde(rename = "requestId")]
|
||||
request_id: serde_json::Value,
|
||||
|
||||
/// The method being requested (e.g., "session/request_permission")
|
||||
method: String,
|
||||
|
||||
/// The request parameters
|
||||
params: serde_json::Value,
|
||||
},
|
||||
|
||||
/// Session modes update (full state) - NOT SUPPORTED BY ZED
|
||||
/// Kept for future ACP protocol support
|
||||
#[serde(rename = "modes_update")]
|
||||
ModesUpdate {
|
||||
/// Session mode state (flattened)
|
||||
#[serde(flatten)]
|
||||
modes: dirigent_protocol::SessionModeState,
|
||||
},
|
||||
|
||||
/// Session models update (full state) - NOT SUPPORTED BY ZED
|
||||
/// Kept for future ACP protocol support
|
||||
#[serde(rename = "models_update")]
|
||||
ModelsUpdate {
|
||||
/// Session model state (flattened)
|
||||
#[serde(flatten)]
|
||||
models: dirigent_protocol::SessionModelState,
|
||||
},
|
||||
|
||||
/// Current mode update (standard ACP notification)
|
||||
/// Signals which mode is currently selected. Zed accepts this.
|
||||
#[serde(rename = "current_mode_update")]
|
||||
CurrentModeUpdate {
|
||||
/// The ID of the currently selected mode
|
||||
/// Note: Zed expects "currentModeId" not "modeId"
|
||||
#[serde(rename = "currentModeId")]
|
||||
mode_id: String,
|
||||
},
|
||||
|
||||
/// Current model update - NOT SUPPORTED BY ZED
|
||||
/// Kept for completeness but Zed only accepts current_mode_update
|
||||
#[serde(rename = "current_model_update")]
|
||||
CurrentModelUpdate {
|
||||
/// The ID of the currently selected model
|
||||
#[serde(rename = "modelId")]
|
||||
model_id: String,
|
||||
},
|
||||
|
||||
/// Config option update - send full configuration options to client
|
||||
///
|
||||
/// This is used after session transfer to provide the target connector's
|
||||
/// actual modes/models instead of Gateway's placeholder values.
|
||||
/// See: https://agentclientprotocol.com/rfds/session-config-options
|
||||
#[serde(rename = "config_option_update")]
|
||||
ConfigOptionUpdate {
|
||||
/// List of configuration options to display
|
||||
#[serde(rename = "configOptions")]
|
||||
config_options: Vec<ConfigOption>,
|
||||
},
|
||||
|
||||
/// Unknown update type (forward compatibility - pass through as raw JSON)
|
||||
#[serde(untagged)]
|
||||
Unknown {
|
||||
#[serde(flatten)]
|
||||
data: serde_json::Value,
|
||||
},
|
||||
}
|
||||
|
||||
impl SessionUpdateVariant {
|
||||
/// Returns the variant name for logging (without data)
|
||||
pub fn variant_name(&self) -> &'static str {
|
||||
match self {
|
||||
SessionUpdateVariant::AgentMessageChunk { .. } => "AgentMessageChunk",
|
||||
SessionUpdateVariant::MessageComplete { .. } => "MessageComplete",
|
||||
SessionUpdateVariant::SessionIdle { .. } => "SessionIdle",
|
||||
SessionUpdateVariant::AvailableCommandsUpdate { .. } => "AvailableCommandsUpdate",
|
||||
SessionUpdateVariant::ConnectorChanged { .. } => "ConnectorChanged",
|
||||
SessionUpdateVariant::ToolCall { .. } => "ToolCall",
|
||||
SessionUpdateVariant::ToolCallUpdate { .. } => "ToolCallUpdate",
|
||||
SessionUpdateVariant::AgentRequest { .. } => "AgentRequest",
|
||||
SessionUpdateVariant::ModesUpdate { .. } => "ModesUpdate",
|
||||
SessionUpdateVariant::ModelsUpdate { .. } => "ModelsUpdate",
|
||||
SessionUpdateVariant::CurrentModeUpdate { .. } => "CurrentModeUpdate",
|
||||
SessionUpdateVariant::CurrentModelUpdate { .. } => "CurrentModelUpdate",
|
||||
SessionUpdateVariant::ConfigOptionUpdate { .. } => "ConfigOptionUpdate",
|
||||
SessionUpdateVariant::Unknown { .. } => "Unknown",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SessionUpdateVariant {
|
||||
fn default() -> Self {
|
||||
SessionUpdateVariant::SessionIdle {}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SlashCommand
|
||||
// ============================================================================
|
||||
|
||||
/// Slash command definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SlashCommand {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub input: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// AcpNotification (Legacy)
|
||||
// ============================================================================
|
||||
|
||||
/// ACP Notification types for SSE streaming (DEPRECATED - keeping for compatibility)
|
||||
///
|
||||
/// These notifications are sent to clients over the SSE connection to inform
|
||||
/// them about session events, message streaming, and errors.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(
|
||||
tag = "type",
|
||||
rename_all = "snake_case",
|
||||
rename_all_fields = "camelCase"
|
||||
)]
|
||||
pub enum AcpNotification {
|
||||
/// A chunk of message content has been received
|
||||
///
|
||||
/// Sent during streaming to provide incremental content updates.
|
||||
/// The client should append this content to the message being built.
|
||||
MessageChunk {
|
||||
/// The session ID this chunk belongs to
|
||||
session_id: String,
|
||||
|
||||
/// The message ID this chunk belongs to
|
||||
message_id: String,
|
||||
|
||||
/// The content chunk (typically text)
|
||||
content: String,
|
||||
|
||||
/// Optional content type (e.g., "text", "thought")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content_type: Option<String>,
|
||||
},
|
||||
|
||||
/// A message has been completed
|
||||
///
|
||||
/// Sent when the agent has finished generating a message.
|
||||
/// The client should finalize the message display.
|
||||
MessageComplete {
|
||||
/// The session ID
|
||||
session_id: String,
|
||||
|
||||
/// The completed message ID
|
||||
message_id: String,
|
||||
},
|
||||
|
||||
/// A session has become idle
|
||||
///
|
||||
/// Sent when the session transitions to an idle state, meaning
|
||||
/// no active generation is occurring.
|
||||
SessionIdle {
|
||||
/// The session ID that became idle
|
||||
session_id: String,
|
||||
},
|
||||
|
||||
/// An error occurred in a session
|
||||
///
|
||||
/// Sent when an error occurs during processing.
|
||||
SessionError {
|
||||
/// The session ID where the error occurred
|
||||
session_id: String,
|
||||
|
||||
/// Human-readable error message
|
||||
error: String,
|
||||
},
|
||||
|
||||
/// A tool call has started or been updated
|
||||
///
|
||||
/// Sent when a tool call is initiated or its status changes.
|
||||
ToolCallUpdate {
|
||||
/// The session ID
|
||||
session_id: String,
|
||||
|
||||
/// The message ID containing this tool call
|
||||
message_id: String,
|
||||
|
||||
/// The tool call ID
|
||||
tool_call_id: String,
|
||||
|
||||
/// The tool name
|
||||
tool_name: String,
|
||||
|
||||
/// Current status (pending, running, completed, error)
|
||||
status: String,
|
||||
|
||||
/// Optional title for the tool call
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
title: Option<String>,
|
||||
|
||||
/// Optional error message
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AcpNotification {
|
||||
/// Get the SSE event type for this notification
|
||||
///
|
||||
/// This is used as the `event:` field in the SSE stream.
|
||||
pub fn event_type(&self) -> &'static str {
|
||||
match self {
|
||||
AcpNotification::MessageChunk { .. } => "message/chunk",
|
||||
AcpNotification::MessageComplete { .. } => "message/complete",
|
||||
AcpNotification::SessionIdle { .. } => "session/idle",
|
||||
AcpNotification::SessionError { .. } => "session/error",
|
||||
AcpNotification::ToolCallUpdate { .. } => "tool/update",
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to SSE format string
|
||||
///
|
||||
/// Returns a string formatted for SSE transmission:
|
||||
/// ```text
|
||||
/// event: <event_type>
|
||||
/// data: <json_data>
|
||||
/// ```
|
||||
pub fn to_sse_string(&self) -> String {
|
||||
let data = serde_json::to_string(self).unwrap_or_else(|_| "{}".to_string());
|
||||
format!("event: {}\ndata: {}\n\n", self.event_type(), data)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ConfigOption Conversion Functions
|
||||
// ============================================================================
|
||||
|
||||
/// Convert SessionModeState to a ConfigOption
|
||||
///
|
||||
/// Transforms the protocol's mode state into the ACP `config_option_update` format.
|
||||
/// Uses "mode" ID to match the ACP protocol standard.
|
||||
pub fn modes_to_config_option(modes: &dirigent_protocol::SessionModeState) -> ConfigOption {
|
||||
ConfigOption {
|
||||
id: "mode".to_string(),
|
||||
name: "Session Mode".to_string(),
|
||||
description: None,
|
||||
category: Some("mode".to_string()),
|
||||
option_type: ConfigOptionType::Select,
|
||||
current_value: modes.current_mode_id.clone(),
|
||||
options: Some(
|
||||
modes
|
||||
.available_modes
|
||||
.iter()
|
||||
.map(|m| ConfigOptionChoice {
|
||||
value: m.id.clone(),
|
||||
name: m.name.clone(),
|
||||
description: m.description.clone(),
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert SessionModelState to a ConfigOption
|
||||
///
|
||||
/// Transforms the protocol's model state into the ACP `config_option_update` format.
|
||||
/// Uses "model" ID to match the ACP protocol standard.
|
||||
pub fn models_to_config_option(models: &dirigent_protocol::SessionModelState) -> ConfigOption {
|
||||
ConfigOption {
|
||||
id: "model".to_string(),
|
||||
name: "Model".to_string(),
|
||||
description: None,
|
||||
category: Some("model".to_string()),
|
||||
option_type: ConfigOptionType::Select,
|
||||
current_value: models.current_model_id.clone(),
|
||||
options: Some(
|
||||
models
|
||||
.available_models
|
||||
.iter()
|
||||
.map(|m| ConfigOptionChoice {
|
||||
value: m.model_id.clone(),
|
||||
name: m.name.clone(),
|
||||
description: m.description.clone(),
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,338 @@
|
||||
//! Common test utilities for bidirectional flow testing.
|
||||
//!
|
||||
//! This module provides mock implementations and test helpers for testing
|
||||
//! the bidirectional request/response flow in the ACP Server.
|
||||
|
||||
use anyhow::Result;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, oneshot, Mutex};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Mock event for testing event bridge
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MockEvent {
|
||||
pub event_type: String,
|
||||
pub data: Value,
|
||||
}
|
||||
|
||||
/// Mock SSE client for testing
|
||||
///
|
||||
/// Simulates an HTTP client that receives SSE events and posts responses.
|
||||
pub struct MockSseClient {
|
||||
pub client_id: String,
|
||||
/// Events received via SSE
|
||||
pub received_events: Arc<Mutex<Vec<MockEvent>>>,
|
||||
/// Sender for simulating HTTP POST responses
|
||||
pub response_tx: mpsc::UnboundedSender<(Value, oneshot::Sender<Result<()>>)>,
|
||||
}
|
||||
|
||||
impl MockSseClient {
|
||||
pub fn new(client_id: String) -> (Self, mpsc::UnboundedReceiver<(Value, oneshot::Sender<Result<()>>)>) {
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
(
|
||||
Self {
|
||||
client_id,
|
||||
received_events: Arc::new(Mutex::new(Vec::new())),
|
||||
response_tx,
|
||||
},
|
||||
response_rx,
|
||||
)
|
||||
}
|
||||
|
||||
/// Simulate receiving an SSE event
|
||||
pub async fn receive_sse(&self, event_type: String, data: Value) {
|
||||
let mut events = self.received_events.lock().await;
|
||||
events.push(MockEvent { event_type, data });
|
||||
}
|
||||
|
||||
/// Get all received events
|
||||
pub async fn get_events(&self) -> Vec<MockEvent> {
|
||||
let events = self.received_events.lock().await;
|
||||
events.clone()
|
||||
}
|
||||
|
||||
/// Get the most recent event of a specific type
|
||||
pub async fn get_latest_event(&self, event_type: &str) -> Option<MockEvent> {
|
||||
let events = self.received_events.lock().await;
|
||||
events.iter()
|
||||
.filter(|e| e.event_type == event_type)
|
||||
.last()
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Clear received events
|
||||
pub async fn clear_events(&self) {
|
||||
let mut events = self.received_events.lock().await;
|
||||
events.clear();
|
||||
}
|
||||
|
||||
/// Simulate sending a response to /acp/agent_response
|
||||
pub async fn send_response(&self, response: Value) -> Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.response_tx.send((response, tx))
|
||||
.map_err(|_| anyhow::anyhow!("Failed to send response"))?;
|
||||
rx.await?
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock connector for testing
|
||||
///
|
||||
/// Simulates an ACP connector that can send agent requests and receive responses.
|
||||
pub struct MockConnector {
|
||||
pub connector_id: String,
|
||||
/// Channel for receiving agent requests from this connector
|
||||
pub request_tx: mpsc::UnboundedSender<(Value, String, String, Value)>,
|
||||
/// Pending responses (request_id -> sender)
|
||||
pub pending_responses: Arc<Mutex<HashMap<Value, oneshot::Sender<Value>>>>,
|
||||
}
|
||||
|
||||
impl MockConnector {
|
||||
pub fn new(
|
||||
connector_id: String,
|
||||
) -> (Self, mpsc::UnboundedReceiver<(Value, String, String, Value)>) {
|
||||
let (request_tx, request_rx) = mpsc::unbounded_channel();
|
||||
|
||||
(
|
||||
Self {
|
||||
connector_id,
|
||||
request_tx,
|
||||
pending_responses: Arc::new(Mutex::new(HashMap::new())),
|
||||
},
|
||||
request_rx,
|
||||
)
|
||||
}
|
||||
|
||||
/// Simulate sending an agent request
|
||||
pub async fn send_agent_request(
|
||||
&self,
|
||||
session_id: String,
|
||||
request_id: Value,
|
||||
method: String,
|
||||
params: Value,
|
||||
) -> oneshot::Receiver<Value> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// Store the response sender
|
||||
let mut pending = self.pending_responses.lock().await;
|
||||
pending.insert(request_id.clone(), tx);
|
||||
drop(pending);
|
||||
|
||||
// Send the request
|
||||
self.request_tx.send((request_id, session_id, method, params))
|
||||
.expect("Failed to send agent request");
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
/// Complete a pending response (simulating response from ACP Server)
|
||||
pub async fn complete_response(&self, request_id: Value, response: Value) -> Result<()> {
|
||||
let mut pending = self.pending_responses.lock().await;
|
||||
|
||||
if let Some(tx) = pending.remove(&request_id) {
|
||||
tx.send(response)
|
||||
.map_err(|_| anyhow::anyhow!("Failed to send response to connector"))?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("No pending response for request_id: {}", request_id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Test context for integration tests
|
||||
///
|
||||
/// Provides a complete test environment with mocked components.
|
||||
pub struct TestContext {
|
||||
/// Mock SSE clients by client_id
|
||||
pub clients: Arc<Mutex<HashMap<String, MockSseClient>>>,
|
||||
/// Mock connectors by connector_id
|
||||
pub connectors: Arc<Mutex<HashMap<String, MockConnector>>>,
|
||||
}
|
||||
|
||||
impl TestContext {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
clients: Arc::new(Mutex::new(HashMap::new())),
|
||||
connectors: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a mock SSE client
|
||||
pub async fn create_client(
|
||||
&self,
|
||||
client_id: String,
|
||||
) -> (MockSseClient, mpsc::UnboundedReceiver<(Value, oneshot::Sender<Result<()>>)>) {
|
||||
let (client, response_rx) = MockSseClient::new(client_id.clone());
|
||||
|
||||
let mut clients = self.clients.lock().await;
|
||||
clients.insert(client_id.clone(), client.clone());
|
||||
|
||||
(client, response_rx)
|
||||
}
|
||||
|
||||
/// Create a mock connector
|
||||
pub async fn create_connector(
|
||||
&self,
|
||||
connector_id: String,
|
||||
) -> (MockConnector, mpsc::UnboundedReceiver<(Value, String, String, Value)>) {
|
||||
let (connector, request_rx) = MockConnector::new(connector_id.clone());
|
||||
|
||||
let mut connectors = self.connectors.lock().await;
|
||||
connectors.insert(connector_id.clone(), connector.clone());
|
||||
|
||||
(connector, request_rx)
|
||||
}
|
||||
|
||||
/// Get a client by ID
|
||||
pub async fn get_client(&self, client_id: &str) -> Option<MockSseClient> {
|
||||
let clients = self.clients.lock().await;
|
||||
clients.get(client_id).cloned()
|
||||
}
|
||||
|
||||
/// Get a connector by ID
|
||||
pub async fn get_connector(&self, connector_id: &str) -> Option<MockConnector> {
|
||||
let connectors = self.connectors.lock().await;
|
||||
connectors.get(connector_id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TestContext {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create a sample permission request
|
||||
pub fn sample_permission_request(request_id: u64) -> Value {
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": "session/request_permission",
|
||||
"params": {
|
||||
"sessionId": "test-session",
|
||||
"tool": "Write",
|
||||
"parameters": {
|
||||
"path": "/tmp/test.txt",
|
||||
"content": "test"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper to create a sample permission response
|
||||
pub fn sample_permission_response(request_id: u64, allow: bool) -> Value {
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"selectedOptionId": if allow { "allow" } else { "deny" }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper to extract agent_request data from SSE event
|
||||
pub fn extract_agent_request(event: &MockEvent) -> Option<(Value, String, Value)> {
|
||||
if event.event_type != "session/update" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let update = event.data.get("update")?;
|
||||
if update.get("sessionUpdate")?.as_str()? != "agent_request" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let request_id = update.get("requestId")?.clone();
|
||||
let method = update.get("method")?.as_str()?.to_string();
|
||||
let params = update.get("params")?.clone();
|
||||
|
||||
Some((request_id, method, params))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_client() {
|
||||
let (client, _response_rx) = MockSseClient::new("test-client".to_string());
|
||||
|
||||
// Simulate receiving an event
|
||||
client.receive_sse("session/update".to_string(), json!({"test": "data"})).await;
|
||||
|
||||
// Verify event was received
|
||||
let events = client.get_events().await;
|
||||
assert_eq!(events.len(), 1);
|
||||
assert_eq!(events[0].event_type, "session/update");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_connector() {
|
||||
let (connector, mut request_rx) = MockConnector::new("test-connector".to_string());
|
||||
|
||||
// Simulate sending an agent request
|
||||
let response_fut = connector.send_agent_request(
|
||||
"session-1".to_string(),
|
||||
json!(0),
|
||||
"session/request_permission".to_string(),
|
||||
json!({"tool": "Write"}),
|
||||
).await;
|
||||
|
||||
// Verify request was sent
|
||||
let request = request_rx.recv().await.unwrap();
|
||||
assert_eq!(request.0, json!(0));
|
||||
assert_eq!(request.1, "session-1");
|
||||
assert_eq!(request.2, "session/request_permission");
|
||||
|
||||
// Simulate completing the response
|
||||
connector.complete_response(json!(0), json!({"result": "success"})).await.unwrap();
|
||||
|
||||
// Verify response was received
|
||||
let response = response_fut.await.unwrap();
|
||||
assert_eq!(response, json!({"result": "success"}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_context() {
|
||||
let ctx = TestContext::new();
|
||||
|
||||
// Create client and connector
|
||||
let (client, _) = ctx.create_client("client-1".to_string()).await;
|
||||
let (connector, _) = ctx.create_connector("connector-1".to_string()).await;
|
||||
|
||||
// Verify we can retrieve them
|
||||
assert!(ctx.get_client("client-1").await.is_some());
|
||||
assert!(ctx.get_connector("connector-1").await.is_some());
|
||||
assert!(ctx.get_client("non-existent").await.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_helpers() {
|
||||
let request = sample_permission_request(0);
|
||||
assert_eq!(request["method"], "session/request_permission");
|
||||
|
||||
let response = sample_permission_response(0, true);
|
||||
assert_eq!(response["result"]["selectedOptionId"], "allow");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_agent_request() {
|
||||
let event = MockEvent {
|
||||
event_type: "session/update".to_string(),
|
||||
data: json!({
|
||||
"sessionId": "session-1",
|
||||
"update": {
|
||||
"sessionUpdate": "agent_request",
|
||||
"requestId": 0,
|
||||
"method": "session/request_permission",
|
||||
"params": {"tool": "Write"}
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
let (request_id, method, params) = extract_agent_request(&event).unwrap();
|
||||
assert_eq!(request_id, json!(0));
|
||||
assert_eq!(method, "session/request_permission");
|
||||
assert_eq!(params["tool"], "Write");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,327 @@
|
||||
//! Integration test for concurrent agent requests (T050)
|
||||
//!
|
||||
//! This test verifies that the system can handle multiple agent requests
|
||||
//! simultaneously without cross-contamination.
|
||||
//!
|
||||
//! Test scenario:
|
||||
//! 1. Register multiple pending requests concurrently
|
||||
//! 2. Complete them in random/different order
|
||||
//! 3. Verify each response goes to the correct request
|
||||
//! 4. Verify no cross-contamination
|
||||
|
||||
use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
use serde_json::json;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_requests_basic() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// Register 5 concurrent requests
|
||||
let mut receivers = Vec::new();
|
||||
for i in 0..5 {
|
||||
let rx = tracker.register(client_id, json!(i));
|
||||
receivers.push((i, rx));
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 5);
|
||||
|
||||
// Complete them in reverse order
|
||||
for i in (0..5).rev() {
|
||||
let response = json!({"request": i, "result": "success"});
|
||||
tracker.complete(client_id, json!(i), response).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Verify each receiver got the correct response
|
||||
for (i, rx) in receivers {
|
||||
let response = rx.await.unwrap();
|
||||
assert_eq!(response["request"], i);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_requests_multiple_clients() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
// Two clients, each with 3 requests
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
|
||||
let mut receivers1 = Vec::new();
|
||||
let mut receivers2 = Vec::new();
|
||||
|
||||
for i in 0..3 {
|
||||
let rx1 = tracker.register(client1, json!(i));
|
||||
let rx2 = tracker.register(client2, json!(i));
|
||||
receivers1.push((i, rx1));
|
||||
receivers2.push((i, rx2));
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 6);
|
||||
assert_eq!(tracker.client_pending_count(client1), 3);
|
||||
assert_eq!(tracker.client_pending_count(client2), 3);
|
||||
|
||||
// Complete client1's requests
|
||||
for i in 0..3 {
|
||||
let response = json!({"client": 1, "request": i});
|
||||
tracker.complete(client1, json!(i), response).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(tracker.client_pending_count(client1), 0);
|
||||
assert_eq!(tracker.client_pending_count(client2), 3);
|
||||
|
||||
// Complete client2's requests
|
||||
for i in 0..3 {
|
||||
let response = json!({"client": 2, "request": i});
|
||||
tracker.complete(client2, json!(i), response).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Verify each receiver got the correct response
|
||||
for (i, rx) in receivers1 {
|
||||
let response = rx.await.unwrap();
|
||||
assert_eq!(response["client"], 1);
|
||||
assert_eq!(response["request"], i);
|
||||
}
|
||||
|
||||
for (i, rx) in receivers2 {
|
||||
let response = rx.await.unwrap();
|
||||
assert_eq!(response["client"], 2);
|
||||
assert_eq!(response["request"], i);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_requests_same_id_different_clients() {
|
||||
// Test that same request_id for different clients are handled independently
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
let request_id = json!(0); // Same ID for both
|
||||
|
||||
let rx1 = tracker.register(client1, request_id.clone());
|
||||
let rx2 = tracker.register(client2, request_id.clone());
|
||||
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
|
||||
// Complete client1's request
|
||||
let response1 = json!({"client": "client-1"});
|
||||
tracker.complete(client1, request_id.clone(), response1.clone()).unwrap();
|
||||
|
||||
// Complete client2's request
|
||||
let response2 = json!({"client": "client-2"});
|
||||
tracker.complete(client2, request_id, response2.clone()).unwrap();
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Verify each got the correct response
|
||||
let received1 = rx1.await.unwrap();
|
||||
let received2 = rx2.await.unwrap();
|
||||
|
||||
assert_eq!(received1["client"], "client-1");
|
||||
assert_eq!(received2["client"], "client-2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_async_completion() {
|
||||
// Test completing requests from multiple async tasks concurrently
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let num_requests = 10;
|
||||
|
||||
// Register requests
|
||||
let mut receivers = Vec::new();
|
||||
for i in 0..num_requests {
|
||||
let rx = tracker.register(client_id, json!(i));
|
||||
receivers.push((i, rx));
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), num_requests);
|
||||
|
||||
// Spawn tasks to complete requests concurrently
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for i in 0..num_requests {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
// Small delay to ensure concurrency
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(((i % 3) * 10) as u64)).await;
|
||||
let response = json!({"request": i, "result": "success"});
|
||||
tracker_clone.complete(client_id, json!(i), response)
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for all completions
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
assert!(result.unwrap().is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Verify all receivers got correct responses
|
||||
for (i, rx) in receivers {
|
||||
let response = rx.await.unwrap();
|
||||
assert_eq!(response["request"], i);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_register_and_complete() {
|
||||
// Test registering and completing requests concurrently
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let num_requests = 20;
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn tasks to register and complete requests
|
||||
for i in 0..num_requests {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
// Register
|
||||
let rx = tracker_clone.register(client_id, json!(i));
|
||||
|
||||
// Small delay
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||||
|
||||
// Complete
|
||||
let response = json!({"request": i});
|
||||
tracker_clone.complete(client_id, json!(i), response.clone()).unwrap();
|
||||
|
||||
// Wait for response
|
||||
let received = rx.await.unwrap();
|
||||
assert_eq!(received["request"], i);
|
||||
|
||||
i
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
let mut completed = Vec::new();
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
completed.push(result.unwrap());
|
||||
}
|
||||
|
||||
// All requests should have completed
|
||||
assert_eq!(completed.len(), num_requests);
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_mixed_operations() {
|
||||
// Test mix of register, complete, and timeout operations
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn 15 tasks with different behaviors
|
||||
for i in 0..15 {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
let rx = tracker_clone.register(client_id, json!(i));
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||||
|
||||
match i % 3 {
|
||||
0 => {
|
||||
// Complete normally
|
||||
let response = json!({"request": i, "type": "complete"});
|
||||
tracker_clone.complete(client_id, json!(i), response.clone()).unwrap();
|
||||
let received = rx.await.unwrap();
|
||||
assert_eq!(received["type"], "complete");
|
||||
"completed"
|
||||
}
|
||||
1 => {
|
||||
// Timeout
|
||||
tracker_clone.timeout(client_id, json!(i));
|
||||
assert!(rx.await.is_err());
|
||||
"timeout"
|
||||
}
|
||||
_ => {
|
||||
// Complete with delay
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
|
||||
let response = json!({"request": i, "type": "delayed"});
|
||||
tracker_clone.complete(client_id, json!(i), response.clone()).unwrap();
|
||||
let received = rx.await.unwrap();
|
||||
assert_eq!(received["type"], "delayed");
|
||||
"delayed"
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
let mut results = Vec::new();
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
results.push(result.unwrap());
|
||||
}
|
||||
|
||||
assert_eq!(results.len(), 15);
|
||||
|
||||
// Count outcomes
|
||||
let completed = results.iter().filter(|&r| r == &"completed").count();
|
||||
let timeout = results.iter().filter(|&r| r == &"timeout").count();
|
||||
let delayed = results.iter().filter(|&r| r == &"delayed").count();
|
||||
|
||||
assert_eq!(completed, 5);
|
||||
assert_eq!(timeout, 5);
|
||||
assert_eq!(delayed, 5);
|
||||
|
||||
// All should be cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_high_concurrency() {
|
||||
// Stress test with many concurrent requests
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let num_requests = 100;
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for i in 0..num_requests {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
let rx = tracker_clone.register(client_id, json!(i));
|
||||
|
||||
// Random-ish delay
|
||||
let delay = ((i * 7) % 20) as u64;
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
|
||||
|
||||
let response = json!({"request": i});
|
||||
tracker_clone.complete(client_id, json!(i), response).unwrap();
|
||||
|
||||
rx.await.unwrap()["request"].as_u64().unwrap()
|
||||
});
|
||||
}
|
||||
|
||||
// Collect all results
|
||||
let mut results = Vec::new();
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
results.push(result.unwrap());
|
||||
}
|
||||
|
||||
// Verify all requests completed
|
||||
assert_eq!(results.len(), num_requests);
|
||||
|
||||
// Verify all request IDs are present
|
||||
results.sort();
|
||||
for (idx, &val) in results.iter().enumerate() {
|
||||
assert_eq!(val, idx as u64);
|
||||
}
|
||||
|
||||
// All cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
@@ -0,0 +1,286 @@
|
||||
//! Integration test for client disconnection (T051)
|
||||
//!
|
||||
//! This test verifies that pending agent requests are cleaned up when a client
|
||||
//! disconnects, and that other clients are unaffected.
|
||||
//!
|
||||
//! Test scenario:
|
||||
//! 1. Register pending requests for multiple clients
|
||||
//! 2. Simulate client disconnection
|
||||
//! 3. Verify cleanup occurs for disconnected client
|
||||
//! 4. Verify other clients are unaffected
|
||||
|
||||
use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disconnect_single_client() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// Register 3 pending requests
|
||||
let rx1 = tracker.register(client_id, json!(0));
|
||||
let rx2 = tracker.register(client_id, json!(1));
|
||||
let rx3 = tracker.register(client_id, json!(2));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 3);
|
||||
assert_eq!(tracker.client_pending_count(client_id), 3);
|
||||
|
||||
// Simulate client disconnection - clear all requests for this client
|
||||
tracker.clear(Some(client_id));
|
||||
|
||||
// All requests should be removed
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
assert_eq!(tracker.client_pending_count(client_id), 0);
|
||||
|
||||
// All receivers should get errors (channels closed)
|
||||
assert!(rx1.await.is_err());
|
||||
assert!(rx2.await.is_err());
|
||||
assert!(rx3.await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disconnect_multiple_clients() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
let client3 = "client-3";
|
||||
|
||||
// Register requests for all clients
|
||||
let rx1_1 = tracker.register(client1, json!(0));
|
||||
let rx1_2 = tracker.register(client1, json!(1));
|
||||
|
||||
let rx2_1 = tracker.register(client2, json!(0));
|
||||
let rx2_2 = tracker.register(client2, json!(1));
|
||||
let rx2_3 = tracker.register(client2, json!(2));
|
||||
|
||||
let rx3_1 = tracker.register(client3, json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 6);
|
||||
assert_eq!(tracker.client_pending_count(client1), 2);
|
||||
assert_eq!(tracker.client_pending_count(client2), 3);
|
||||
assert_eq!(tracker.client_pending_count(client3), 1);
|
||||
|
||||
// Disconnect client2
|
||||
tracker.clear(Some(client2));
|
||||
|
||||
// Only client2's requests should be removed
|
||||
assert_eq!(tracker.pending_count(), 3);
|
||||
assert_eq!(tracker.client_pending_count(client1), 2);
|
||||
assert_eq!(tracker.client_pending_count(client2), 0);
|
||||
assert_eq!(tracker.client_pending_count(client3), 1);
|
||||
|
||||
// Client2's receivers should error
|
||||
assert!(rx2_1.await.is_err());
|
||||
assert!(rx2_2.await.is_err());
|
||||
assert!(rx2_3.await.is_err());
|
||||
|
||||
// Client1 and client3 should still work
|
||||
tracker.complete(client1, json!(0), json!({"result": "client1-0"})).unwrap();
|
||||
assert_eq!(rx1_1.await.unwrap()["result"], "client1-0");
|
||||
|
||||
tracker.complete(client3, json!(0), json!({"result": "client3-0"})).unwrap();
|
||||
assert_eq!(rx3_1.await.unwrap()["result"], "client3-0");
|
||||
|
||||
// Complete remaining client1 request
|
||||
tracker.complete(client1, json!(1), json!({"result": "client1-1"})).unwrap();
|
||||
assert_eq!(rx1_2.await.unwrap()["result"], "client1-1");
|
||||
|
||||
// All cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disconnect_then_reconnect() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// First connection - register requests
|
||||
let rx1 = tracker.register(client_id, json!(0));
|
||||
let rx2 = tracker.register(client_id, json!(1));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
|
||||
// Disconnect - cleanup
|
||||
tracker.clear(Some(client_id));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Old receivers should error
|
||||
assert!(rx1.await.is_err());
|
||||
assert!(rx2.await.is_err());
|
||||
|
||||
// Reconnect - register new requests (same client_id, same request_ids)
|
||||
let rx3 = tracker.register(client_id, json!(0));
|
||||
let rx4 = tracker.register(client_id, json!(1));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
|
||||
// Complete new requests
|
||||
tracker.complete(client_id, json!(0), json!({"result": "new-0"})).unwrap();
|
||||
tracker.complete(client_id, json!(1), json!({"result": "new-1"})).unwrap();
|
||||
|
||||
// New receivers should get responses
|
||||
assert_eq!(rx3.await.unwrap()["result"], "new-0");
|
||||
assert_eq!(rx4.await.unwrap()["result"], "new-1");
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disconnect_no_pending_requests() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// No pending requests
|
||||
assert_eq!(tracker.client_pending_count(client_id), 0);
|
||||
|
||||
// Disconnect should be no-op
|
||||
tracker.clear(Some(client_id));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_clear_all_clients() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
let client3 = "client-3";
|
||||
|
||||
// Register requests for multiple clients
|
||||
let rx1 = tracker.register(client1, json!(0));
|
||||
let rx2 = tracker.register(client2, json!(0));
|
||||
let rx3 = tracker.register(client3, json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 3);
|
||||
|
||||
// Clear all (simulating server shutdown)
|
||||
tracker.clear(None);
|
||||
|
||||
// All should be removed
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
assert_eq!(tracker.client_pending_count(client1), 0);
|
||||
assert_eq!(tracker.client_pending_count(client2), 0);
|
||||
assert_eq!(tracker.client_pending_count(client3), 0);
|
||||
|
||||
// All receivers should error
|
||||
assert!(rx1.await.is_err());
|
||||
assert!(rx2.await.is_err());
|
||||
assert!(rx3.await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disconnect_race_with_completion() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// Register multiple requests
|
||||
let rx1 = tracker.register(client_id, json!(0));
|
||||
let rx2 = tracker.register(client_id, json!(1));
|
||||
let rx3 = tracker.register(client_id, json!(2));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 3);
|
||||
|
||||
// Complete one request
|
||||
tracker.complete(client_id, json!(0), json!({"result": "0"})).unwrap();
|
||||
|
||||
// Verify it was removed
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
|
||||
// Now disconnect (should only clear remaining requests)
|
||||
tracker.clear(Some(client_id));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// First receiver should have gotten response
|
||||
assert_eq!(rx1.await.unwrap()["result"], "0");
|
||||
|
||||
// Other receivers should error
|
||||
assert!(rx2.await.is_err());
|
||||
assert!(rx3.await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_partial_disconnect_completion() {
|
||||
// Test that completing a request after disconnect fails gracefully
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
let _rx = tracker.register(client_id, json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// Disconnect
|
||||
tracker.clear(Some(client_id));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Try to complete after disconnect - should fail
|
||||
let result = tracker.complete(client_id, json!(0), json!({"result": "late"}));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_disconnect_and_complete() {
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// Register many requests
|
||||
for i in 0..50 {
|
||||
tracker.register(client_id, json!(i));
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 50);
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn task to disconnect after delay
|
||||
{
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||||
tracker_clone.clear(Some(client_id));
|
||||
"disconnected"
|
||||
});
|
||||
}
|
||||
|
||||
// Spawn tasks to complete requests
|
||||
for i in 0..50 {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
// Small delay to ensure some complete before disconnect
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis((i % 5) as u64)).await;
|
||||
let response = json!({"request": i});
|
||||
match tracker_clone.complete(client_id, json!(i), response) {
|
||||
Ok(_) => "completed",
|
||||
Err(_) => "failed",
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
let mut results = Vec::new();
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
results.push(result.unwrap());
|
||||
}
|
||||
|
||||
// Should have 1 disconnect + 50 completion attempts
|
||||
assert_eq!(results.len(), 51);
|
||||
|
||||
// Some completions succeeded, some failed (after disconnect)
|
||||
let disconnects = results.iter().filter(|r| r == &&"disconnected").count();
|
||||
assert_eq!(disconnects, 1);
|
||||
|
||||
// All requests should be cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
//! Integration tests for session/list RPC method
|
||||
|
||||
use dirigent_acp_api::{
|
||||
NoOpConnectorOperations, RpcHandler, SessionManager,
|
||||
};
|
||||
use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
use dirigent_acp_api::sse::SseNotifier;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn create_test_handler() -> RpcHandler<NoOpConnectorOperations> {
|
||||
RpcHandler::new(
|
||||
SessionManager::new(),
|
||||
NoOpConnectorOperations,
|
||||
SseNotifier::new(),
|
||||
Arc::new(AgentRequestTracker::new()),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_list_returns_sessions() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/list",
|
||||
"params": {
|
||||
"connectorId": "stub-connector"
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
// NoOp returns one stub session
|
||||
let result = &response_json["result"];
|
||||
assert!(result["sessions"].is_array());
|
||||
let sessions = result["sessions"].as_array().unwrap();
|
||||
assert!(!sessions.is_empty());
|
||||
assert!(sessions[0]["sessionId"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_list_no_params() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
// session/list with no params should use default connector
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "session/list"
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
// Should succeed using default connector from NoOp
|
||||
let result = &response_json["result"];
|
||||
assert!(result["sessions"].is_array());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_list_creates_mappings() {
|
||||
let session_manager = SessionManager::new();
|
||||
let handler = RpcHandler::new(
|
||||
session_manager.clone(),
|
||||
NoOpConnectorOperations,
|
||||
SseNotifier::new(),
|
||||
Arc::new(AgentRequestTracker::new()),
|
||||
);
|
||||
|
||||
// First, initialize to register a client
|
||||
let init_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 0,
|
||||
"method": "initialize",
|
||||
"params": {}
|
||||
});
|
||||
let init_response = handler.handle_request(&init_body.to_string(), Some("test-client")).await;
|
||||
let init_json = serde_json::to_value(&init_response).unwrap();
|
||||
let client_id = init_json["result"]["clientId"].as_str().unwrap();
|
||||
|
||||
// Now list sessions
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/list",
|
||||
"params": {
|
||||
"connectorId": "stub-connector"
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), Some(client_id)).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let sessions = response_json["result"]["sessions"].as_array().unwrap();
|
||||
assert!(!sessions.is_empty());
|
||||
|
||||
// Session mapping should exist for the returned session
|
||||
let session_id = sessions[0]["sessionId"].as_str().unwrap();
|
||||
let mapping = session_manager.get_mapping(session_id);
|
||||
assert!(mapping.is_some(), "Session mapping should be created for listed sessions");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_load_without_connector_id() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
// Standard ACP: only sessionId + cwd + mcpServers, no connectorId
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/load",
|
||||
"params": {
|
||||
"sessionId": "sess-abc",
|
||||
"cwd": "G:\\dev\\projects\\test",
|
||||
"mcpServers": []
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let result = &response_json["result"];
|
||||
assert!(result["sessionId"].is_string(), "session/load should succeed without connectorId");
|
||||
assert!(result["createdAt"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialize_advertises_list_sessions() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let caps = &response_json["result"]["agentCapabilities"];
|
||||
assert_eq!(caps["listSessions"], true, "listSessions capability should be advertised");
|
||||
assert_eq!(caps["loadSession"], true, "loadSession capability should still be advertised");
|
||||
|
||||
// Verify nested sessionCapabilities.list is advertised (required by Zed v0.9.4+)
|
||||
assert!(
|
||||
caps["sessionCapabilities"]["list"].is_object(),
|
||||
"sessionCapabilities.list should be advertised as an empty object"
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
//! Integration tests for session/resume RPC method
|
||||
|
||||
use dirigent_acp_api::{
|
||||
NoOpConnectorOperations, RpcHandler, SessionManager,
|
||||
};
|
||||
use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
use dirigent_acp_api::sse::SseNotifier;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn create_test_handler() -> RpcHandler<NoOpConnectorOperations> {
|
||||
RpcHandler::new(
|
||||
SessionManager::new(),
|
||||
NoOpConnectorOperations,
|
||||
SseNotifier::new(),
|
||||
Arc::new(AgentRequestTracker::new()),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_resume_returns_session() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/resume",
|
||||
"params": {
|
||||
"sessionId": "sess-123",
|
||||
"connectorId": "stub-connector"
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let result = &response_json["result"];
|
||||
assert!(result["sessionId"].is_string());
|
||||
assert!(result["connectorId"].is_string());
|
||||
assert!(result["createdAt"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_resume_missing_params() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/resume"
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
// Should return error for missing params
|
||||
assert!(response_json["error"].is_object());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_resume_creates_mapping() {
|
||||
let session_manager = SessionManager::new();
|
||||
let handler = RpcHandler::new(
|
||||
session_manager.clone(),
|
||||
NoOpConnectorOperations,
|
||||
SseNotifier::new(),
|
||||
Arc::new(AgentRequestTracker::new()),
|
||||
);
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/resume",
|
||||
"params": {
|
||||
"sessionId": "sess-456",
|
||||
"connectorId": "stub-connector"
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let session_id = response_json["result"]["sessionId"].as_str().unwrap();
|
||||
let mapping = session_manager.get_mapping(session_id);
|
||||
assert!(mapping.is_some(), "Session mapping should be created for resumed session");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_resume_without_connector_id() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
// Standard ACP: only sessionId, no connectorId — should resolve via default connector
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/resume",
|
||||
"params": {
|
||||
"sessionId": "sess-789",
|
||||
"cwd": "G:\\dev\\projects\\test"
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let result = &response_json["result"];
|
||||
assert!(result["sessionId"].is_string(), "Should succeed without connectorId");
|
||||
assert!(result["createdAt"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialize_advertises_session_resume() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let caps = &response_json["result"]["agentCapabilities"];
|
||||
assert!(
|
||||
caps["sessionCapabilities"]["list"].is_object(),
|
||||
"sessionCapabilities.list should be advertised as an empty object"
|
||||
);
|
||||
assert!(
|
||||
caps["sessionCapabilities"]["resume"].is_object(),
|
||||
"sessionCapabilities.resume should be advertised as an empty object"
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
//! Integration test for timeout handling (T049)
|
||||
//!
|
||||
//! This test verifies that the system handles timeouts gracefully when a client
|
||||
//! fails to respond to an agent request within the timeout period.
|
||||
//!
|
||||
//! Test scenario:
|
||||
//! 1. Register a pending agent request
|
||||
//! 2. Wait for timeout (using reduced timeout for testing)
|
||||
//! 3. Verify timeout occurs
|
||||
//! 4. Verify cleanup happens correctly
|
||||
//! 5. Verify no resource leaks
|
||||
|
||||
use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
use serde_json::json;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_basic() {
|
||||
// Create tracker
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let request_id = json!(0);
|
||||
|
||||
// Register request
|
||||
let receiver = tracker.register(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// Wait for timeout (use short timeout for testing)
|
||||
let result = timeout(Duration::from_millis(100), receiver).await;
|
||||
|
||||
// Should timeout
|
||||
assert!(result.is_err(), "Expected timeout but request completed");
|
||||
|
||||
// Manually trigger cleanup (in production, event bridge does this)
|
||||
tracker.timeout(client_id, request_id);
|
||||
|
||||
// Verify cleanup
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_cleanup() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let request_id = json!(123);
|
||||
|
||||
// Register request
|
||||
let receiver = tracker.register(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// Trigger timeout before waiting
|
||||
tracker.timeout(client_id, request_id);
|
||||
|
||||
// Verify cleanup happened
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Receiver should get error (channel closed)
|
||||
let result = receiver.await;
|
||||
assert!(result.is_err(), "Expected receiver to get error after timeout");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_multiple_clients() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
|
||||
// Register requests for both clients
|
||||
let rx1 = tracker.register(client1, json!(0));
|
||||
let rx2 = tracker.register(client2, json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
assert_eq!(tracker.client_pending_count(client1), 1);
|
||||
assert_eq!(tracker.client_pending_count(client2), 1);
|
||||
|
||||
// Timeout only client1's request
|
||||
tracker.timeout(client1, json!(0));
|
||||
|
||||
// Verify only client1's request is removed
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
assert_eq!(tracker.client_pending_count(client1), 0);
|
||||
assert_eq!(tracker.client_pending_count(client2), 1);
|
||||
|
||||
// Client1's receiver should error
|
||||
assert!(rx1.await.is_err());
|
||||
|
||||
// Complete client2's request normally
|
||||
let response = json!({"result": "success"});
|
||||
let result = tracker.complete(client2, json!(0), response);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Client2's receiver should get response
|
||||
let received = rx2.await.unwrap();
|
||||
assert_eq!(received, json!({"result": "success"}));
|
||||
|
||||
// All cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_no_double_cleanup() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let request_id = json!(0);
|
||||
|
||||
// Register request
|
||||
let _receiver = tracker.register(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// First timeout - should remove
|
||||
tracker.timeout(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Second timeout - should be no-op (not panic)
|
||||
tracker.timeout(client_id, request_id);
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_race_with_complete() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let request_id = json!(0);
|
||||
|
||||
// Register request
|
||||
let receiver = tracker.register(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// Complete the request
|
||||
let response = json!({"result": "success"});
|
||||
let result = tracker.complete(client_id, request_id.clone(), response.clone());
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Try to timeout after completion - should be no-op
|
||||
tracker.timeout(client_id, request_id);
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Receiver should still get the response
|
||||
let received = receiver.await.unwrap();
|
||||
assert_eq!(received, response);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_timeouts() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// Register 10 requests
|
||||
let mut receivers = Vec::new();
|
||||
for i in 0..10 {
|
||||
let rx = tracker.register(client_id, json!(i));
|
||||
receivers.push((i, rx));
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 10);
|
||||
|
||||
// Spawn tasks to timeout each request after random delays
|
||||
let tracker_clone = tracker.clone();
|
||||
let timeout_handles: Vec<_> = (0..10)
|
||||
.map(|i| {
|
||||
let tracker = tracker_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
// Small random-ish delay based on index
|
||||
tokio::time::sleep(Duration::from_millis((i * 10) as u64)).await;
|
||||
tracker.timeout(client_id, json!(i));
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Wait for all timeouts to complete
|
||||
for handle in timeout_handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// All should be cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// All receivers should get errors
|
||||
for (_i, rx) in receivers {
|
||||
assert!(rx.await.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_with_actual_delay() {
|
||||
// This test uses actual time delays to verify timeout behavior more realistically
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let request_id = json!(0);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Register request
|
||||
let receiver = tracker.register(client_id, request_id.clone());
|
||||
|
||||
// Spawn task to timeout after 200ms
|
||||
let tracker_clone = tracker.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
tracker_clone.timeout(client_id, json!(0));
|
||||
});
|
||||
|
||||
// Wait on receiver with longer timeout
|
||||
let result = timeout(Duration::from_secs(1), receiver).await;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
// Should complete due to timeout() call, not tokio::time::timeout
|
||||
assert!(result.is_ok(), "Should complete when timeout() is called");
|
||||
assert!(result.unwrap().is_err(), "Receiver should get error");
|
||||
|
||||
// Should take approximately 200ms
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(300),
|
||||
"Expected ~200ms but took {:?}",
|
||||
elapsed
|
||||
);
|
||||
|
||||
// Should be cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
Reference in New Issue
Block a user