chore: rename packages/ to crates/
Move all 29 workspace members from packages/<name>/ to crates/<name>/. Updates: workspace Cargo.toml (members + path deps), justfile, root CLAUDE.md, scripts/build/CARGO_INSTALL.md, docs/architecture/crates.md (renamed from packages.md), structural references in docs/architecture and docs/configuration, per-crate CLAUDE.md self-references. Historical plans, reports, and building/ docs are left untouched. No behavior change; just check-all stays green and fermata tests pass.
This commit is contained in:
@@ -0,0 +1,931 @@
|
||||
//! Bridge mode implementation for relaying stdio ACP clients to Dirigent ACP Server.
|
||||
//!
|
||||
//! This module implements the stdio-to-HTTP/SSE bridge that allows external ACP clients
|
||||
//! (like Claude Code configured for stdio transport) to connect to a Dirigent ACP Server.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! ```text
|
||||
//! External ACP Client (Claude Code, etc.)
|
||||
//! |
|
||||
//! | stdio (stdin/stdout)
|
||||
//! v
|
||||
//! +-------------------+
|
||||
//! | Conductor Bridge |
|
||||
//! | - stdin parser | <-- Reads JSON-RPC from stdin
|
||||
//! | - HTTP client | <-- POSTs to /rpc endpoint
|
||||
//! | - SSE subscriber | <-- Subscribes to /events
|
||||
//! +-------------------+
|
||||
//! |
|
||||
//! | HTTP/SSE
|
||||
//! v
|
||||
//! Dirigent ACP Server
|
||||
//! ```
|
||||
//!
|
||||
//! ## Protocol Flow
|
||||
//!
|
||||
//! 1. Client sends JSON-RPC request on stdin (line-delimited JSON)
|
||||
//! 2. Bridge reads the request, POSTs to `{server_url}/rpc`
|
||||
//! 3. Bridge writes the HTTP response to stdout
|
||||
//! 4. Meanwhile, SSE subscriber writes server notifications to stdout
|
||||
|
||||
use dirigent_core::acp::transport::json_reader::{JsonLineReader, ReadResult};
|
||||
use crate::Result;
|
||||
use dirigent_protocol::log_utils::mask_json_string;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::io::{AsyncWriteExt, BufReader};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Configuration for the bridge.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BridgeConfig {
|
||||
/// Base URL of the ACP Server (e.g., "http://localhost:3001/acp").
|
||||
pub server_url: String,
|
||||
|
||||
/// Timeout for HTTP requests.
|
||||
pub timeout: Duration,
|
||||
|
||||
/// Enable verbose logging of JSON-RPC messages.
|
||||
pub verbose: bool,
|
||||
|
||||
/// Automatically reconnect SSE stream on disconnect.
|
||||
pub auto_reconnect: bool,
|
||||
|
||||
/// Optional connector selection by ID or agent type magic word
|
||||
pub select_connector: Option<String>,
|
||||
}
|
||||
|
||||
impl BridgeConfig {
|
||||
/// Create a new bridge configuration.
|
||||
pub fn new(server_url: String) -> Self {
|
||||
Self {
|
||||
server_url,
|
||||
timeout: Duration::from_secs(30),
|
||||
verbose: false,
|
||||
auto_reconnect: true,
|
||||
select_connector: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the RPC endpoint URL.
|
||||
pub fn rpc_url(&self) -> String {
|
||||
format!("{}/rpc", self.server_url.trim_end_matches('/'))
|
||||
}
|
||||
|
||||
/// Build the SSE events endpoint URL.
|
||||
pub fn events_url(&self) -> String {
|
||||
let base = format!("{}/events", self.server_url.trim_end_matches('/'));
|
||||
base
|
||||
}
|
||||
|
||||
/// Build the SSE events endpoint URL with client_id query parameter.
|
||||
pub fn events_url_with_client(&self, client_id: &str) -> String {
|
||||
format!("{}/events?client_id={}", self.server_url.trim_end_matches('/'), client_id)
|
||||
}
|
||||
|
||||
/// Build the agent response endpoint URL.
|
||||
pub fn agent_response_url(&self) -> String {
|
||||
format!("{}/agent_response", self.server_url.trim_end_matches('/'))
|
||||
}
|
||||
}
|
||||
|
||||
/// Bridge state shared between tasks.
|
||||
struct BridgeState {
|
||||
/// Client ID received from initialize response.
|
||||
client_id: Option<String>,
|
||||
|
||||
/// Buffered RPC responses for session/prompt requests (session_id -> response_json)
|
||||
/// These are held until we see the corresponding turn.complete notification
|
||||
/// (with session_idle as fallback for backward compatibility)
|
||||
buffered_responses: std::collections::HashMap<String, String>,
|
||||
|
||||
/// Pending agent requests awaiting responses from Zed
|
||||
/// Maps request_id -> insertion timestamp for timeout tracking
|
||||
pending_agent_requests: std::collections::HashMap<serde_json::Value, Instant>,
|
||||
}
|
||||
|
||||
impl BridgeState {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
client_id: None,
|
||||
buffered_responses: std::collections::HashMap::new(),
|
||||
pending_agent_requests: std::collections::HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// JSON-RPC request structure (minimal for parsing).
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct JsonRpcRequest {
|
||||
jsonrpc: String,
|
||||
method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
params: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// JSON-RPC response structure (minimal for parsing).
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct JsonRpcResponse {
|
||||
jsonrpc: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// JSON-RPC notification structure for SSE events.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct JsonRpcNotification {
|
||||
jsonrpc: String,
|
||||
method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Run the bridge, relaying between stdio and the ACP Server.
|
||||
///
|
||||
/// This function:
|
||||
/// 1. Spawns an SSE subscriber task that writes notifications to stdout
|
||||
/// 2. Reads JSON-RPC requests from stdin in a loop
|
||||
/// 3. Forwards requests to the ACP Server via HTTP
|
||||
/// 4. Writes responses to stdout
|
||||
pub async fn run_bridge(config: BridgeConfig) -> Result<()> {
|
||||
tracing::info!(
|
||||
server_url = %config.server_url,
|
||||
"Starting bridge to Dirigent ACP Server"
|
||||
);
|
||||
|
||||
// Create HTTP client
|
||||
let client = Client::builder()
|
||||
.timeout(config.timeout)
|
||||
.build()
|
||||
.map_err(|e| crate::MockerError::Internal(format!("Failed to create HTTP client: {}", e)))?;
|
||||
|
||||
// Shared state
|
||||
let state = Arc::new(Mutex::new(BridgeState::new()));
|
||||
|
||||
// Channel for stdout writes (to serialize output from multiple tasks)
|
||||
let (stdout_tx, mut stdout_rx) = mpsc::channel::<String>(100);
|
||||
|
||||
// Spawn stdout writer task
|
||||
let stdout_handle = tokio::spawn(async move {
|
||||
let mut stdout = tokio::io::stdout();
|
||||
while let Some(line) = stdout_rx.recv().await {
|
||||
// T019: Start timing for stdout write
|
||||
let write_start = Instant::now();
|
||||
|
||||
if let Err(e) = stdout.write_all(line.as_bytes()).await {
|
||||
tracing::error!(error = %e, "Failed to write to stdout");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = stdout.write_all(b"\n").await {
|
||||
tracing::error!(error = %e, "Failed to write newline to stdout");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = stdout.flush().await {
|
||||
tracing::error!(error = %e, "Failed to flush stdout");
|
||||
break;
|
||||
}
|
||||
|
||||
// T019: Trace logging after flush
|
||||
let elapsed_ms = write_start.elapsed().as_millis();
|
||||
tracing::trace!(
|
||||
elapsed_ms = elapsed_ms,
|
||||
line_len = line.len(),
|
||||
"Flushed event to stdout"
|
||||
);
|
||||
|
||||
// T020: Warning for slow stdout writes
|
||||
if elapsed_ms > 100 {
|
||||
tracing::warn!(
|
||||
elapsed_ms = elapsed_ms,
|
||||
line_len = line.len(),
|
||||
"Slow stdout write + flush detected"
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn SSE subscriber task
|
||||
let sse_config = config.clone();
|
||||
let sse_stdout_tx = stdout_tx.clone();
|
||||
let sse_state = Arc::clone(&state);
|
||||
let sse_handle = tokio::spawn(async move {
|
||||
run_sse_subscriber(sse_config, sse_stdout_tx, sse_state).await
|
||||
});
|
||||
|
||||
// Spawn timeout checker task
|
||||
let timeout_state = Arc::clone(&state);
|
||||
let timeout_handle = tokio::spawn(async move {
|
||||
run_timeout_checker(timeout_state).await
|
||||
});
|
||||
|
||||
// Run stdin reader in main task
|
||||
let stdin_result = run_stdin_reader(config, client, stdout_tx, state).await;
|
||||
|
||||
// Cleanup
|
||||
sse_handle.abort();
|
||||
timeout_handle.abort();
|
||||
stdout_handle.abort();
|
||||
|
||||
stdin_result
|
||||
}
|
||||
|
||||
/// Read JSON-RPC requests from stdin and forward to the ACP Server.
|
||||
async fn run_stdin_reader(
|
||||
config: BridgeConfig,
|
||||
client: Client,
|
||||
stdout_tx: mpsc::Sender<String>,
|
||||
state: Arc<Mutex<BridgeState>>,
|
||||
) -> Result<()> {
|
||||
let stdin = tokio::io::stdin();
|
||||
let mut reader = BufReader::new(stdin);
|
||||
let mut json_reader = JsonLineReader::new();
|
||||
|
||||
tracing::debug!("Stdin reader started, waiting for JSON-RPC requests");
|
||||
|
||||
loop {
|
||||
// Read next JSON message (handles multi-line JSON from clients)
|
||||
let message: serde_json::Value = match json_reader.read_message(&mut reader).await {
|
||||
Ok(ReadResult::Message(msg)) => msg,
|
||||
Ok(ReadResult::Eof) => break,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to parse JSON-RPC message from stdin");
|
||||
let error_response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32700,
|
||||
"message": format!("Parse error: {}", e)
|
||||
},
|
||||
"id": null
|
||||
});
|
||||
let _ = stdout_tx.send(error_response.to_string()).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Serialize for forwarding (always compact single-line)
|
||||
let line = serde_json::to_string(&message).unwrap_or_default();
|
||||
|
||||
if config.verbose {
|
||||
tracing::debug!(request = %mask_json_string(&line), "Received JSON-RPC request from stdin");
|
||||
}
|
||||
|
||||
// Check if this is a response to an agent request (has id but no method)
|
||||
let is_response = message.get("id").is_some() && message.get("method").is_none();
|
||||
|
||||
if is_response {
|
||||
// This might be a response to an agent request
|
||||
if let Some(response_id) = message.get("id") {
|
||||
let mut state_lock = state.lock().await;
|
||||
|
||||
// Check if this response_id matches a pending agent request
|
||||
if state_lock.pending_agent_requests.contains_key(response_id) {
|
||||
tracing::info!(
|
||||
response_id = ?response_id,
|
||||
"Detected Zed response to agent request"
|
||||
);
|
||||
|
||||
// Remove from pending requests
|
||||
state_lock.pending_agent_requests.remove(response_id);
|
||||
tracing::debug!(
|
||||
response_id = ?response_id,
|
||||
remaining_pending = state_lock.pending_agent_requests.len(),
|
||||
"Removed agent request from pending set"
|
||||
);
|
||||
|
||||
// Get client_id
|
||||
let client_id = state_lock.client_id.clone();
|
||||
drop(state_lock); // Release lock before HTTP request
|
||||
|
||||
if let Some(client_id) = client_id {
|
||||
// POST response to /agent_response
|
||||
let agent_response_url = config.agent_response_url();
|
||||
|
||||
tracing::info!(
|
||||
url = %agent_response_url,
|
||||
response_id = ?response_id,
|
||||
"POSTing agent response to Dirigent"
|
||||
);
|
||||
|
||||
match client
|
||||
.post(&agent_response_url)
|
||||
.header("X-Client-ID", &client_id)
|
||||
.header("Content-Type", "application/json")
|
||||
.body(line.to_string())
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let status = resp.status();
|
||||
if status.is_success() {
|
||||
tracing::info!(
|
||||
response_id = ?response_id,
|
||||
status = %status,
|
||||
"Successfully posted agent response to Dirigent"
|
||||
);
|
||||
} else {
|
||||
let body = resp.text().await.unwrap_or_else(|_| "".to_string());
|
||||
tracing::error!(
|
||||
response_id = ?response_id,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"Failed to post agent response to Dirigent"
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
error = %e,
|
||||
response_id = ?response_id,
|
||||
"HTTP request failed when posting agent response"
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!(
|
||||
response_id = ?response_id,
|
||||
"Cannot post agent response: client_id not set"
|
||||
);
|
||||
}
|
||||
|
||||
// Skip forwarding to /acp/rpc - this is a response flow, not a new request
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse as request for normal processing
|
||||
let request: JsonRpcRequest = match serde_json::from_value(message) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, line = %line, "Failed to parse JSON-RPC request");
|
||||
// Send error response
|
||||
let error_response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32700,
|
||||
"message": format!("Parse error: {}", e)
|
||||
},
|
||||
"id": null
|
||||
});
|
||||
let _ = stdout_tx.send(error_response.to_string()).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Forward request to server (with client_id if available and not initialize)
|
||||
let client_id = if request.method != "initialize" {
|
||||
state.lock().await.client_id.clone()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// For initialize requests, pass select_connector as header
|
||||
let select_connector = if request.method == "initialize" {
|
||||
config.select_connector.as_deref()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let response = forward_request(&client, &config, &line, client_id.as_deref(), select_connector).await;
|
||||
|
||||
match response {
|
||||
Ok(response_text) => {
|
||||
if config.verbose {
|
||||
tracing::debug!(response = %mask_json_string(&response_text), "Received response from server");
|
||||
}
|
||||
|
||||
// Check if this is an initialize response to extract client_id
|
||||
if request.method == "initialize" {
|
||||
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&response_text) {
|
||||
if let Some(result) = resp.result {
|
||||
// Try both camelCase (ACP spec) and snake_case
|
||||
if let Some(client_id) = result.get("clientId")
|
||||
.or_else(|| result.get("client_id"))
|
||||
.and_then(|v| v.as_str()) {
|
||||
let mut state = state.lock().await;
|
||||
state.client_id = Some(client_id.to_string());
|
||||
tracing::info!(client_id = %client_id, "Received client_id from initialize response");
|
||||
} else {
|
||||
tracing::warn!("No clientId found in initialize response");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a deferred response (e.g., gateway session/list waiting
|
||||
// for transfer). The real response will arrive later via SSE _rpc_response.
|
||||
if let Ok(resp) = serde_json::from_str::<JsonRpcResponse>(&response_text) {
|
||||
if let Some(ref result) = resp.result {
|
||||
if result.get("_deferred").and_then(|v| v.as_bool()) == Some(true) {
|
||||
tracing::info!(
|
||||
id = ?resp.id,
|
||||
method = %request.method,
|
||||
"Dropping deferred response - real response will arrive via SSE"
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a session/prompt response - buffer it instead of sending immediately
|
||||
if request.method == "session/prompt" {
|
||||
// Extract session_id from request params
|
||||
if let Some(params) = &request.params {
|
||||
if let Some(session_id) = params.get("sessionId").and_then(|v| v.as_str()) {
|
||||
tracing::debug!(
|
||||
session_id = %session_id,
|
||||
"Buffering session/prompt response until turn.complete"
|
||||
);
|
||||
let mut state = state.lock().await;
|
||||
state.buffered_responses.insert(session_id.to_string(), response_text);
|
||||
// Don't send to stdout yet - wait for turn.complete (or session_idle as fallback)
|
||||
continue;
|
||||
} else {
|
||||
tracing::warn!("session/prompt request missing sessionId in params");
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("session/prompt request missing params");
|
||||
}
|
||||
}
|
||||
|
||||
// Send response to stdout (for non-session/prompt responses)
|
||||
if stdout_tx.send(response_text).await.is_err() {
|
||||
tracing::error!("Stdout channel closed");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, method = %request.method, "Failed to forward request to server");
|
||||
// Send error response
|
||||
let error_response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32603,
|
||||
"message": format!("Internal error: {}", e)
|
||||
},
|
||||
"id": request.id
|
||||
});
|
||||
let _ = stdout_tx.send(error_response.to_string()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup: Clear all pending agent requests when stdin closes (Zed disconnected)
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
let pending_count = state.pending_agent_requests.len();
|
||||
if pending_count > 0 {
|
||||
tracing::warn!(
|
||||
pending_count = pending_count,
|
||||
"Stdin closed with {} pending agent requests - clearing all",
|
||||
pending_count
|
||||
);
|
||||
state.pending_agent_requests.clear();
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("Stdin closed, shutting down bridge");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Forward a JSON-RPC request to the ACP Server.
|
||||
async fn forward_request(
|
||||
client: &Client,
|
||||
config: &BridgeConfig,
|
||||
request_body: &str,
|
||||
client_id: Option<&str>,
|
||||
select_connector: Option<&str>,
|
||||
) -> Result<String> {
|
||||
let mut request_builder = client
|
||||
.post(config.rpc_url())
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
// Add client_id header if available
|
||||
if let Some(id) = client_id {
|
||||
request_builder = request_builder.header("X-Client-ID", id);
|
||||
tracing::debug!(client_id = %id, "Forwarding request with client_id header");
|
||||
}
|
||||
|
||||
// Add select_connector header if available (for initialize requests)
|
||||
if let Some(connector) = select_connector {
|
||||
request_builder = request_builder.header("X-Select-Connector", connector);
|
||||
tracing::info!(select_connector = %connector, "Forwarding initialize with X-Select-Connector header");
|
||||
}
|
||||
|
||||
let response = request_builder
|
||||
.body(request_body.to_string())
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| crate::MockerError::Internal(format!("HTTP request failed: {}", e)))?;
|
||||
|
||||
let status = response.status();
|
||||
let body = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| crate::MockerError::Internal(format!("Failed to read response body: {}", e)))?;
|
||||
|
||||
if !status.is_success() {
|
||||
tracing::warn!(status = %status, body = %body, "Server returned non-success status");
|
||||
}
|
||||
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
/// Subscribe to SSE events and write notifications to stdout.
|
||||
async fn run_sse_subscriber(
|
||||
config: BridgeConfig,
|
||||
stdout_tx: mpsc::Sender<String>,
|
||||
state: Arc<Mutex<BridgeState>>,
|
||||
) {
|
||||
loop {
|
||||
// Wait for client_id to be set before subscribing
|
||||
let client_id = loop {
|
||||
let state = state.lock().await;
|
||||
if let Some(ref id) = state.client_id {
|
||||
break id.clone();
|
||||
}
|
||||
drop(state);
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
};
|
||||
|
||||
let events_url = config.events_url_with_client(&client_id);
|
||||
tracing::info!(url = %events_url, "Connecting to SSE events endpoint");
|
||||
|
||||
let mut es = EventSource::get(&events_url);
|
||||
|
||||
while let Some(event) = es.next().await {
|
||||
match event {
|
||||
Ok(Event::Open) => {
|
||||
tracing::info!("SSE connection opened");
|
||||
}
|
||||
Ok(Event::Message(message)) => {
|
||||
// T016: Trace logging when SSE event is received from dirigent
|
||||
tracing::trace!(
|
||||
event_type = %message.event,
|
||||
data_len = message.data.len(),
|
||||
"Received SSE event from dirigent"
|
||||
);
|
||||
|
||||
if config.verbose {
|
||||
tracing::debug!(
|
||||
event_type = %message.event,
|
||||
data = %mask_json_string(&message.data),
|
||||
"Received SSE event"
|
||||
);
|
||||
}
|
||||
|
||||
// T017: Start timing for notification processing
|
||||
let notification_start = Instant::now();
|
||||
|
||||
// _rpc_response events are complete JSON-RPC responses pushed via SSE.
|
||||
// Forward the data directly to stdout without wrapping as a notification.
|
||||
if message.event == "_rpc_response" {
|
||||
tracing::info!(
|
||||
data_len = message.data.len(),
|
||||
"Received _rpc_response via SSE - forwarding directly to stdout"
|
||||
);
|
||||
if stdout_tx.send(message.data.clone()).await.is_err() {
|
||||
tracing::error!("Stdout channel closed while forwarding _rpc_response");
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Convert SSE event to JSON-RPC notification
|
||||
// Note: Use event type as-is for ACP compliance (don't add acp/ prefix)
|
||||
let notification = JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: message.event.clone(),
|
||||
params: serde_json::from_str(&message.data).ok(),
|
||||
};
|
||||
|
||||
let notification_json = match serde_json::to_string(¬ification) {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to serialize notification");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// T017: Trace logging after notification serialization
|
||||
tracing::trace!(
|
||||
elapsed_ms = notification_start.elapsed().as_millis(),
|
||||
"Notification JSON serialization completed"
|
||||
);
|
||||
|
||||
// T021: Warning for large notifications
|
||||
if notification_json.len() > 10240 {
|
||||
tracing::warn!(
|
||||
size_bytes = notification_json.len(),
|
||||
"Large notification being sent to stdout"
|
||||
);
|
||||
}
|
||||
|
||||
// T018: Trace logging before writing to stdout
|
||||
tracing::trace!(
|
||||
notification_len = notification_json.len(),
|
||||
elapsed_since_receive_ms = notification_start.elapsed().as_millis(),
|
||||
"Sending notification to stdout channel"
|
||||
);
|
||||
|
||||
// Send the notification to stdout
|
||||
if stdout_tx.send(notification_json).await.is_err() {
|
||||
tracing::error!("Stdout channel closed, stopping SSE subscriber");
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if this is an agent_request - forward to Zed as JSON-RPC request
|
||||
if message.event == "session/update" {
|
||||
if let Some(params) = ¬ification.params {
|
||||
if let Some(update) = params.get("update") {
|
||||
if update.get("sessionUpdate") == Some(&serde_json::json!("agent_request")) {
|
||||
tracing::info!("Detected agent_request from SSE");
|
||||
|
||||
// Extract agent request fields
|
||||
let request_id = match update.get("requestId") {
|
||||
Some(id) => id.clone(),
|
||||
None => {
|
||||
tracing::warn!("agent_request missing requestId, skipping");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let method = match update.get("method").and_then(|m| m.as_str()) {
|
||||
Some(m) => m,
|
||||
None => {
|
||||
tracing::warn!("agent_request missing method, skipping");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let request_params = match update.get("params") {
|
||||
Some(p) => p.clone(),
|
||||
None => {
|
||||
tracing::warn!("agent_request missing params, skipping");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Format as JSON-RPC request for Zed
|
||||
let rpc_request = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": request_params
|
||||
});
|
||||
|
||||
let rpc_request_json = match serde_json::to_string(&rpc_request) {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to serialize agent request");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
request_id = ?request_id,
|
||||
method = %method,
|
||||
"Forwarding agent request to Zed"
|
||||
);
|
||||
|
||||
// T022: Start timing for agent request forwarding
|
||||
let agent_request_start = Instant::now();
|
||||
|
||||
// Send formatted request to Zed's stdin
|
||||
if stdout_tx.send(rpc_request_json).await.is_err() {
|
||||
tracing::error!("Stdout channel closed while sending agent request");
|
||||
return;
|
||||
}
|
||||
|
||||
// T022: Trace logging after agent request forwarded
|
||||
tracing::trace!(
|
||||
elapsed_ms = agent_request_start.elapsed().as_millis(),
|
||||
"Agent request forwarded to Zed stdout"
|
||||
);
|
||||
|
||||
// Track pending agent request with timestamp for timeout tracking
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
state.pending_agent_requests.insert(request_id.clone(), Instant::now());
|
||||
tracing::debug!(
|
||||
request_id = ?request_id,
|
||||
pending_count = state.pending_agent_requests.len(),
|
||||
"Tracking pending agent request"
|
||||
);
|
||||
}
|
||||
|
||||
// Skip regular notification handling for agent requests
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a turn.complete notification - if so, flush buffered response
|
||||
// This is the primary signal that all content has been received
|
||||
if message.event == "turn.complete" {
|
||||
tracing::debug!("Received turn.complete notification");
|
||||
if let Some(params) = ¬ification.params {
|
||||
if let Some(session_id) = params.get("session_id").and_then(|v| v.as_str()) {
|
||||
tracing::debug!("Session ID from turn.complete: {}", session_id);
|
||||
// Check if we have a buffered response for this session
|
||||
let mut state = state.lock().await;
|
||||
tracing::debug!(
|
||||
"Looking for buffered response for session {}, have {} buffered",
|
||||
session_id,
|
||||
state.buffered_responses.len()
|
||||
);
|
||||
if let Some(buffered_response) = state.buffered_responses.remove(session_id) {
|
||||
tracing::info!(
|
||||
session_id = %session_id,
|
||||
"Flushing buffered session/prompt response after turn.complete"
|
||||
);
|
||||
drop(state); // Release lock before sending
|
||||
if stdout_tx.send(buffered_response).await.is_err() {
|
||||
tracing::error!("Stdout channel closed while flushing buffered response");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
tracing::debug!("No buffered response for session {} (already flushed or no pending response)", session_id);
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("No 'session_id' in turn.complete params");
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("turn.complete params is None, raw data: {}", message.data);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this is a session_idle notification - fallback flush for backward compatibility
|
||||
if message.event == "session/update" {
|
||||
tracing::debug!("Received session/update notification, checking for session_idle");
|
||||
// Parse the notification params to check for session_idle
|
||||
if let Some(params) = ¬ification.params {
|
||||
tracing::debug!("Notification params: {:?}", params);
|
||||
if let Some(session_id) = params.get("sessionId").and_then(|v| v.as_str()) {
|
||||
tracing::debug!("Session ID from notification: {}", session_id);
|
||||
if let Some(update) = params.get("update") {
|
||||
if let Some(session_update) = update.get("sessionUpdate").and_then(|v| v.as_str()) {
|
||||
tracing::debug!("Session update type: {}", session_update);
|
||||
if session_update == "session_idle" {
|
||||
// Check if we have a buffered response for this session
|
||||
let mut state = state.lock().await;
|
||||
tracing::debug!(
|
||||
"Looking for buffered response for session {}, have {} buffered (fallback path)",
|
||||
session_id,
|
||||
state.buffered_responses.len()
|
||||
);
|
||||
if let Some(buffered_response) = state.buffered_responses.remove(session_id) {
|
||||
tracing::info!(
|
||||
session_id = %session_id,
|
||||
"Flushing buffered session/prompt response after session_idle (fallback)"
|
||||
);
|
||||
drop(state); // Release lock before sending
|
||||
if stdout_tx.send(buffered_response).await.is_err() {
|
||||
tracing::error!("Stdout channel closed while flushing buffered response");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
tracing::debug!("No buffered response found for session {} (already flushed by turn.complete)", session_id);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("Could not extract sessionUpdate from update");
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("No 'update' field in params");
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("No 'sessionId' in params");
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("Notification params is None, raw data: {}", message.data);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(reqwest_eventsource::Error::StreamEnded) => {
|
||||
tracing::info!("SSE stream ended");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "SSE error");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cleanup: Clear pending agent requests when SSE stream closes
|
||||
{
|
||||
let mut state = state.lock().await;
|
||||
let pending_count = state.pending_agent_requests.len();
|
||||
if pending_count > 0 {
|
||||
tracing::warn!(
|
||||
pending_count = pending_count,
|
||||
"SSE stream closed with {} pending agent requests - clearing all",
|
||||
pending_count
|
||||
);
|
||||
state.pending_agent_requests.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if !config.auto_reconnect {
|
||||
tracing::info!("Auto-reconnect disabled, stopping SSE subscriber");
|
||||
return;
|
||||
}
|
||||
|
||||
tracing::info!("Reconnecting to SSE endpoint in 1 second...");
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Periodically check for and cleanup stale pending agent requests.
|
||||
async fn run_timeout_checker(state: Arc<Mutex<BridgeState>>) {
|
||||
const TIMEOUT_DURATION: Duration = Duration::from_secs(30);
|
||||
const CHECK_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
||||
tracing::debug!("Timeout checker started");
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(CHECK_INTERVAL).await;
|
||||
|
||||
let mut state = state.lock().await;
|
||||
let now = Instant::now();
|
||||
let mut stale_requests = Vec::new();
|
||||
|
||||
// Find stale requests
|
||||
for (request_id, inserted_at) in state.pending_agent_requests.iter() {
|
||||
let elapsed = now.duration_since(*inserted_at);
|
||||
if elapsed >= TIMEOUT_DURATION {
|
||||
stale_requests.push(request_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove stale requests
|
||||
for request_id in stale_requests {
|
||||
state.pending_agent_requests.remove(&request_id);
|
||||
tracing::error!(
|
||||
request_id = ?request_id,
|
||||
timeout_secs = TIMEOUT_DURATION.as_secs(),
|
||||
"Agent request timed out - removing from pending set"
|
||||
);
|
||||
}
|
||||
|
||||
if !state.pending_agent_requests.is_empty() {
|
||||
tracing::debug!(
|
||||
pending_count = state.pending_agent_requests.len(),
|
||||
"Timeout checker: {} pending agent requests",
|
||||
state.pending_agent_requests.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_bridge_config_urls() {
|
||||
let config = BridgeConfig::new("http://localhost:3001/acp".to_string());
|
||||
assert_eq!(config.rpc_url(), "http://localhost:3001/acp/rpc");
|
||||
assert_eq!(config.events_url(), "http://localhost:3001/acp/events");
|
||||
assert_eq!(config.agent_response_url(), "http://localhost:3001/acp/agent_response");
|
||||
|
||||
// Test with trailing slash
|
||||
let config = BridgeConfig::new("http://localhost:3001/acp/".to_string());
|
||||
assert_eq!(config.rpc_url(), "http://localhost:3001/acp/rpc");
|
||||
assert_eq!(config.events_url(), "http://localhost:3001/acp/events");
|
||||
assert_eq!(config.agent_response_url(), "http://localhost:3001/acp/agent_response");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_rpc_request_parsing() {
|
||||
let json = r#"{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}"#;
|
||||
let request: JsonRpcRequest = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(request.method, "initialize");
|
||||
assert_eq!(request.id, Some(serde_json::json!(1)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_rpc_notification_serialization() {
|
||||
let notification = JsonRpcNotification {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: "acp/messageChunk".to_string(),
|
||||
params: Some(serde_json::json!({"content": "Hello"})),
|
||||
};
|
||||
let json = serde_json::to_string(¬ification).unwrap();
|
||||
assert!(json.contains("acp/messageChunk"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
//! ACP (Agent-Client Protocol) server implementation.
|
||||
//!
|
||||
//! This module provides the HTTP/WebSocket server that implements the ACP protocol.
|
||||
//! It handles incoming client connections, routes requests to fixture responders,
|
||||
//! and streams responses back to clients.
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! - `server.rs` - Axum HTTP server setup and routing
|
||||
//! - `model.rs` - ACP protocol type definitions and serialization
|
||||
//! - `stream.rs` - Server-Sent Events (SSE) streaming for real-time updates
|
||||
//! - `stdio.rs` - Stdin/stdout transport for editors like Zed
|
||||
//! - `bridge.rs` - Bridge mode for relaying stdio to HTTP/SSE
|
||||
|
||||
pub mod bridge;
|
||||
pub mod model;
|
||||
pub mod server;
|
||||
pub mod stdio;
|
||||
pub mod stream;
|
||||
|
||||
// Re-exports for convenience
|
||||
pub use bridge::{run_bridge, BridgeConfig};
|
||||
pub use model::*;
|
||||
pub use server::*;
|
||||
pub use stdio::serve_stdio;
|
||||
pub use stream::{
|
||||
chunk_text, stream_session, StreamConfig, StreamController, StreamEvent,
|
||||
};
|
||||
@@ -0,0 +1,709 @@
|
||||
//! ACP protocol type definitions.
|
||||
//!
|
||||
//! This module defines the core types for the Agent-Client Protocol (ACP) v0.1,
|
||||
//! including JSON-RPC message structures and ACP-specific request/response types.
|
||||
//!
|
||||
//! These types are aligned with the official ACP specification and designed to
|
||||
//! work with the mocker's fixture system for predictable testing scenarios.
|
||||
//!
|
||||
//! # References
|
||||
//! - ACP Specification: https://github.com/agent-client-protocol/spec
|
||||
//! - JSON-RPC 2.0: https://www.jsonrpc.org/specification
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// Re-export types from fixture module that are used in ACP protocol
|
||||
use crate::fixture::types::{Message, Participant};
|
||||
|
||||
// ============================================================================
|
||||
// JSON-RPC Core Types
|
||||
// ============================================================================
|
||||
// Reference: https://www.jsonrpc.org/specification
|
||||
|
||||
/// JSON-RPC 2.0 request message.
|
||||
///
|
||||
/// Represents a client request with a method name, optional parameters, and an ID.
|
||||
/// The ID is used to correlate requests with responses.
|
||||
///
|
||||
/// # JSON-RPC Specification
|
||||
/// A Request object has the following members:
|
||||
/// - `jsonrpc`: MUST be exactly "2.0"
|
||||
/// - `method`: A String containing the name of the method to be invoked
|
||||
/// - `params`: A Structured value (Array or Object) holding the parameter values
|
||||
/// - `id`: An identifier established by the Client (String, Number, or NULL)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcRequest {
|
||||
/// JSON-RPC version (always "2.0").
|
||||
pub jsonrpc: String,
|
||||
|
||||
/// The name of the method to be invoked.
|
||||
pub method: String,
|
||||
|
||||
/// Optional parameters for the method (structured as JSON value).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
|
||||
/// Request identifier (number or string). Used to correlate with response.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// JSON-RPC 2.0 response message.
|
||||
///
|
||||
/// Represents a server response to a request. Contains either a result or an error,
|
||||
/// but never both.
|
||||
///
|
||||
/// # JSON-RPC Specification
|
||||
/// A Response object has the following members:
|
||||
/// - `jsonrpc`: MUST be exactly "2.0"
|
||||
/// - `result`: Required on success, absent on error
|
||||
/// - `error`: Required on error, absent on success
|
||||
/// - `id`: MUST be the same as the value of the id member in the Request Object
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcResponse {
|
||||
/// JSON-RPC version (always "2.0").
|
||||
pub jsonrpc: String,
|
||||
|
||||
/// The result of the method invocation (present on success).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<serde_json::Value>,
|
||||
|
||||
/// Error object (present on error).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<JsonRpcError>,
|
||||
|
||||
/// Request identifier (must match the request ID).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// JSON-RPC 2.0 notification message.
|
||||
///
|
||||
/// A notification is a request without an ID. The server does not send a response
|
||||
/// to notifications.
|
||||
///
|
||||
/// # JSON-RPC Specification
|
||||
/// A Notification is a Request object without an "id" member. The Server MUST NOT
|
||||
/// reply to a Notification.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcNotification {
|
||||
/// JSON-RPC version (always "2.0").
|
||||
pub jsonrpc: String,
|
||||
|
||||
/// The name of the method to be invoked.
|
||||
pub method: String,
|
||||
|
||||
/// Optional parameters for the method (structured as JSON value).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// JSON-RPC 2.0 error object.
|
||||
///
|
||||
/// Represents an error that occurred during method execution.
|
||||
///
|
||||
/// # JSON-RPC Specification
|
||||
/// An Error object has the following members:
|
||||
/// - `code`: A Number indicating the error type
|
||||
/// - `message`: A String providing a short description
|
||||
/// - `data`: Optional additional information about the error
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct JsonRpcError {
|
||||
/// A numeric error code indicating the type of error.
|
||||
///
|
||||
/// Standard codes:
|
||||
/// - `-32700`: Parse error (invalid JSON)
|
||||
/// - `-32600`: Invalid Request
|
||||
/// - `-32601`: Method not found
|
||||
/// - `-32602`: Invalid params
|
||||
/// - `-32603`: Internal error
|
||||
/// - `-32000` to `-32099`: Server error (implementation-defined)
|
||||
pub code: i32,
|
||||
|
||||
/// A short description of the error.
|
||||
pub message: String,
|
||||
|
||||
/// Optional additional information about the error.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ACP Content Types
|
||||
// ============================================================================
|
||||
// Reference: ACP spec section on content blocks
|
||||
|
||||
/// A block of content in an ACP message.
|
||||
///
|
||||
/// Content blocks represent different types of message content. In ACP v0.1,
|
||||
/// only text content is supported, but the enum structure allows for future
|
||||
/// extensions.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// Content blocks are the building blocks of messages. Each block has a type
|
||||
/// and type-specific fields.
|
||||
///
|
||||
/// # TODO
|
||||
/// Future content block types to implement:
|
||||
/// - `Image { source: ImageSource, alt_text: Option<String> }`
|
||||
/// - `Tool { id: String, name: String, input: Value }`
|
||||
/// - `ToolResult { tool_call_id: String, content: Vec<ContentBlock> }`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentBlock {
|
||||
/// Plain text content.
|
||||
Text {
|
||||
/// The text content.
|
||||
text: String,
|
||||
},
|
||||
// TODO: Add Image variant
|
||||
// TODO: Add Tool variant
|
||||
// TODO: Add ToolResult variant
|
||||
}
|
||||
|
||||
/// Reason why an agent stopped generating content.
|
||||
///
|
||||
/// Indicates the condition that caused the agent to stop producing output.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// The stop reason indicates why message generation terminated. This helps
|
||||
/// clients understand whether the message is complete or was interrupted.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StopReason {
|
||||
/// The language model finishes responding without requesting more tools.
|
||||
EndTurn,
|
||||
|
||||
/// The maximum token limit is reached.
|
||||
MaxTokens,
|
||||
|
||||
/// The maximum number of model requests in a single turn is exceeded.
|
||||
MaxTurnRequests,
|
||||
|
||||
/// The Agent refuses to continue.
|
||||
Refusal,
|
||||
|
||||
/// The Client cancels the turn.
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ACP Initialize Types
|
||||
// ============================================================================
|
||||
// Reference: ACP spec section on initialization handshake
|
||||
|
||||
/// Parameters for the `acp.initialize` method.
|
||||
///
|
||||
/// Sent by the client to initiate an ACP session and negotiate capabilities.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// The initialize request is the first message in an ACP session. It establishes
|
||||
/// the protocol version and allows client and server to negotiate capabilities.
|
||||
///
|
||||
/// # JSON-RPC Method
|
||||
/// Method: `acp.initialize`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeParams {
|
||||
/// The ACP protocol version the client supports (protocol version 1 = integer).
|
||||
pub protocol_version: serde_json::Value, // Can be integer or string
|
||||
|
||||
/// Client capabilities and configuration.
|
||||
///
|
||||
/// This is intentionally flexible to allow for client-specific capabilities.
|
||||
/// The server should examine this and respond with compatible agent capabilities.
|
||||
pub client_capabilities: serde_json::Value,
|
||||
|
||||
/// Optional client information (name, title, version).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub client_info: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Result of the `acp.initialize` method.
|
||||
///
|
||||
/// Returned by the server in response to an initialize request.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// The initialize response confirms the protocol version and declares the
|
||||
/// agent's capabilities. This allows the client to adapt its behavior based
|
||||
/// on what the agent supports.
|
||||
///
|
||||
/// # JSON-RPC Method
|
||||
/// Method: `acp.initialize`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InitializeResult {
|
||||
/// The ACP protocol version the server supports (protocol version 1 = integer).
|
||||
pub protocol_version: serde_json::Value, // Can be integer or string
|
||||
|
||||
/// Agent capabilities.
|
||||
pub agent_capabilities: AgentCapabilities,
|
||||
|
||||
/// Optional agent information (name, title, version).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub agent_info: Option<serde_json::Value>,
|
||||
|
||||
/// Optional authentication methods supported.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub auth_methods: Option<Vec<String>>,
|
||||
|
||||
/// Optional metadata for protocol extensions.
|
||||
///
|
||||
/// This field uses a leading underscore to indicate it's for extensions
|
||||
/// and not part of the core protocol.
|
||||
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
|
||||
pub _meta: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Agent capabilities declaration.
|
||||
///
|
||||
/// Describes what features the agent supports.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// Capabilities allow the client to discover what features are available.
|
||||
/// The agent should accurately report its capabilities to avoid client errors.
|
||||
///
|
||||
/// # TODO
|
||||
/// Future capabilities to implement:
|
||||
/// - `tools: bool` - Agent supports tool calls
|
||||
/// - `files: bool` - Agent supports file attachments
|
||||
/// - `streaming: bool` - Agent supports streaming responses
|
||||
/// - `context_window: Option<usize>` - Maximum context size
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AgentCapabilities {
|
||||
/// Whether the agent supports loading existing sessions.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub load_session: Option<bool>,
|
||||
|
||||
/// Prompt capabilities (what input types are supported).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_capabilities: Option<PromptCapabilities>,
|
||||
|
||||
/// MCP (Model Context Protocol) capabilities.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mcp: Option<serde_json::Value>,
|
||||
// TODO: Add tools capability
|
||||
// TODO: Add files capability
|
||||
// TODO: Add streaming capability
|
||||
// TODO: Add context_window capability
|
||||
}
|
||||
|
||||
/// Prompt capabilities declaration.
|
||||
///
|
||||
/// Describes what types of content the agent can accept in prompts.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// In v0.1, only text prompts are required. Future versions may support
|
||||
/// images, audio, video, etc.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PromptCapabilities {
|
||||
/// Whether the agent supports image prompts.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub image: Option<bool>,
|
||||
|
||||
/// Whether the agent supports audio prompts.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub audio: Option<bool>,
|
||||
|
||||
/// Whether the agent supports embedded context (resources).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub embedded_context: Option<bool>,
|
||||
// Note: text and resourceLink are baseline requirements, not listed here
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ACP Session Types
|
||||
// ============================================================================
|
||||
// Reference: ACP spec sections on session management
|
||||
|
||||
/// Parameters for the `session/new` method.
|
||||
///
|
||||
/// Sent by the client to create a new session.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// Creates a new conversation session. The session ID is generated by the
|
||||
/// agent and returned in the response.
|
||||
///
|
||||
/// # JSON-RPC Method
|
||||
/// Method: `session/new`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionNewParams {
|
||||
/// The working directory for the session (absolute path).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cwd: Option<String>,
|
||||
|
||||
/// MCP servers the agent should connect to.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mcp_servers: Option<Vec<serde_json::Value>>,
|
||||
|
||||
/// Optional fixture template ID (mocker-specific extension).
|
||||
///
|
||||
/// When provided, the mocker will create a session based on the specified
|
||||
/// fixture template. If not provided, a default session is created.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub template_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Result of the `session/new` method.
|
||||
///
|
||||
/// Returns the ID of the newly created session.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// The server generates a unique session ID and returns it to the client.
|
||||
/// The client uses this ID in subsequent requests to interact with the session.
|
||||
///
|
||||
/// # JSON-RPC Method
|
||||
/// Method: `session/new`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionNewResult {
|
||||
/// The unique identifier for the new session.
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
/// Parameters for the `session/load` method.
|
||||
///
|
||||
/// Sent by the client to load an existing session.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// Loads a previously created session, including its full message history.
|
||||
/// This requires the agent to support the `loadSession` capability.
|
||||
/// The agent replays the conversation via session/update notifications.
|
||||
///
|
||||
/// # JSON-RPC Method
|
||||
/// Method: `session/load`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionLoadParams {
|
||||
/// The ID of the session to load.
|
||||
pub session_id: String,
|
||||
|
||||
/// The working directory for the session (absolute path).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cwd: Option<String>,
|
||||
|
||||
/// MCP servers the agent should connect to.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mcp_servers: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
/// Complete information about a session.
|
||||
///
|
||||
/// Represents a full conversation session with all its data.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// A session contains metadata, participants, and the full message history.
|
||||
/// This is the primary data structure for representing conversations in ACP.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionInfo {
|
||||
/// Unique session identifier.
|
||||
pub id: String,
|
||||
|
||||
/// Human-readable session title.
|
||||
pub title: String,
|
||||
|
||||
/// ISO8601 timestamp when session was created.
|
||||
pub created_at: String,
|
||||
|
||||
/// Participants in this session.
|
||||
pub participants: Vec<Participant>,
|
||||
|
||||
/// Messages in this session.
|
||||
pub messages: Vec<Message>,
|
||||
}
|
||||
|
||||
/// Parameters for the `session/prompt` method.
|
||||
///
|
||||
/// Sent by the client to send a prompt (user message) to the agent.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// The prompt method is the primary way to interact with an agent. The client
|
||||
/// sends content blocks (usually text) and the agent responds with generated
|
||||
/// content via session/update notifications during processing.
|
||||
///
|
||||
/// # JSON-RPC Method
|
||||
/// Method: `session/prompt`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionPromptParams {
|
||||
/// The ID of the session to send the prompt to.
|
||||
pub session_id: String,
|
||||
|
||||
/// The content blocks to send (prompt content).
|
||||
pub prompt: Vec<ContentBlock>,
|
||||
}
|
||||
|
||||
/// Result of the `session/prompt` method.
|
||||
///
|
||||
/// Returns why the agent stopped processing the prompt.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// The response only contains the stop reason. All content is sent via
|
||||
/// session/update notifications during processing.
|
||||
///
|
||||
/// # JSON-RPC Method
|
||||
/// Method: `session/prompt`
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionPromptResult {
|
||||
/// The reason why the agent stopped generating.
|
||||
pub stop_reason: StopReason,
|
||||
}
|
||||
|
||||
/// Parameters for the `session/update` notification.
|
||||
///
|
||||
/// Sent by the agent to notify the client of session state changes.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// Session updates are server-to-client notifications (not requests) that
|
||||
/// inform the client about changes to a session. These are typically sent
|
||||
/// during streaming responses. The update field contains the type-specific
|
||||
/// update information.
|
||||
///
|
||||
/// # JSON-RPC Method
|
||||
/// Method: `session/update` (notification, no response expected)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionUpdateParams {
|
||||
/// The ID of the session that was updated.
|
||||
pub session_id: String,
|
||||
|
||||
/// The update information (type-specific).
|
||||
/// Common types: agent_message_chunk, user_message_chunk, tool_call, etc.
|
||||
pub update: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Parameters for the `session/cancel` notification.
|
||||
///
|
||||
/// Sent by the client to request cancellation of an in-progress generation.
|
||||
///
|
||||
/// # ACP Specification
|
||||
/// The cancel notification tells the agent to stop generating content for
|
||||
/// the specified session. The agent should stop as soon as possible and
|
||||
/// return the content generated so far.
|
||||
///
|
||||
/// # JSON-RPC Method
|
||||
/// Method: `session/cancel` (notification, no response expected)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionCancelParams {
|
||||
/// The ID of the session to cancel.
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Utility Implementations
|
||||
// ============================================================================
|
||||
|
||||
impl JsonRpcRequest {
|
||||
/// Creates a new JSON-RPC request with the given method and parameters.
|
||||
pub fn new(method: impl Into<String>, params: Option<serde_json::Value>, id: impl Into<serde_json::Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: method.into(),
|
||||
params,
|
||||
id: Some(id.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonRpcResponse {
|
||||
/// Creates a successful JSON-RPC response with the given result.
|
||||
pub fn success(result: serde_json::Value, id: Option<serde_json::Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
result: Some(result),
|
||||
error: None,
|
||||
id,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an error JSON-RPC response with the given error.
|
||||
pub fn error(error: JsonRpcError, id: Option<serde_json::Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
result: None,
|
||||
error: Some(error),
|
||||
id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonRpcNotification {
|
||||
/// Creates a new JSON-RPC notification with the given method and parameters.
|
||||
pub fn new(method: impl Into<String>, params: Option<serde_json::Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
method: method.into(),
|
||||
params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl JsonRpcError {
|
||||
/// Creates a new JSON-RPC error with the given code and message.
|
||||
pub fn new(code: i32, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
code,
|
||||
message: message.into(),
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new JSON-RPC error with additional data.
|
||||
pub fn with_data(code: i32, message: impl Into<String>, data: serde_json::Value) -> Self {
|
||||
Self {
|
||||
code,
|
||||
message: message.into(),
|
||||
data: Some(data),
|
||||
}
|
||||
}
|
||||
|
||||
/// Standard error: Parse error (invalid JSON).
|
||||
pub fn parse_error() -> Self {
|
||||
Self::new(-32700, "Parse error")
|
||||
}
|
||||
|
||||
/// Standard error: Invalid Request.
|
||||
pub fn invalid_request() -> Self {
|
||||
Self::new(-32600, "Invalid Request")
|
||||
}
|
||||
|
||||
/// Standard error: Method not found.
|
||||
pub fn method_not_found(method: &str) -> Self {
|
||||
Self::new(-32601, format!("Method not found: {}", method))
|
||||
}
|
||||
|
||||
/// Standard error: Invalid params.
|
||||
pub fn invalid_params(message: impl Into<String>) -> Self {
|
||||
Self::new(-32602, message)
|
||||
}
|
||||
|
||||
/// Standard error: Internal error.
|
||||
pub fn internal_error(message: impl Into<String>) -> Self {
|
||||
Self::new(-32603, message)
|
||||
}
|
||||
}
|
||||
|
||||
impl ContentBlock {
|
||||
/// Creates a new text content block.
|
||||
pub fn text(text: impl Into<String>) -> Self {
|
||||
Self::Text { text: text.into() }
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_request_serde() {
|
||||
let req = JsonRpcRequest::new("test.method", None, 1);
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"jsonrpc\":\"2.0\""));
|
||||
assert!(json.contains("\"method\":\"test.method\""));
|
||||
assert!(json.contains("\"id\":1"));
|
||||
|
||||
let parsed: JsonRpcRequest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.method, "test.method");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_response_success() {
|
||||
let resp = JsonRpcResponse::success(serde_json::json!({"result": "ok"}), Some(1.into()));
|
||||
let json = serde_json::to_string(&resp).unwrap();
|
||||
assert!(json.contains("\"jsonrpc\":\"2.0\""));
|
||||
assert!(json.contains("\"result\""));
|
||||
assert!(!json.contains("\"error\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_response_error() {
|
||||
let err = JsonRpcError::method_not_found("test.method");
|
||||
let resp = JsonRpcResponse::error(err, Some(1.into()));
|
||||
let json = serde_json::to_string(&resp).unwrap();
|
||||
assert!(json.contains("\"error\""));
|
||||
assert!(json.contains("-32601"));
|
||||
assert!(!json.contains("\"result\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_block_text() {
|
||||
let block = ContentBlock::text("Hello, world!");
|
||||
let json = serde_json::to_string(&block).unwrap();
|
||||
assert!(json.contains("\"type\":\"text\""));
|
||||
assert!(json.contains("\"text\":\"Hello, world!\""));
|
||||
|
||||
let parsed: ContentBlock = serde_json::from_str(&json).unwrap();
|
||||
match parsed {
|
||||
ContentBlock::Text { text } => assert_eq!(text, "Hello, world!"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stop_reason_serde() {
|
||||
let reason = StopReason::EndTurn;
|
||||
let json = serde_json::to_string(&reason).unwrap();
|
||||
assert_eq!(json, "\"end_turn\"");
|
||||
|
||||
let parsed: StopReason = serde_json::from_str("\"cancelled\"").unwrap();
|
||||
matches!(parsed, StopReason::Cancelled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initialize_params_serde() {
|
||||
let params = InitializeParams {
|
||||
protocol_version: serde_json::json!(1),
|
||||
client_capabilities: serde_json::json!({"streaming": true}),
|
||||
client_info: None,
|
||||
};
|
||||
let json = serde_json::to_string(¶ms).unwrap();
|
||||
assert!(json.contains("\"protocolVersion\":1"));
|
||||
assert!(json.contains("\"clientCapabilities\""));
|
||||
|
||||
let _parsed: InitializeParams = serde_json::from_str(&json).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_new_params_with_template() {
|
||||
let params = SessionNewParams {
|
||||
cwd: Some(".".to_string()),
|
||||
mcp_servers: Some(vec![]),
|
||||
template_id: Some("test-template".to_string()),
|
||||
};
|
||||
let json = serde_json::to_string(¶ms).unwrap();
|
||||
eprintln!("JSON: {}", json);
|
||||
assert!(json.contains("\"templateId\":\"test-template\"")); // camelCase
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_new_params_without_template() {
|
||||
let params = SessionNewParams {
|
||||
cwd: Some(".".to_string()),
|
||||
mcp_servers: Some(vec![]),
|
||||
template_id: None,
|
||||
};
|
||||
let json = serde_json::to_string(¶ms).unwrap();
|
||||
// Should not include template_id field when None
|
||||
assert!(!json.contains("template_id"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_prompt_params() {
|
||||
let params = SessionPromptParams {
|
||||
session_id: "session-123".to_string(),
|
||||
prompt: vec![ContentBlock::text("Hello!")],
|
||||
};
|
||||
let json = serde_json::to_string(¶ms).unwrap();
|
||||
assert!(json.contains("\"sessionId\":\"session-123\"")); // camelCase
|
||||
assert!(json.contains("\"prompt\""));
|
||||
assert!(json.contains("\"type\":\"text\""));
|
||||
}
|
||||
}
|
||||
+2211
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,147 @@
|
||||
//! Stdio transport for JSON-RPC communication.
|
||||
//!
|
||||
//! This module implements JSON-RPC over stdin/stdout for use with editors like Zed
|
||||
//! that launch agent servers as child processes.
|
||||
|
||||
use dirigent_core::acp::transport::json_reader::{JsonLineReader, ReadResult};
|
||||
use crate::{MockerState, Result};
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{AsyncWriteExt, BufReader};
|
||||
|
||||
/// Run the JSON-RPC server using stdin/stdout transport.
|
||||
///
|
||||
/// This reads JSON-RPC requests from stdin (one per line) and writes responses to stdout.
|
||||
/// This is the transport mode used by editors like Zed.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The mocker state containing fixtures and sessions
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if stdin/stdout I/O fails or JSON parsing fails.
|
||||
pub async fn serve_stdio(state: Arc<MockerState>) -> Result<()> {
|
||||
tracing::info!("Starting ACP server in stdio mode");
|
||||
tracing::debug!("Reading JSON-RPC requests from stdin, writing responses to stdout");
|
||||
|
||||
let stdin = tokio::io::stdin();
|
||||
let mut reader = BufReader::new(stdin);
|
||||
let mut json_reader = JsonLineReader::new();
|
||||
|
||||
// Use a shared stdout wrapped in Mutex to serialize writes
|
||||
let stdout = Arc::new(tokio::sync::Mutex::new(tokio::io::stdout()));
|
||||
|
||||
// Subscribe to session/update notifications
|
||||
let mut event_rx = state.subscribe_events();
|
||||
|
||||
// Spawn a task to forward session/update notifications to stdout
|
||||
let stdout_clone = stdout.clone();
|
||||
let notification_task = tokio::spawn(async move {
|
||||
loop {
|
||||
match event_rx.recv().await {
|
||||
Ok(notification) => {
|
||||
tracing::info!(
|
||||
session_id = %notification.session_id,
|
||||
notification = %notification.notification,
|
||||
"📤 Forwarding session/update notification to stdout"
|
||||
);
|
||||
|
||||
// Acquire lock and write notification
|
||||
let mut stdout_guard = stdout_clone.lock().await;
|
||||
if let Err(e) = stdout_guard.write_all(notification.notification.as_bytes()).await {
|
||||
tracing::error!(error = %e, "Failed to write notification to stdout");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = stdout_guard.write_all(b"\n").await {
|
||||
tracing::error!(error = %e, "Failed to write newline");
|
||||
break;
|
||||
}
|
||||
if let Err(e) = stdout_guard.flush().await {
|
||||
tracing::error!(error = %e, "Failed to flush stdout");
|
||||
break;
|
||||
}
|
||||
drop(stdout_guard); // Release lock
|
||||
|
||||
tracing::info!("✅ Notification written and flushed to stdout");
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
|
||||
tracing::warn!(skipped, "Event receiver lagged, some notifications were skipped");
|
||||
}
|
||||
Err(tokio::sync::broadcast::error::RecvError::Closed) => {
|
||||
tracing::info!("Event channel closed, notification task shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
loop {
|
||||
// Read next JSON message (handles multi-line JSON from clients)
|
||||
let request = match json_reader.read_message(&mut reader).await {
|
||||
Ok(ReadResult::Message(msg)) => msg,
|
||||
Ok(ReadResult::Eof) => {
|
||||
tracing::info!("Stdin closed, shutting down");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "Failed to read JSON-RPC request from stdin");
|
||||
// Send error response
|
||||
let error_response = serde_json::json!({
|
||||
"jsonrpc": "2.0",
|
||||
"error": {
|
||||
"code": -32700,
|
||||
"message": "Parse error",
|
||||
"data": e
|
||||
},
|
||||
"id": null
|
||||
});
|
||||
let error_json = serde_json::to_string(&error_response).unwrap();
|
||||
let mut stdout_guard = stdout.lock().await;
|
||||
stdout_guard.write_all(error_json.as_bytes()).await
|
||||
.map_err(crate::MockerError::Transport)?;
|
||||
stdout_guard.write_all(b"\n").await
|
||||
.map_err(crate::MockerError::Transport)?;
|
||||
stdout_guard.flush().await.map_err(crate::MockerError::Transport)?;
|
||||
drop(stdout_guard);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let masked_request = dirigent_protocol::log_utils::mask_json_string(&request.to_string());
|
||||
tracing::info!(request = %masked_request, "Received JSON-RPC request");
|
||||
|
||||
// Handle the request using the same handler as HTTP transport
|
||||
let response = super::server::handle_jsonrpc_inner(state.clone(), request).await;
|
||||
|
||||
// Check if this is a notification (no response expected)
|
||||
if response.is_none() {
|
||||
tracing::debug!("Request was a notification, no response sent");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Write the response to stdout
|
||||
let response_json = serde_json::to_string(&response.unwrap()).map_err(|e| {
|
||||
tracing::error!(error = %e, "Failed to serialize response");
|
||||
crate::MockerError::Internal(format!("Failed to serialize response: {}", e))
|
||||
})?;
|
||||
|
||||
let masked_response = dirigent_protocol::log_utils::mask_json_string(&response_json);
|
||||
tracing::info!(response = %masked_response, "Sending JSON-RPC response");
|
||||
|
||||
let mut stdout_guard = stdout.lock().await;
|
||||
stdout_guard.write_all(response_json.as_bytes()).await
|
||||
.map_err(crate::MockerError::Transport)?;
|
||||
stdout_guard.write_all(b"\n").await
|
||||
.map_err(crate::MockerError::Transport)?;
|
||||
stdout_guard.flush().await.map_err(crate::MockerError::Transport)?;
|
||||
drop(stdout_guard);
|
||||
|
||||
tracing::info!("Response sent and flushed");
|
||||
}
|
||||
|
||||
// Clean up notification task
|
||||
notification_task.abort();
|
||||
|
||||
tracing::info!("Stdio server shutting down (stdin closed)");
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,554 @@
|
||||
//! Server-Sent Events (SSE) streaming for real-time updates.
|
||||
//!
|
||||
//! This module handles streaming of ACP events to clients using SSE,
|
||||
//! allowing real-time delivery of message chunks, tool executions,
|
||||
//! and other protocol events.
|
||||
//!
|
||||
//! It also provides text chunking and timing logic for simulating
|
||||
//! streaming responses with configurable delays and jitter.
|
||||
|
||||
use crate::Result;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_chacha::ChaCha8Rng;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tokio_stream::Stream;
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Configuration for streaming behavior.
|
||||
///
|
||||
/// Controls how text is chunked and the timing between chunk emissions.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamConfig {
|
||||
/// Number of tokens (words) per chunk.
|
||||
pub tokens_per_chunk: usize,
|
||||
|
||||
/// Base interval between chunks in milliseconds.
|
||||
pub chunk_interval_ms: u64,
|
||||
|
||||
/// Optional random jitter to add/subtract from chunk interval (in milliseconds).
|
||||
/// The actual jitter will be uniformly distributed in the range [-jitter_ms, +jitter_ms].
|
||||
pub jitter_ms: Option<u64>,
|
||||
}
|
||||
|
||||
impl StreamConfig {
|
||||
/// Create a new stream configuration.
|
||||
pub fn new(tokens_per_chunk: usize, chunk_interval_ms: u64) -> Self {
|
||||
Self {
|
||||
tokens_per_chunk,
|
||||
chunk_interval_ms,
|
||||
jitter_ms: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new stream configuration with jitter.
|
||||
pub fn with_jitter(tokens_per_chunk: usize, chunk_interval_ms: u64, jitter_ms: u64) -> Self {
|
||||
Self {
|
||||
tokens_per_chunk,
|
||||
chunk_interval_ms,
|
||||
jitter_ms: Some(jitter_ms),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for StreamConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
tokens_per_chunk: 5,
|
||||
chunk_interval_ms: 100,
|
||||
jitter_ms: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Text Chunking
|
||||
// ============================================================================
|
||||
|
||||
/// Splits text into chunks of approximately N tokens (words).
|
||||
///
|
||||
/// Uses whitespace splitting as a simplified proxy for tokenization.
|
||||
/// Chunks preserve word boundaries (no mid-word splits).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `text` - The text to chunk
|
||||
/// * `tokens_per_chunk` - Approximate number of words per chunk
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of text chunks. If `tokens_per_chunk` is 0 or text is empty,
|
||||
/// returns a single-element vector containing the entire text.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let text = "Hello world this is a test";
|
||||
/// let chunks = chunk_text(text, 2);
|
||||
/// assert_eq!(chunks, vec!["Hello world", "this is", "a test"]);
|
||||
/// ```
|
||||
pub fn chunk_text(text: &str, tokens_per_chunk: usize) -> Vec<String> {
|
||||
if tokens_per_chunk == 0 || text.is_empty() {
|
||||
return vec![text.to_string()];
|
||||
}
|
||||
|
||||
let words: Vec<&str> = text.split_whitespace().collect();
|
||||
|
||||
if words.is_empty() {
|
||||
return vec![text.to_string()];
|
||||
}
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
let mut current_chunk = Vec::new();
|
||||
|
||||
for word in words {
|
||||
current_chunk.push(word);
|
||||
|
||||
if current_chunk.len() >= tokens_per_chunk {
|
||||
chunks.push(current_chunk.join(" "));
|
||||
current_chunk.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// Add remaining words as final chunk
|
||||
if !current_chunk.is_empty() {
|
||||
chunks.push(current_chunk.join(" "));
|
||||
}
|
||||
|
||||
chunks
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Stream Controller
|
||||
// ============================================================================
|
||||
|
||||
/// Controls the streaming of text chunks with configurable timing and cancellation.
|
||||
///
|
||||
/// The stream controller manages the emission of text chunks with configured
|
||||
/// delays and optional jitter. It supports cancellation to stop streaming
|
||||
/// in response to client requests.
|
||||
#[derive(Clone)]
|
||||
pub struct StreamController {
|
||||
/// Streaming configuration
|
||||
config: StreamConfig,
|
||||
|
||||
/// Random number generator for jitter (seeded for reproducibility)
|
||||
rng: Arc<std::sync::Mutex<ChaCha8Rng>>,
|
||||
|
||||
/// Cancellation flag shared across stream instances
|
||||
cancelled: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl StreamController {
|
||||
/// Create a new stream controller with the given configuration.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - Streaming configuration
|
||||
/// * `seed` - Seed for the random number generator (for reproducible jitter)
|
||||
pub fn new(config: StreamConfig, seed: u64) -> Self {
|
||||
Self {
|
||||
config,
|
||||
rng: Arc::new(std::sync::Mutex::new(ChaCha8Rng::seed_from_u64(seed))),
|
||||
cancelled: Arc::new(AtomicBool::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Cancel the stream.
|
||||
///
|
||||
/// Sets the cancellation flag, causing the stream to stop emitting
|
||||
/// chunks after the current chunk completes.
|
||||
pub fn cancel(&self) {
|
||||
self.cancelled.store(true, Ordering::SeqCst);
|
||||
tracing::debug!("Stream cancellation requested");
|
||||
}
|
||||
|
||||
/// Check if the stream has been cancelled.
|
||||
pub fn is_cancelled(&self) -> bool {
|
||||
self.cancelled.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Reset the cancellation flag.
|
||||
///
|
||||
/// This allows the controller to be reused for a new stream.
|
||||
pub fn reset(&self) {
|
||||
self.cancelled.store(false, Ordering::SeqCst);
|
||||
tracing::debug!("Stream cancellation flag reset");
|
||||
}
|
||||
|
||||
/// Stream text chunks with configured timing.
|
||||
///
|
||||
/// Emits chunks with the configured interval between each, applying
|
||||
/// random jitter if configured. Checks the cancellation flag before
|
||||
/// each emission and stops immediately if cancelled.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `chunks` - The chunks to stream
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An async stream that yields chunks as Strings
|
||||
pub fn stream_chunks(
|
||||
&self,
|
||||
chunks: Vec<String>,
|
||||
) -> impl Stream<Item = String> + '_ {
|
||||
let cancelled = self.cancelled.clone();
|
||||
let rng = self.rng.clone();
|
||||
let config = self.config.clone();
|
||||
let total_chunks = chunks.len();
|
||||
|
||||
async_stream::stream! {
|
||||
for (idx, chunk) in chunks.into_iter().enumerate() {
|
||||
// Check cancellation before each chunk
|
||||
if cancelled.load(Ordering::SeqCst) {
|
||||
tracing::info!(chunks_emitted = idx, "Stream cancelled");
|
||||
break;
|
||||
}
|
||||
|
||||
// Calculate delay with optional jitter
|
||||
let base_delay = config.chunk_interval_ms;
|
||||
let delay_ms = if let Some(jitter) = config.jitter_ms {
|
||||
let mut rng = rng.lock().unwrap();
|
||||
let jitter_amount = rng.gen_range(-(jitter as i64)..=(jitter as i64));
|
||||
(base_delay as i64 + jitter_amount).max(0) as u64
|
||||
} else {
|
||||
base_delay
|
||||
};
|
||||
|
||||
tracing::debug!(
|
||||
chunk_idx = idx,
|
||||
delay_ms,
|
||||
chunk_len = chunk.len(),
|
||||
"Emitting chunk"
|
||||
);
|
||||
|
||||
yield chunk;
|
||||
|
||||
// Sleep after emitting chunk (except for last chunk)
|
||||
if idx < total_chunks - 1 {
|
||||
sleep(Duration::from_millis(delay_ms)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for StreamController {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("StreamController")
|
||||
.field("config", &self.config)
|
||||
.field("cancelled", &self.cancelled.load(Ordering::SeqCst))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE Event (Legacy from previous implementation)
|
||||
// ============================================================================
|
||||
|
||||
/// SSE event sent to clients.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamEvent {
|
||||
/// Event type (e.g., "message", "tool_call", "error").
|
||||
pub event: String,
|
||||
/// Event data (JSON-serialized).
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
impl StreamEvent {
|
||||
/// Create a new stream event.
|
||||
pub fn new(event: impl Into<String>, data: impl Into<String>) -> Self {
|
||||
Self {
|
||||
event: event.into(),
|
||||
data: data.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Format as SSE text.
|
||||
pub fn to_sse(&self) -> String {
|
||||
format!("event: {}\ndata: {}\n\n", self.event, self.data)
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle SSE connection for a session.
|
||||
///
|
||||
/// This function manages the lifecycle of an SSE connection,
|
||||
/// streaming events from fixtures to the connected client.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - ID of the session to stream
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An async stream of SSE events.
|
||||
pub async fn stream_session(_session_id: String) -> Result<()> {
|
||||
// TODO: Implement SSE streaming using StreamController
|
||||
// - Subscribe to session events from fixture responder
|
||||
// - Convert events to SSE format
|
||||
// - Handle client disconnection
|
||||
// - Clean up resources
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
// ========================================================================
|
||||
// Text Chunking Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_chunk_text_basic() {
|
||||
let text = "Hello world this is a test";
|
||||
let chunks = chunk_text(text, 2);
|
||||
assert_eq!(chunks, vec!["Hello world", "this is", "a test"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_text_single_chunk() {
|
||||
let text = "Hello world";
|
||||
let chunks = chunk_text(text, 5);
|
||||
assert_eq!(chunks, vec!["Hello world"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_text_exact_fit() {
|
||||
let text = "one two three four";
|
||||
let chunks = chunk_text(text, 2);
|
||||
assert_eq!(chunks, vec!["one two", "three four"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_text_empty() {
|
||||
let text = "";
|
||||
let chunks = chunk_text(text, 2);
|
||||
assert_eq!(chunks, vec![""]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_text_zero_tokens_per_chunk() {
|
||||
let text = "Hello world";
|
||||
let chunks = chunk_text(text, 0);
|
||||
assert_eq!(chunks, vec!["Hello world"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_text_single_word() {
|
||||
let text = "Hello";
|
||||
let chunks = chunk_text(text, 2);
|
||||
assert_eq!(chunks, vec!["Hello"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_text_multiple_spaces() {
|
||||
let text = "Hello world test";
|
||||
let chunks = chunk_text(text, 2);
|
||||
assert_eq!(chunks, vec!["Hello world", "test"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_text_preserves_word_boundaries() {
|
||||
let text = "supercalifragilisticexpialidocious test";
|
||||
let chunks = chunk_text(text, 1);
|
||||
assert_eq!(chunks, vec!["supercalifragilisticexpialidocious", "test"]);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// StreamController Tests
|
||||
// ========================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_controller_basic() {
|
||||
let config = StreamConfig::new(2, 10);
|
||||
let controller = StreamController::new(config, 42);
|
||||
|
||||
let chunks = vec!["chunk1".to_string(), "chunk2".to_string(), "chunk3".to_string()];
|
||||
let stream = controller.stream_chunks(chunks);
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
results.push(chunk);
|
||||
}
|
||||
|
||||
assert_eq!(results, vec!["chunk1", "chunk2", "chunk3"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_controller_timing() {
|
||||
let config = StreamConfig::new(2, 50); // 50ms delay
|
||||
let controller = StreamController::new(config, 42);
|
||||
|
||||
let chunks = vec!["chunk1".to_string(), "chunk2".to_string()];
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let stream = controller.stream_chunks(chunks);
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut count = 0;
|
||||
while let Some(_chunk) = stream.next().await {
|
||||
count += 1;
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
assert_eq!(count, 2);
|
||||
// Should take at least 50ms for the delay between chunks
|
||||
// (we allow some tolerance for timing variations)
|
||||
assert!(elapsed.as_millis() >= 40, "Expected at least 40ms, got {}ms", elapsed.as_millis());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_controller_with_jitter() {
|
||||
let config = StreamConfig::with_jitter(2, 50, 10);
|
||||
let controller = StreamController::new(config, 42);
|
||||
|
||||
let chunks = vec!["chunk1".to_string(), "chunk2".to_string(), "chunk3".to_string()];
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let stream = controller.stream_chunks(chunks);
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut count = 0;
|
||||
while let Some(_chunk) = stream.next().await {
|
||||
count += 1;
|
||||
}
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
assert_eq!(count, 3);
|
||||
// With jitter of ±10ms on 50ms base, we expect roughly 80-120ms total
|
||||
// (2 delays between 3 chunks)
|
||||
assert!(elapsed.as_millis() >= 60, "Timing too short: {}ms", elapsed.as_millis());
|
||||
assert!(elapsed.as_millis() <= 200, "Timing too long: {}ms", elapsed.as_millis());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_controller_cancellation() {
|
||||
let config = StreamConfig::new(2, 20);
|
||||
let controller = StreamController::new(config, 42);
|
||||
|
||||
let chunks = vec![
|
||||
"chunk1".to_string(),
|
||||
"chunk2".to_string(),
|
||||
"chunk3".to_string(),
|
||||
"chunk4".to_string(),
|
||||
];
|
||||
|
||||
let controller_clone = controller.clone();
|
||||
|
||||
// Spawn a task to cancel after 30ms
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_millis(30)).await;
|
||||
controller_clone.cancel();
|
||||
});
|
||||
|
||||
let stream = controller.stream_chunks(chunks);
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut results = Vec::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
results.push(chunk);
|
||||
}
|
||||
|
||||
// Should be cancelled before all chunks are emitted
|
||||
assert!(results.len() < 4, "Expected cancellation, got {} chunks", results.len());
|
||||
assert!(results.len() >= 1, "Should have at least 1 chunk before cancellation");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_controller_empty_chunks() {
|
||||
let config = StreamConfig::new(2, 10);
|
||||
let controller = StreamController::new(config, 42);
|
||||
|
||||
let chunks: Vec<String> = vec![];
|
||||
let stream = controller.stream_chunks(chunks);
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut count = 0;
|
||||
while let Some(_chunk) = stream.next().await {
|
||||
count += 1;
|
||||
}
|
||||
|
||||
assert_eq!(count, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_controller_reset_cancellation() {
|
||||
let config = StreamConfig::new(2, 10);
|
||||
let controller = StreamController::new(config, 42);
|
||||
|
||||
controller.cancel();
|
||||
assert!(controller.is_cancelled());
|
||||
|
||||
controller.reset();
|
||||
assert!(!controller.is_cancelled());
|
||||
|
||||
// Should be able to stream after reset
|
||||
let chunks = vec!["chunk1".to_string(), "chunk2".to_string()];
|
||||
let stream = controller.stream_chunks(chunks);
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut count = 0;
|
||||
while let Some(_chunk) = stream.next().await {
|
||||
count += 1;
|
||||
}
|
||||
|
||||
assert_eq!(count, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_stream_controller_jitter_reproducibility() {
|
||||
let config = StreamConfig::with_jitter(2, 50, 10);
|
||||
|
||||
// Same seed should produce same jitter sequence
|
||||
let controller1 = StreamController::new(config.clone(), 42);
|
||||
let controller2 = StreamController::new(config, 42);
|
||||
|
||||
let chunks1 = vec!["a".to_string(), "b".to_string(), "c".to_string()];
|
||||
let chunks2 = vec!["a".to_string(), "b".to_string(), "c".to_string()];
|
||||
|
||||
let start1 = std::time::Instant::now();
|
||||
let stream1 = controller1.stream_chunks(chunks1);
|
||||
tokio::pin!(stream1);
|
||||
while let Some(_) = stream1.next().await {}
|
||||
let elapsed1 = start1.elapsed();
|
||||
|
||||
let start2 = std::time::Instant::now();
|
||||
let stream2 = controller2.stream_chunks(chunks2);
|
||||
tokio::pin!(stream2);
|
||||
while let Some(_) = stream2.next().await {}
|
||||
let elapsed2 = start2.elapsed();
|
||||
|
||||
// Same seed should produce similar timing (within 20ms tolerance)
|
||||
let diff = if elapsed1 > elapsed2 {
|
||||
elapsed1 - elapsed2
|
||||
} else {
|
||||
elapsed2 - elapsed1
|
||||
};
|
||||
assert!(diff.as_millis() < 20, "Timing difference too large: {}ms", diff.as_millis());
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// StreamEvent Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_stream_event_formatting() {
|
||||
let event = StreamEvent::new("message", r#"{"text":"hello"}"#);
|
||||
assert_eq!(
|
||||
event.to_sse(),
|
||||
"event: message\ndata: {\"text\":\"hello\"}\n\n"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
//! # dirigate
|
||||
//!
|
||||
//! Dirigate - ACP bridge and mock server CLI tool.
|
||||
//!
|
||||
//! This binary provides a CLI interface for:
|
||||
//! - Running an ACP mock server with fixture-based responses
|
||||
//! - Bridging stdio ACP clients to a Dirigent ACP Server via HTTP/SSE
|
||||
//! - Connecting to ACP agents as an interactive client
|
||||
//!
|
||||
//! ## Bridge Mode
|
||||
//!
|
||||
//! The bridge mode allows external ACP clients (like Claude Code configured for stdio)
|
||||
//! to connect to a Dirigent ACP Server over HTTP/SSE:
|
||||
//!
|
||||
//! ```text
|
||||
//! External ACP Client (Claude Code, etc.)
|
||||
//! |
|
||||
//! | stdio (stdin/stdout)
|
||||
//! v
|
||||
//! +-------------------+
|
||||
//! | Dirigate Bridge |
|
||||
//! | - stdin parser |
|
||||
//! | - HTTP client |
|
||||
//! | - SSE subscriber |
|
||||
//! +-------------------+
|
||||
//! |
|
||||
//! | HTTP/SSE
|
||||
//! v
|
||||
//! Dirigent ACP Server
|
||||
//! ```
|
||||
|
||||
use dirigate::{
|
||||
cli::{execute_command, parse_log_format, Cli},
|
||||
logging::init_logging,
|
||||
};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Parse CLI arguments
|
||||
let cli = Cli::parse_args();
|
||||
|
||||
// Set log level via RUST_LOG if not already set
|
||||
if std::env::var("RUST_LOG").is_err() {
|
||||
std::env::set_var("RUST_LOG", &cli.log_level);
|
||||
}
|
||||
|
||||
// Initialize logging (must be done after setting RUST_LOG)
|
||||
let log_format = parse_log_format(&cli.log_format);
|
||||
init_logging(log_format);
|
||||
|
||||
tracing::info!("dirigate v{}", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
// Execute command and handle errors
|
||||
if let Err(e) = execute_command(cli.command).await {
|
||||
eprintln!("Error: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
+193
@@ -0,0 +1,193 @@
|
||||
//! CLI argument definitions.
|
||||
//!
|
||||
//! This module defines the command-line interface structure using `clap`.
|
||||
|
||||
use clap::{Parser, Subcommand, ValueEnum};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Default ACP Server URL for bridge mode.
|
||||
/// By default, ACP is nested at /acp on the main Dioxus server (port 8080).
|
||||
/// Use DIRIGENT_ACP_PORT env var to run ACP on a separate port.
|
||||
pub const DEFAULT_SERVER_URL: &str = "http://localhost:8080/acp";
|
||||
|
||||
/// Dirigate - ACP bridge and mock server for testing and proxying ACP connections.
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "dirigate")]
|
||||
#[command(version, about, long_about = None)]
|
||||
pub struct Cli {
|
||||
/// Log level (trace, debug, info, warn, error).
|
||||
#[arg(short, long, default_value = "info", global = true)]
|
||||
pub log_level: String,
|
||||
|
||||
/// Log format (pretty, json, compact).
|
||||
#[arg(long, default_value = "pretty", global = true)]
|
||||
pub log_format: String,
|
||||
|
||||
/// Subcommand to execute.
|
||||
#[command(subcommand)]
|
||||
pub command: Commands,
|
||||
}
|
||||
|
||||
/// Available commands.
|
||||
#[derive(Subcommand, Debug)]
|
||||
pub enum Commands {
|
||||
/// Start the mock server with fixture-based responses.
|
||||
Serve {
|
||||
/// Directory or file containing fixture YAML files.
|
||||
#[arg(short, long, default_value = "./fixtures")]
|
||||
fixtures: PathBuf,
|
||||
|
||||
/// Use stdin/stdout for JSON-RPC transport (for Zed and similar editors).
|
||||
#[arg(long)]
|
||||
stdio: bool,
|
||||
|
||||
/// Port to bind the server to (ignored if --stdio is set).
|
||||
#[arg(short, long, default_value_t = 8080)]
|
||||
port: u16,
|
||||
|
||||
/// Host address to bind to (ignored if --stdio is set).
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
|
||||
/// Tokens per chunk for streaming (overrides fixture setting).
|
||||
#[arg(long)]
|
||||
tokens_per_chunk: Option<usize>,
|
||||
|
||||
/// Milliseconds between chunks for streaming (overrides fixture setting).
|
||||
#[arg(long)]
|
||||
chunk_interval_ms: Option<u64>,
|
||||
|
||||
/// Enable or disable streaming (overrides fixture setting).
|
||||
#[arg(long)]
|
||||
streaming: Option<bool>,
|
||||
},
|
||||
|
||||
/// Validate fixture files without starting the server.
|
||||
Validate {
|
||||
/// Directory or file to validate. Can be specified multiple times.
|
||||
#[arg(value_name = "PATH")]
|
||||
paths: Vec<PathBuf>,
|
||||
},
|
||||
|
||||
/// Print fixture contents in various formats.
|
||||
Print {
|
||||
/// Path to the fixture file.
|
||||
#[arg(value_name = "FIXTURE")]
|
||||
fixture: PathBuf,
|
||||
|
||||
/// Output format.
|
||||
#[arg(short, long, value_enum, default_value_t = PrintFormat::Table)]
|
||||
format: PrintFormat,
|
||||
},
|
||||
|
||||
/// Ingest sessions from external sources (requires 'ingest' feature).
|
||||
#[cfg(feature = "ingest")]
|
||||
Ingest {
|
||||
/// Base URL of the OpenCode API.
|
||||
#[arg(short = 'u', long)]
|
||||
base_url: String,
|
||||
|
||||
/// Specific session ID to ingest (if not provided, use --all).
|
||||
#[arg(short = 's', long)]
|
||||
session_id: Option<String>,
|
||||
|
||||
/// Ingest all sessions from the API.
|
||||
#[arg(short = 'a', long)]
|
||||
all: bool,
|
||||
|
||||
/// Output file path for the generated fixture.
|
||||
#[arg(short = 'o', long)]
|
||||
output: PathBuf,
|
||||
|
||||
/// Merge with existing fixture file if it exists.
|
||||
#[arg(short = 'm', long)]
|
||||
merge: bool,
|
||||
},
|
||||
|
||||
/// Connect to an ACP agent as a client (interactive mode).
|
||||
Connect {
|
||||
/// Command to spawn for stdio transport (e.g., "claude --acp").
|
||||
#[arg(long, conflicts_with = "url")]
|
||||
command: Option<String>,
|
||||
|
||||
/// HTTP URL for HTTP+SSE transport (e.g., "http://localhost:8080").
|
||||
#[arg(long, conflicts_with = "command")]
|
||||
url: Option<String>,
|
||||
|
||||
/// Protocol version to use.
|
||||
#[arg(long, default_value = "2025-01-01")]
|
||||
protocol_version: String,
|
||||
|
||||
/// Automatically create a new session on connect.
|
||||
#[arg(long)]
|
||||
auto_session: bool,
|
||||
},
|
||||
|
||||
/// Bridge stdio ACP client to a Dirigent ACP Server via HTTP/SSE.
|
||||
///
|
||||
/// This mode allows external ACP clients (like Claude Code configured for stdio)
|
||||
/// to connect to a Dirigent ACP Server. The bridge reads JSON-RPC requests from
|
||||
/// stdin, forwards them to the server via HTTP, and writes responses and SSE
|
||||
/// notifications to stdout.
|
||||
///
|
||||
/// ## Example Usage
|
||||
///
|
||||
/// ```bash
|
||||
/// # Connect to local server (default: http://localhost:8080/acp)
|
||||
/// dirigate bridge
|
||||
///
|
||||
/// # Connect to remote server (use actual Dioxus server port)
|
||||
/// dirigate bridge --server-url http://remote:8080/acp
|
||||
///
|
||||
/// # Connect to ACP on separate port (if configured with DIRIGENT_ACP_PORT)
|
||||
/// dirigate bridge --server-url http://localhost:3001/acp
|
||||
///
|
||||
/// # Via environment variable
|
||||
/// DIRIGENT_SERVER_URL=http://remote:8080/acp dirigate bridge
|
||||
/// ```
|
||||
Bridge {
|
||||
/// ACP Server URL to connect to.
|
||||
///
|
||||
/// Can also be set via DIRIGENT_SERVER_URL environment variable.
|
||||
#[arg(short = 's', long, env = "DIRIGENT_SERVER_URL", default_value = DEFAULT_SERVER_URL)]
|
||||
server_url: String,
|
||||
|
||||
/// Enable verbose logging of JSON-RPC messages.
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
|
||||
/// Timeout in seconds for HTTP requests to the server.
|
||||
#[arg(long, default_value_t = 30)]
|
||||
timeout: u64,
|
||||
|
||||
/// Automatically reconnect SSE stream on disconnect.
|
||||
#[arg(long, default_value_t = true)]
|
||||
auto_reconnect: bool,
|
||||
|
||||
/// Select a specific connector by ID or agent type magic word.
|
||||
///
|
||||
/// Magic words: claude, codex, gemini
|
||||
/// When set, sessions will be routed directly to a connector of this type,
|
||||
/// bypassing the gateway.
|
||||
#[arg(long)]
|
||||
select_connector: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Output format for the print command.
|
||||
#[derive(Debug, Clone, Copy, ValueEnum)]
|
||||
pub enum PrintFormat {
|
||||
/// Human-readable table format.
|
||||
Table,
|
||||
/// JSON format.
|
||||
Json,
|
||||
/// YAML format.
|
||||
Yaml,
|
||||
}
|
||||
|
||||
impl Cli {
|
||||
/// Parse CLI arguments from the environment.
|
||||
pub fn parse_args() -> Self {
|
||||
Self::parse()
|
||||
}
|
||||
}
|
||||
+1274
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,18 @@
|
||||
//! Command-line interface for the ACP mocker.
|
||||
//!
|
||||
//! This module provides the CLI structure and command definitions for the
|
||||
//! dirigate binary. It uses `clap` for argument parsing and
|
||||
//! command dispatch.
|
||||
//!
|
||||
//! ## Commands
|
||||
//!
|
||||
//! - `serve` - Start the mock server
|
||||
//! - `validate` - Validate fixture files without starting server
|
||||
//! - `ingest` - Ingest sessions from external sources (feature-gated)
|
||||
|
||||
pub mod args;
|
||||
pub mod commands;
|
||||
|
||||
// Re-exports for convenience
|
||||
pub use args::*;
|
||||
pub use commands::*;
|
||||
@@ -0,0 +1,82 @@
|
||||
//! Error types for the ACP mocker.
|
||||
//!
|
||||
//! This module defines all error types that can occur during fixture loading,
|
||||
//! validation, server operation, and optional ingestion from external sources.
|
||||
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type alias using MockerError.
|
||||
pub type Result<T> = std::result::Result<T, MockerError>;
|
||||
|
||||
/// Errors that can occur in the ACP mocker.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum MockerError {
|
||||
/// Error loading fixture files from disk.
|
||||
#[error("Failed to load fixture: {0}")]
|
||||
FixtureLoad(String),
|
||||
|
||||
/// Error validating fixture structure or content.
|
||||
#[error("Invalid fixture: {0}")]
|
||||
FixtureValidation(String),
|
||||
|
||||
/// Error in ACP protocol handling.
|
||||
#[error("ACP protocol error: {0}")]
|
||||
AcpProtocol(String),
|
||||
|
||||
/// Requested session was not found in fixtures.
|
||||
#[error("Session not found: {session_id}")]
|
||||
SessionNotFound { session_id: String },
|
||||
|
||||
/// Error in transport layer (HTTP, WebSocket, etc.).
|
||||
#[error("Transport error: {0}")]
|
||||
Transport(#[from] std::io::Error),
|
||||
|
||||
/// Error during session ingestion from external sources.
|
||||
#[cfg(feature = "ingest")]
|
||||
#[error("Ingestion error: {0}")]
|
||||
Ingest(String),
|
||||
|
||||
/// Error parsing YAML fixtures.
|
||||
#[error("YAML parsing error: {0}")]
|
||||
YamlParse(#[from] serde_yaml::Error),
|
||||
|
||||
/// Error parsing JSON data.
|
||||
#[error("JSON parsing error: {0}")]
|
||||
JsonParse(#[from] serde_json::Error),
|
||||
|
||||
/// Generic error for cases not covered by specific variants.
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
// Implement From for common error types
|
||||
impl From<String> for MockerError {
|
||||
fn from(s: String) -> Self {
|
||||
MockerError::Internal(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for MockerError {
|
||||
fn from(s: &str) -> Self {
|
||||
MockerError::Internal(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
let err = MockerError::SessionNotFound {
|
||||
session_id: "test-123".to_string(),
|
||||
};
|
||||
assert_eq!(err.to_string(), "Session not found: test-123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fixture_load_error() {
|
||||
let err = MockerError::FixtureLoad("file not found".to_string());
|
||||
assert_eq!(err.to_string(), "Failed to load fixture: file not found");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,593 @@
|
||||
//! Fixture loading from YAML files.
|
||||
//!
|
||||
//! This module handles loading fixture definitions from the filesystem,
|
||||
//! parsing YAML, and validating fixture structure.
|
||||
|
||||
use crate::{
|
||||
fixture::{Fixture, Message, Session},
|
||||
MockerError, Result,
|
||||
};
|
||||
use std::{collections::HashSet, path::Path};
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Load a fixture from a YAML file.
|
||||
///
|
||||
/// This function reads the file, parses it as YAML, and returns the fixture
|
||||
/// without performing validation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the YAML fixture file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `MockerError::FixtureLoad` if:
|
||||
/// - The file cannot be read
|
||||
/// - The YAML is malformed
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use dirigate::fixture::load_fixture;
|
||||
///
|
||||
/// # async fn example() -> dirigate::Result<()> {
|
||||
/// let fixture = load_fixture("fixtures/basic_session.yaml").await?;
|
||||
/// println!("Loaded fixture version: {}", fixture.version);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn load_fixture<P: AsRef<Path>>(path: P) -> Result<Fixture> {
|
||||
let path = path.as_ref();
|
||||
|
||||
// Read file contents
|
||||
let contents = tokio::fs::read_to_string(path)
|
||||
.await
|
||||
.map_err(|e| MockerError::FixtureLoad(format!("Failed to read file '{}': {}", path.display(), e)))?;
|
||||
|
||||
// Parse YAML
|
||||
let fixture: Fixture = serde_yaml::from_str(&contents)
|
||||
.map_err(|e| MockerError::FixtureLoad(format!("Failed to parse YAML in '{}': {}", path.display(), e)))?;
|
||||
|
||||
info!("Loaded fixture from '{}'", path.display());
|
||||
|
||||
Ok(fixture)
|
||||
}
|
||||
|
||||
/// Validate a fixture structure.
|
||||
///
|
||||
/// Performs comprehensive validation of the fixture including:
|
||||
/// - Version check (must be "0.1")
|
||||
/// - Duplicate ID detection (sessions, messages, participants)
|
||||
/// - Reference validation (session_id, parent_id, participant references)
|
||||
/// - Timestamp format validation (ISO8601)
|
||||
/// - Required field validation (non-empty strings)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `fixture` - The fixture to validate
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `MockerError::FixtureValidation` with a detailed message listing all validation errors.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use dirigate::fixture::{load_fixture, validate_fixture};
|
||||
///
|
||||
/// # async fn example() -> dirigate::Result<()> {
|
||||
/// let fixture = load_fixture("fixtures/basic_session.yaml").await?;
|
||||
/// validate_fixture(&fixture)?;
|
||||
/// println!("Fixture is valid!");
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn validate_fixture(fixture: &Fixture) -> Result<()> {
|
||||
let mut errors = Vec::new();
|
||||
|
||||
// 1. Validate version
|
||||
if fixture.version != "0.1" {
|
||||
errors.push(format!(
|
||||
"Invalid version '{}', expected '0.1'",
|
||||
fixture.version
|
||||
));
|
||||
}
|
||||
|
||||
// 2. Check for duplicate session IDs
|
||||
let mut session_ids = HashSet::new();
|
||||
for session in &fixture.sessions {
|
||||
if !session_ids.insert(&session.id) {
|
||||
errors.push(format!("Duplicate session ID '{}'", session.id));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Validate each session
|
||||
for session in &fixture.sessions {
|
||||
validate_session(session, &mut errors);
|
||||
}
|
||||
|
||||
// If there are errors, return them all
|
||||
if !errors.is_empty() {
|
||||
let error_message = format!(
|
||||
"Fixture validation failed with {} error(s):\n - {}",
|
||||
errors.len(),
|
||||
errors.join("\n - ")
|
||||
);
|
||||
warn!("{}", error_message);
|
||||
return Err(MockerError::FixtureValidation(error_message));
|
||||
}
|
||||
|
||||
info!("Fixture validation passed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load and validate a fixture in one operation.
|
||||
///
|
||||
/// This is a convenience function that combines `load_fixture` and `validate_fixture`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the YAML fixture file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `MockerError::FixtureLoad` or `MockerError::FixtureValidation` as appropriate.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use dirigate::fixture::load_and_validate;
|
||||
///
|
||||
/// # async fn example() -> dirigate::Result<()> {
|
||||
/// let fixture = load_and_validate("fixtures/basic_session.yaml").await?;
|
||||
/// println!("Loaded and validated fixture with {} sessions", fixture.sessions.len());
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn load_and_validate<P: AsRef<Path>>(path: P) -> Result<Fixture> {
|
||||
let path = path.as_ref();
|
||||
let fixture = load_fixture(path).await?;
|
||||
validate_fixture(&fixture)?;
|
||||
info!("Successfully loaded and validated fixture from '{}'", path.display());
|
||||
Ok(fixture)
|
||||
}
|
||||
|
||||
/// Load all fixtures from a directory.
|
||||
///
|
||||
/// Scans a directory for YAML fixture files (.yaml and .yml extensions),
|
||||
/// loads and validates each one, and returns a vector of valid fixtures.
|
||||
/// Invalid files are logged as warnings and skipped.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dir` - Directory containing YAML fixture files
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `MockerError::FixtureLoad` if the directory cannot be read.
|
||||
/// Individual file errors are logged but do not stop processing.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use dirigate::fixture::load_fixtures_from_dir;
|
||||
///
|
||||
/// # async fn example() -> dirigate::Result<()> {
|
||||
/// let fixtures = load_fixtures_from_dir("fixtures/").await?;
|
||||
/// println!("Loaded {} fixtures", fixtures.len());
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn load_fixtures_from_dir<P: AsRef<Path>>(dir: P) -> Result<Vec<Fixture>> {
|
||||
let dir = dir.as_ref();
|
||||
|
||||
// Read directory entries
|
||||
let mut entries = tokio::fs::read_dir(dir)
|
||||
.await
|
||||
.map_err(|e| MockerError::FixtureLoad(format!("Failed to read directory '{}': {}", dir.display(), e)))?;
|
||||
|
||||
let mut fixtures = Vec::new();
|
||||
|
||||
// Process each entry
|
||||
while let Some(entry) = entries.next_entry().await
|
||||
.map_err(|e| MockerError::FixtureLoad(format!("Failed to read directory entry: {}", e)))? {
|
||||
|
||||
let path = entry.path();
|
||||
|
||||
// Skip non-files
|
||||
if !path.is_file() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check file extension
|
||||
let extension = path.extension().and_then(|e| e.to_str());
|
||||
if !matches!(extension, Some("yaml") | Some("yml")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Try to load and validate the fixture
|
||||
match load_and_validate(&path).await {
|
||||
Ok(fixture) => {
|
||||
info!("Successfully loaded fixture from '{}'", path.display());
|
||||
fixtures.push(fixture);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to load fixture from '{}': {}", path.display(), e);
|
||||
// Continue processing other files
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Loaded {} fixture(s) from '{}'", fixtures.len(), dir.display());
|
||||
Ok(fixtures)
|
||||
}
|
||||
|
||||
/// Validate a single session.
|
||||
fn validate_session(session: &Session, errors: &mut Vec<String>) {
|
||||
// Validate required fields are non-empty
|
||||
if session.id.is_empty() {
|
||||
errors.push("Session has empty ID".to_string());
|
||||
}
|
||||
if session.title.is_empty() {
|
||||
errors.push(format!("Session '{}' has empty title", session.id));
|
||||
}
|
||||
|
||||
// Validate timestamp
|
||||
if !is_valid_iso8601(&session.created_at) {
|
||||
errors.push(format!(
|
||||
"Session '{}' has invalid ISO8601 timestamp: '{}'",
|
||||
session.id, session.created_at
|
||||
));
|
||||
}
|
||||
|
||||
// Check for duplicate participant IDs
|
||||
let mut participant_ids = HashSet::new();
|
||||
for participant in &session.participants {
|
||||
if participant.id.is_empty() {
|
||||
errors.push(format!(
|
||||
"Session '{}' has participant with empty ID",
|
||||
session.id
|
||||
));
|
||||
}
|
||||
if !participant_ids.insert(&participant.id) {
|
||||
errors.push(format!(
|
||||
"Session '{}' has duplicate participant ID '{}'",
|
||||
session.id, participant.id
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for duplicate message IDs
|
||||
let mut message_ids = HashSet::new();
|
||||
for message in &session.messages {
|
||||
if !message_ids.insert(&message.id) {
|
||||
errors.push(format!(
|
||||
"Session '{}' has duplicate message ID '{}'",
|
||||
session.id, message.id
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Validate each message
|
||||
for message in &session.messages {
|
||||
validate_message(message, session, &participant_ids, &message_ids, errors);
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a single message.
|
||||
fn validate_message(
|
||||
message: &Message,
|
||||
session: &Session,
|
||||
_participant_ids: &HashSet<&String>,
|
||||
all_message_ids: &HashSet<&String>,
|
||||
errors: &mut Vec<String>,
|
||||
) {
|
||||
// Validate required fields
|
||||
if message.id.is_empty() {
|
||||
errors.push(format!(
|
||||
"Session '{}' has message with empty ID",
|
||||
session.id
|
||||
));
|
||||
}
|
||||
if message.content.is_empty() {
|
||||
errors.push(format!(
|
||||
"Message '{}' in session '{}' has empty content",
|
||||
message.id, session.id
|
||||
));
|
||||
}
|
||||
|
||||
// Validate session_id matches parent session
|
||||
if message.session_id != session.id {
|
||||
errors.push(format!(
|
||||
"Message '{}' has session_id '{}' but is in session '{}'",
|
||||
message.id, message.session_id, session.id
|
||||
));
|
||||
}
|
||||
|
||||
// Validate timestamp
|
||||
if !is_valid_iso8601(&message.created_at) {
|
||||
errors.push(format!(
|
||||
"Message '{}' has invalid ISO8601 timestamp: '{}'",
|
||||
message.id, message.created_at
|
||||
));
|
||||
}
|
||||
|
||||
// Validate parent_id reference if present
|
||||
if let Some(parent_id) = &message.parent_id {
|
||||
if !all_message_ids.contains(parent_id) {
|
||||
errors.push(format!(
|
||||
"Message '{}' references non-existent parent_id '{}'",
|
||||
message.id, parent_id
|
||||
));
|
||||
}
|
||||
// Check for self-reference
|
||||
if parent_id == &message.id {
|
||||
errors.push(format!(
|
||||
"Message '{}' cannot reference itself as parent",
|
||||
message.id
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Note: We don't validate that message role corresponds to a participant
|
||||
// because roles (user/assistant/system) are independent of participant IDs
|
||||
// in this schema. Participants are entities, roles are message types.
|
||||
}
|
||||
|
||||
/// Check if a string is a valid ISO8601 timestamp.
|
||||
///
|
||||
/// This performs a basic format validation. It checks for common ISO8601 patterns:
|
||||
/// - YYYY-MM-DDTHH:MM:SS
|
||||
/// - YYYY-MM-DDTHH:MM:SSZ
|
||||
/// - YYYY-MM-DDTHH:MM:SS+HH:MM
|
||||
/// - YYYY-MM-DDTHH:MM:SS.sss...
|
||||
///
|
||||
/// For production use, you might want to use a proper datetime parser like `chrono` or `time`.
|
||||
fn is_valid_iso8601(timestamp: &str) -> bool {
|
||||
// Basic regex-like validation without adding regex dependency
|
||||
// ISO8601 basic format: YYYY-MM-DDTHH:MM:SS with optional timezone/fractional seconds
|
||||
|
||||
if timestamp.len() < 19 {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check basic structure: YYYY-MM-DDTHH:MM:SS
|
||||
let chars: Vec<char> = timestamp.chars().collect();
|
||||
|
||||
// Year (4 digits)
|
||||
if !chars[0..4].iter().all(|c| c.is_ascii_digit()) {
|
||||
return false;
|
||||
}
|
||||
// Dash
|
||||
if chars.get(4) != Some(&'-') {
|
||||
return false;
|
||||
}
|
||||
// Month (2 digits)
|
||||
if !chars[5..7].iter().all(|c| c.is_ascii_digit()) {
|
||||
return false;
|
||||
}
|
||||
// Dash
|
||||
if chars.get(7) != Some(&'-') {
|
||||
return false;
|
||||
}
|
||||
// Day (2 digits)
|
||||
if !chars[8..10].iter().all(|c| c.is_ascii_digit()) {
|
||||
return false;
|
||||
}
|
||||
// T separator
|
||||
if chars.get(10) != Some(&'T') {
|
||||
return false;
|
||||
}
|
||||
// Hour (2 digits)
|
||||
if !chars[11..13].iter().all(|c| c.is_ascii_digit()) {
|
||||
return false;
|
||||
}
|
||||
// Colon
|
||||
if chars.get(13) != Some(&':') {
|
||||
return false;
|
||||
}
|
||||
// Minute (2 digits)
|
||||
if !chars[14..16].iter().all(|c| c.is_ascii_digit()) {
|
||||
return false;
|
||||
}
|
||||
// Colon
|
||||
if chars.get(16) != Some(&':') {
|
||||
return false;
|
||||
}
|
||||
// Second (2 digits)
|
||||
if !chars[17..19].iter().all(|c| c.is_ascii_digit()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// The rest is optional (fractional seconds, timezone)
|
||||
// We'll accept anything after the basic format
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::fixture::{
|
||||
Message, MessageRole, Participant, ParticipantKind, ResponderStrategy,
|
||||
Responders, Session, Streaming,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_minimal_valid_fixture() -> Fixture {
|
||||
Fixture {
|
||||
version: "0.1".to_string(),
|
||||
sessions: vec![Session {
|
||||
id: "session-1".to_string(),
|
||||
title: "Test Session".to_string(),
|
||||
created_at: "2024-01-01T00:00:00Z".to_string(),
|
||||
participants: vec![
|
||||
Participant {
|
||||
id: "user-1".to_string(),
|
||||
kind: ParticipantKind::User,
|
||||
display_name: Some("Test User".to_string()),
|
||||
},
|
||||
Participant {
|
||||
id: "assistant-1".to_string(),
|
||||
kind: ParticipantKind::Assistant,
|
||||
display_name: Some("Test Assistant".to_string()),
|
||||
},
|
||||
],
|
||||
messages: vec![Message {
|
||||
id: "msg-1".to_string(),
|
||||
session_id: "session-1".to_string(),
|
||||
role: MessageRole::User,
|
||||
content: "Hello".to_string(),
|
||||
created_at: "2024-01-01T00:00:01Z".to_string(),
|
||||
parent_id: None,
|
||||
metadata: None,
|
||||
}],
|
||||
behavior: None,
|
||||
}],
|
||||
responders: Responders {
|
||||
keyword_map: HashMap::new(),
|
||||
default_strategy: ResponderStrategy::Echo,
|
||||
random: None,
|
||||
},
|
||||
streaming: Streaming {
|
||||
enabled: false,
|
||||
tokens_per_chunk: 1,
|
||||
chunk_interval_ms: 100,
|
||||
jitter_ms: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_valid_fixture() {
|
||||
let fixture = create_minimal_valid_fixture();
|
||||
assert!(validate_fixture(&fixture).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_invalid_version() {
|
||||
let mut fixture = create_minimal_valid_fixture();
|
||||
fixture.version = "0.2".to_string();
|
||||
|
||||
let result = validate_fixture(&fixture);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("Invalid version"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_duplicate_session_ids() {
|
||||
let mut fixture = create_minimal_valid_fixture();
|
||||
let session2 = fixture.sessions[0].clone();
|
||||
fixture.sessions.push(session2);
|
||||
|
||||
let result = validate_fixture(&fixture);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("Duplicate session ID"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_duplicate_message_ids() {
|
||||
let mut fixture = create_minimal_valid_fixture();
|
||||
let message2 = fixture.sessions[0].messages[0].clone();
|
||||
fixture.sessions[0].messages.push(message2);
|
||||
|
||||
let result = validate_fixture(&fixture);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
let err_str = err.to_string();
|
||||
assert!(err_str.contains("duplicate message ID"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_mismatched_session_id() {
|
||||
let mut fixture = create_minimal_valid_fixture();
|
||||
fixture.sessions[0].messages[0].session_id = "wrong-session".to_string();
|
||||
|
||||
let result = validate_fixture(&fixture);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("has session_id"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_invalid_parent_reference() {
|
||||
let mut fixture = create_minimal_valid_fixture();
|
||||
fixture.sessions[0].messages[0].parent_id = Some("non-existent".to_string());
|
||||
|
||||
let result = validate_fixture(&fixture);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("non-existent parent_id"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_valid_parent_reference() {
|
||||
let mut fixture = create_minimal_valid_fixture();
|
||||
let mut message2 = fixture.sessions[0].messages[0].clone();
|
||||
message2.id = "msg-2".to_string();
|
||||
message2.parent_id = Some("msg-1".to_string());
|
||||
fixture.sessions[0].messages.push(message2);
|
||||
|
||||
assert!(validate_fixture(&fixture).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_self_reference_parent() {
|
||||
let mut fixture = create_minimal_valid_fixture();
|
||||
fixture.sessions[0].messages[0].parent_id = Some("msg-1".to_string());
|
||||
|
||||
let result = validate_fixture(&fixture);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("cannot reference itself"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_empty_required_fields() {
|
||||
let mut fixture = create_minimal_valid_fixture();
|
||||
fixture.sessions[0].id = "".to_string();
|
||||
|
||||
let result = validate_fixture(&fixture);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("empty ID"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_iso8601() {
|
||||
assert!(is_valid_iso8601("2024-01-01T00:00:00Z"));
|
||||
assert!(is_valid_iso8601("2024-01-01T00:00:00"));
|
||||
assert!(is_valid_iso8601("2024-01-01T00:00:00.123Z"));
|
||||
assert!(is_valid_iso8601("2024-01-01T00:00:00+05:30"));
|
||||
assert!(is_valid_iso8601("2024-12-31T23:59:59.999999Z"));
|
||||
|
||||
assert!(!is_valid_iso8601("2024-01-01"));
|
||||
assert!(!is_valid_iso8601("not-a-date"));
|
||||
assert!(!is_valid_iso8601(""));
|
||||
assert!(!is_valid_iso8601("2024/01/01T00:00:00Z"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_invalid_timestamp() {
|
||||
let mut fixture = create_minimal_valid_fixture();
|
||||
fixture.sessions[0].created_at = "not-a-date".to_string();
|
||||
|
||||
let result = validate_fixture(&fixture);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("invalid ISO8601 timestamp"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_duplicate_participant_ids() {
|
||||
let mut fixture = create_minimal_valid_fixture();
|
||||
let participant2 = fixture.sessions[0].participants[0].clone();
|
||||
fixture.sessions[0].participants.push(participant2);
|
||||
|
||||
let result = validate_fixture(&fixture);
|
||||
assert!(result.is_err());
|
||||
let err = result.unwrap_err();
|
||||
assert!(err.to_string().contains("duplicate participant ID"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,24 @@
|
||||
//! Fixture system for defining mock session behaviors.
|
||||
//!
|
||||
//! This module provides the core fixture system that defines how the mock server
|
||||
//! responds to client requests. Fixtures are loaded from YAML files and specify:
|
||||
//!
|
||||
//! - Session metadata (ID, title, timestamps)
|
||||
//! - Message sequences (user inputs and agent responses)
|
||||
//! - Response behaviors (static, random, sequential, pattern-based)
|
||||
//! - Tool call definitions and outcomes
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! - `types.rs` - Core fixture type definitions and validation
|
||||
//! - `loader.rs` - Loading fixtures from YAML files
|
||||
//! - `responders.rs` - Behavior logic for generating responses
|
||||
|
||||
pub mod loader;
|
||||
pub mod responders;
|
||||
pub mod types;
|
||||
|
||||
// Re-exports for convenience
|
||||
pub use loader::*;
|
||||
pub use responders::*;
|
||||
pub use types::*;
|
||||
@@ -0,0 +1,912 @@
|
||||
//! Response behavior logic for fixtures.
|
||||
//!
|
||||
//! This module implements different response strategies for generating mock assistant
|
||||
//! responses based on user input. Each strategy is represented by a struct implementing
|
||||
//! the `Responder` trait.
|
||||
//!
|
||||
//! ## Available Strategies
|
||||
//!
|
||||
//! - **Echo**: Simply echoes back the user's input with a prefix
|
||||
//! - **Keywords**: Matches keywords in user input to predefined responses
|
||||
//! - **Random**: Returns random responses from a corpus (seeded for reproducibility)
|
||||
//! - **FixtureOnly**: Replays assistant messages from fixture data in sequence
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use dirigate::fixture::{ResponderFactory, ResponderStrategy, Responders};
|
||||
//!
|
||||
//! let responders_config = Responders {
|
||||
//! keyword_map: HashMap::new(),
|
||||
//! default_strategy: ResponderStrategy::Echo,
|
||||
//! random: None,
|
||||
//! };
|
||||
//!
|
||||
//! let responder = ResponderFactory::create_responder(
|
||||
//! &ResponderStrategy::Echo,
|
||||
//! &responders_config,
|
||||
//! &session,
|
||||
//! )?;
|
||||
//!
|
||||
//! let response = responder.respond("Hello!", &session);
|
||||
//! ```
|
||||
|
||||
use crate::error::{MockerError, Result};
|
||||
use crate::fixture::types::{MessageRole, Responders, ResponderStrategy, Session};
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_chacha::ChaCha8Rng;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
/// Trait for response generation strategies.
|
||||
///
|
||||
/// Implementors of this trait define how the mock server generates responses
|
||||
/// to user messages. Different strategies can be used for different testing scenarios.
|
||||
pub trait Responder: Send + Sync {
|
||||
/// Generate a response based on user input and session context.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `user_input` - The message content from the user
|
||||
/// * `session` - The current session context (for accessing fixture data)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The generated response string that will be sent back to the user
|
||||
fn respond(&mut self, user_input: &str, session: &Session) -> String;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Echo Responder (Task 1.3)
|
||||
// ============================================================================
|
||||
|
||||
/// Responder that echoes back user input with a prefix.
|
||||
///
|
||||
/// This is the simplest responder strategy, useful for basic connectivity
|
||||
/// testing and debugging. It returns the user's message prefixed with "Echo: ".
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let responder = EchoResponder;
|
||||
/// assert_eq!(responder.respond("Hello", &session), "Echo: Hello");
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EchoResponder;
|
||||
|
||||
impl Responder for EchoResponder {
|
||||
fn respond(&mut self, user_input: &str, _session: &Session) -> String {
|
||||
tracing::debug!(user_input, "Echo responder invoked");
|
||||
format!("Echo: {}", user_input)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Keywords Responder (Task 1.4)
|
||||
// ============================================================================
|
||||
|
||||
/// Responder that matches keywords in user input to predefined responses.
|
||||
///
|
||||
/// This strategy searches the user input for keywords and returns the corresponding
|
||||
/// response. Matching is case-insensitive and supports substring matching.
|
||||
/// If multiple keywords match, the first one found (by map iteration order) is used.
|
||||
/// If no keywords match, a default response is returned.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let mut keyword_map = HashMap::new();
|
||||
/// keyword_map.insert("hello".to_string(), "Hi there!".to_string());
|
||||
/// keyword_map.insert("help".to_string(), "How can I assist?".to_string());
|
||||
///
|
||||
/// let responder = KeywordsResponder::new(
|
||||
/// keyword_map,
|
||||
/// "I don't understand.".to_string()
|
||||
/// );
|
||||
///
|
||||
/// assert_eq!(responder.respond("HELLO world", &session), "Hi there!");
|
||||
/// assert_eq!(responder.respond("random text", &session), "I don't understand.");
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KeywordsResponder {
|
||||
/// Map of keywords (lowercase) to their responses
|
||||
keyword_map: HashMap<String, String>,
|
||||
/// Response to use when no keyword matches
|
||||
default_response: String,
|
||||
}
|
||||
|
||||
impl KeywordsResponder {
|
||||
/// Create a new keywords responder.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `keyword_map` - Map of keywords to their responses (will be converted to lowercase)
|
||||
/// * `default_response` - Response to use when no keyword matches
|
||||
pub fn new(keyword_map: HashMap<String, String>, default_response: String) -> Self {
|
||||
// Convert all keywords to lowercase for case-insensitive matching
|
||||
let lowercase_map = keyword_map
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k.to_lowercase(), v))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
keyword_map: lowercase_map,
|
||||
default_response,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Responder for KeywordsResponder {
|
||||
fn respond(&mut self, user_input: &str, _session: &Session) -> String {
|
||||
let input_lower = user_input.to_lowercase();
|
||||
|
||||
// Search for the first matching keyword
|
||||
for (keyword, response) in &self.keyword_map {
|
||||
if input_lower.contains(keyword) {
|
||||
tracing::debug!(
|
||||
keyword = %keyword,
|
||||
user_input,
|
||||
"Keywords responder matched keyword"
|
||||
);
|
||||
return response.clone();
|
||||
}
|
||||
}
|
||||
|
||||
tracing::debug!(user_input, "Keywords responder using default response");
|
||||
self.default_response.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Random Responder (Task 1.5)
|
||||
// ============================================================================
|
||||
|
||||
/// Responder that returns random responses from a corpus.
|
||||
///
|
||||
/// This strategy uses a seeded random number generator to select responses
|
||||
/// from a predefined corpus. The seeded RNG ensures reproducible behavior
|
||||
/// for testing purposes.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let corpus = vec!["Response A".to_string(), "Response B".to_string()];
|
||||
/// let responder = RandomResponder::new(42, corpus);
|
||||
///
|
||||
/// // Same seed produces same sequence
|
||||
/// let response1 = responder.respond("anything", &session);
|
||||
/// let response2 = responder.respond("anything", &session);
|
||||
/// ```
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the corpus is empty when `respond()` is called.
|
||||
#[derive(Debug)]
|
||||
pub struct RandomResponder {
|
||||
/// Corpus of possible responses
|
||||
corpus: Vec<String>,
|
||||
/// Seeded random number generator for reproducibility
|
||||
rng: ChaCha8Rng,
|
||||
}
|
||||
|
||||
impl RandomResponder {
|
||||
/// Create a new random responder with a seeded RNG.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `seed` - Seed for the random number generator (for reproducibility)
|
||||
/// * `corpus` - List of possible responses to randomly select from
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Will panic during `respond()` if corpus is empty.
|
||||
pub fn new(seed: u64, corpus: Vec<String>) -> Self {
|
||||
let rng = ChaCha8Rng::seed_from_u64(seed);
|
||||
Self { corpus, rng }
|
||||
}
|
||||
}
|
||||
|
||||
impl Responder for RandomResponder {
|
||||
fn respond(&mut self, _user_input: &str, _session: &Session) -> String {
|
||||
if self.corpus.is_empty() {
|
||||
tracing::error!("Random responder has empty corpus");
|
||||
panic!("RandomResponder corpus is empty - cannot generate response");
|
||||
}
|
||||
|
||||
let index = self.rng.gen_range(0..self.corpus.len());
|
||||
let response = self.corpus[index].clone();
|
||||
|
||||
tracing::debug!(
|
||||
index,
|
||||
corpus_size = self.corpus.len(),
|
||||
"Random responder selected corpus entry"
|
||||
);
|
||||
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// FixtureOnly Responder (Task 1.6)
|
||||
// ============================================================================
|
||||
|
||||
/// Responder that replays assistant messages from fixture data in sequence.
|
||||
///
|
||||
/// This strategy returns pre-recorded assistant responses from the fixture's
|
||||
/// message history. It maintains a turn counter to track position in the sequence.
|
||||
/// When all fixture messages are exhausted, it returns an error message.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let responder = FixtureOnlyResponder::new(&session);
|
||||
///
|
||||
/// // Returns first assistant message
|
||||
/// let response1 = responder.respond("user input 1", &session);
|
||||
///
|
||||
/// // Returns second assistant message
|
||||
/// let response2 = responder.respond("user input 2", &session);
|
||||
///
|
||||
/// // Returns error when exhausted
|
||||
/// let response3 = responder.respond("user input 3", &session);
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct FixtureOnlyResponder {
|
||||
/// Pre-filtered list of assistant messages from the fixture
|
||||
assistant_messages: Vec<String>,
|
||||
/// Current position in the message sequence (atomic for thread-safety)
|
||||
turn_counter: AtomicUsize,
|
||||
}
|
||||
|
||||
impl FixtureOnlyResponder {
|
||||
/// Create a new fixture-only responder.
|
||||
///
|
||||
/// This extracts all assistant messages from the session's message history
|
||||
/// and prepares them for sequential replay.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session` - The session containing fixture messages to replay
|
||||
pub fn new(session: &Session) -> Self {
|
||||
let assistant_messages: Vec<String> = session
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == MessageRole::Assistant)
|
||||
.map(|msg| msg.content.clone())
|
||||
.collect();
|
||||
|
||||
tracing::debug!(
|
||||
message_count = assistant_messages.len(),
|
||||
session_id = %session.id,
|
||||
"FixtureOnly responder initialized"
|
||||
);
|
||||
|
||||
Self {
|
||||
assistant_messages,
|
||||
turn_counter: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Responder for FixtureOnlyResponder {
|
||||
fn respond(&mut self, _user_input: &str, _session: &Session) -> String {
|
||||
let current_turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
|
||||
|
||||
if current_turn >= self.assistant_messages.len() {
|
||||
tracing::warn!(
|
||||
turn = current_turn,
|
||||
total_messages = self.assistant_messages.len(),
|
||||
"FixtureOnly responder exhausted all messages"
|
||||
);
|
||||
return "[ERROR: No more fixture messages available]".to_string();
|
||||
}
|
||||
|
||||
let response = self.assistant_messages[current_turn].clone();
|
||||
|
||||
tracing::debug!(
|
||||
turn = current_turn,
|
||||
total_messages = self.assistant_messages.len(),
|
||||
"FixtureOnly responder replaying message"
|
||||
);
|
||||
|
||||
response
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Responder Factory (Task 1.7)
|
||||
// ============================================================================
|
||||
|
||||
/// Factory for creating responder instances based on strategy configuration.
|
||||
///
|
||||
/// This factory handles the logic of creating the appropriate responder type
|
||||
/// based on the strategy enum and associated configuration. It also handles
|
||||
/// session-level behavior overrides.
|
||||
pub struct ResponderFactory;
|
||||
|
||||
impl ResponderFactory {
|
||||
/// Create a responder instance based on strategy and configuration.
|
||||
///
|
||||
/// This function creates the appropriate responder type based on the strategy
|
||||
/// and validates that required configuration is present.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `strategy` - The responder strategy to use
|
||||
/// * `responders` - Global responder configuration (keyword_map, random config)
|
||||
/// * `session` - Session context (for FixtureOnly and session overrides)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boxed responder instance implementing the `Responder` trait
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `MockerError::FixtureValidation` if:
|
||||
/// - Random strategy is requested but no random config is provided
|
||||
/// - Random config has an empty corpus
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let responder = ResponderFactory::create_responder(
|
||||
/// &ResponderStrategy::Echo,
|
||||
/// &responders_config,
|
||||
/// &session,
|
||||
/// )?;
|
||||
/// ```
|
||||
pub fn create_responder(
|
||||
strategy: &ResponderStrategy,
|
||||
responders: &Responders,
|
||||
session: &Session,
|
||||
) -> Result<Box<dyn Responder>> {
|
||||
tracing::info!(
|
||||
strategy = ?strategy,
|
||||
session_id = %session.id,
|
||||
"Creating responder"
|
||||
);
|
||||
|
||||
match strategy {
|
||||
ResponderStrategy::Echo => {
|
||||
tracing::info!("Selected Echo responder");
|
||||
Ok(Box::new(EchoResponder))
|
||||
}
|
||||
|
||||
ResponderStrategy::Keywords => {
|
||||
tracing::info!(
|
||||
keyword_count = responders.keyword_map.len(),
|
||||
"Selected Keywords responder"
|
||||
);
|
||||
|
||||
// Use keyword map with a default response
|
||||
let default_response = "[No matching keyword found]".to_string();
|
||||
Ok(Box::new(KeywordsResponder::new(
|
||||
responders.keyword_map.clone(),
|
||||
default_response,
|
||||
)))
|
||||
}
|
||||
|
||||
ResponderStrategy::Random => {
|
||||
let random_config = responders.random.as_ref().ok_or_else(|| {
|
||||
MockerError::FixtureValidation(
|
||||
"Random strategy requires 'random' configuration".to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
if random_config.corpus.is_empty() {
|
||||
return Err(MockerError::FixtureValidation(
|
||||
"Random responder corpus cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
seed = random_config.seed,
|
||||
corpus_size = random_config.corpus.len(),
|
||||
"Selected Random responder"
|
||||
);
|
||||
|
||||
Ok(Box::new(RandomResponder::new(
|
||||
random_config.seed,
|
||||
random_config.corpus.clone(),
|
||||
)))
|
||||
}
|
||||
|
||||
ResponderStrategy::FixtureOnly => {
|
||||
let assistant_count = session
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == MessageRole::Assistant)
|
||||
.count();
|
||||
|
||||
tracing::info!(
|
||||
assistant_messages = assistant_count,
|
||||
"Selected FixtureOnly responder"
|
||||
);
|
||||
|
||||
Ok(Box::new(FixtureOnlyResponder::new(session)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a responder for a session, handling session-level behavior overrides.
|
||||
///
|
||||
/// This is a convenience method that checks for session-level responder overrides
|
||||
/// before falling back to the default strategy.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `responders` - Global responder configuration
|
||||
/// * `session` - Session context (may contain behavior overrides)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boxed responder instance
|
||||
pub fn create_for_session(responders: &Responders, session: &Session) -> Result<Box<dyn Responder>> {
|
||||
// Check if session has a behavior override
|
||||
let strategy = session
|
||||
.behavior
|
||||
.as_ref()
|
||||
.and_then(|b| b.responder.as_ref())
|
||||
.unwrap_or(&responders.default_strategy);
|
||||
|
||||
tracing::debug!(
|
||||
session_id = %session.id,
|
||||
strategy = ?strategy,
|
||||
has_override = session.behavior.is_some(),
|
||||
"Creating responder for session"
|
||||
);
|
||||
|
||||
Self::create_responder(strategy, responders, session)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::fixture::types::{Message, RandomConfig};
|
||||
|
||||
// Helper to create a minimal test session
|
||||
fn create_test_session(messages: Vec<Message>) -> Session {
|
||||
Session {
|
||||
id: "test-session".to_string(),
|
||||
title: "Test Session".to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
participants: vec![],
|
||||
messages,
|
||||
behavior: None,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to create a test message
|
||||
fn create_message(id: &str, role: MessageRole, content: &str) -> Message {
|
||||
Message {
|
||||
id: id.to_string(),
|
||||
session_id: "test-session".to_string(),
|
||||
role,
|
||||
content: content.to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
parent_id: None,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Task 1.3: Echo Responder Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_echo_responder_basic() {
|
||||
let session = create_test_session(vec![]);
|
||||
let mut responder = EchoResponder;
|
||||
|
||||
assert_eq!(responder.respond("Hello", &session), "Echo: Hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_echo_responder_empty_input() {
|
||||
let session = create_test_session(vec![]);
|
||||
let mut responder = EchoResponder;
|
||||
|
||||
assert_eq!(responder.respond("", &session), "Echo: ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_echo_responder_special_characters() {
|
||||
let session = create_test_session(vec![]);
|
||||
let mut responder = EchoResponder;
|
||||
|
||||
assert_eq!(
|
||||
responder.respond("Hello, world! @#$%", &session),
|
||||
"Echo: Hello, world! @#$%"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_echo_responder_multiline() {
|
||||
let session = create_test_session(vec![]);
|
||||
let mut responder = EchoResponder;
|
||||
|
||||
let input = "Line 1\nLine 2\nLine 3";
|
||||
assert_eq!(responder.respond(input, &session), "Echo: Line 1\nLine 2\nLine 3");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Task 1.4: Keywords Responder Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_keywords_responder_exact_match() {
|
||||
let session = create_test_session(vec![]);
|
||||
let mut keyword_map = HashMap::new();
|
||||
keyword_map.insert("hello".to_string(), "Hi there!".to_string());
|
||||
|
||||
let mut responder = KeywordsResponder::new(keyword_map, "Default".to_string());
|
||||
|
||||
assert_eq!(responder.respond("hello", &session), "Hi there!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keywords_responder_partial_match() {
|
||||
let session = create_test_session(vec![]);
|
||||
let mut keyword_map = HashMap::new();
|
||||
keyword_map.insert("help".to_string(), "How can I assist?".to_string());
|
||||
|
||||
let mut responder = KeywordsResponder::new(keyword_map, "Default".to_string());
|
||||
|
||||
assert_eq!(
|
||||
responder.respond("I need help with something", &session),
|
||||
"How can I assist?"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keywords_responder_case_insensitive() {
|
||||
let session = create_test_session(vec![]);
|
||||
let mut keyword_map = HashMap::new();
|
||||
keyword_map.insert("hello".to_string(), "Hi!".to_string());
|
||||
|
||||
let mut responder = KeywordsResponder::new(keyword_map, "Default".to_string());
|
||||
|
||||
assert_eq!(responder.respond("HELLO", &session), "Hi!");
|
||||
assert_eq!(responder.respond("HeLLo", &session), "Hi!");
|
||||
assert_eq!(responder.respond("hello", &session), "Hi!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keywords_responder_no_match_returns_default() {
|
||||
let session = create_test_session(vec![]);
|
||||
let mut keyword_map = HashMap::new();
|
||||
keyword_map.insert("hello".to_string(), "Hi!".to_string());
|
||||
|
||||
let mut responder = KeywordsResponder::new(keyword_map, "I don't understand".to_string());
|
||||
|
||||
assert_eq!(responder.respond("goodbye", &session), "I don't understand");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keywords_responder_multiple_keywords() {
|
||||
let session = create_test_session(vec![]);
|
||||
let mut keyword_map = HashMap::new();
|
||||
keyword_map.insert("hello".to_string(), "Response A".to_string());
|
||||
keyword_map.insert("help".to_string(), "Response B".to_string());
|
||||
|
||||
let mut responder = KeywordsResponder::new(keyword_map, "Default".to_string());
|
||||
|
||||
// Should match one of them (order depends on HashMap iteration)
|
||||
let response = responder.respond("hello help", &session);
|
||||
assert!(response == "Response A" || response == "Response B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keywords_responder_empty_map() {
|
||||
let session = create_test_session(vec![]);
|
||||
let keyword_map = HashMap::new();
|
||||
|
||||
let mut responder = KeywordsResponder::new(keyword_map, "Always default".to_string());
|
||||
|
||||
assert_eq!(responder.respond("anything", &session), "Always default");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Task 1.5: Random Responder Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_random_responder_same_seed_produces_same_sequence() {
|
||||
let session = create_test_session(vec![]);
|
||||
let corpus = vec!["A".to_string(), "B".to_string(), "C".to_string()];
|
||||
|
||||
let mut responder1 = RandomResponder::new(42, corpus.clone());
|
||||
let mut responder2 = RandomResponder::new(42, corpus);
|
||||
|
||||
// Same seed should produce same sequence
|
||||
for _ in 0..10 {
|
||||
assert_eq!(
|
||||
responder1.respond("test", &session),
|
||||
responder2.respond("test", &session)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_responder_different_seeds_differ() {
|
||||
let session = create_test_session(vec![]);
|
||||
let corpus = vec!["A".to_string(), "B".to_string(), "C".to_string()];
|
||||
|
||||
let mut responder1 = RandomResponder::new(42, corpus.clone());
|
||||
let mut responder2 = RandomResponder::new(123, corpus);
|
||||
|
||||
// Collect responses
|
||||
let responses1: Vec<String> = (0..20).map(|_| responder1.respond("test", &session)).collect();
|
||||
let responses2: Vec<String> = (0..20).map(|_| responder2.respond("test", &session)).collect();
|
||||
|
||||
// Different seeds should produce different sequences
|
||||
assert_ne!(responses1, responses2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_random_responder_all_corpus_entries_selected() {
|
||||
let session = create_test_session(vec![]);
|
||||
let corpus = vec!["A".to_string(), "B".to_string(), "C".to_string()];
|
||||
|
||||
let mut responder = RandomResponder::new(42, corpus.clone());
|
||||
|
||||
// Collect many responses
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for _ in 0..100 {
|
||||
seen.insert(responder.respond("test", &session));
|
||||
}
|
||||
|
||||
// All corpus entries should eventually be selected
|
||||
assert_eq!(seen.len(), corpus.len());
|
||||
for entry in corpus {
|
||||
assert!(seen.contains(&entry));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "RandomResponder corpus is empty")]
|
||||
fn test_random_responder_empty_corpus_panics() {
|
||||
let session = create_test_session(vec![]);
|
||||
let corpus = vec![];
|
||||
|
||||
let mut responder = RandomResponder::new(42, corpus);
|
||||
responder.respond("test", &session);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Task 1.6: FixtureOnly Responder Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_fixture_only_responder_sequential_replay() {
|
||||
let messages = vec![
|
||||
create_message("1", MessageRole::Assistant, "First response"),
|
||||
create_message("2", MessageRole::User, "User message"),
|
||||
create_message("3", MessageRole::Assistant, "Second response"),
|
||||
create_message("4", MessageRole::Assistant, "Third response"),
|
||||
];
|
||||
let session = create_test_session(messages);
|
||||
|
||||
let mut responder = FixtureOnlyResponder::new(&session);
|
||||
|
||||
assert_eq!(responder.respond("input 1", &session), "First response");
|
||||
assert_eq!(responder.respond("input 2", &session), "Second response");
|
||||
assert_eq!(responder.respond("input 3", &session), "Third response");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fixture_only_responder_end_of_fixture() {
|
||||
let messages = vec![create_message("1", MessageRole::Assistant, "Only response")];
|
||||
let session = create_test_session(messages);
|
||||
|
||||
let mut responder = FixtureOnlyResponder::new(&session);
|
||||
|
||||
assert_eq!(responder.respond("input 1", &session), "Only response");
|
||||
assert_eq!(
|
||||
responder.respond("input 2", &session),
|
||||
"[ERROR: No more fixture messages available]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fixture_only_responder_empty_fixture() {
|
||||
let session = create_test_session(vec![]);
|
||||
|
||||
let mut responder = FixtureOnlyResponder::new(&session);
|
||||
|
||||
assert_eq!(
|
||||
responder.respond("input", &session),
|
||||
"[ERROR: No more fixture messages available]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fixture_only_responder_filters_assistant_only() {
|
||||
let messages = vec![
|
||||
create_message("1", MessageRole::User, "User 1"),
|
||||
create_message("2", MessageRole::Assistant, "Assistant 1"),
|
||||
create_message("3", MessageRole::System, "System"),
|
||||
create_message("4", MessageRole::User, "User 2"),
|
||||
create_message("5", MessageRole::Assistant, "Assistant 2"),
|
||||
];
|
||||
let session = create_test_session(messages);
|
||||
|
||||
let mut responder = FixtureOnlyResponder::new(&session);
|
||||
|
||||
assert_eq!(responder.respond("input 1", &session), "Assistant 1");
|
||||
assert_eq!(responder.respond("input 2", &session), "Assistant 2");
|
||||
assert_eq!(
|
||||
responder.respond("input 3", &session),
|
||||
"[ERROR: No more fixture messages available]"
|
||||
);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Task 1.7: Responder Factory Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_factory_creates_echo_responder() {
|
||||
let session = create_test_session(vec![]);
|
||||
let responders = Responders {
|
||||
keyword_map: HashMap::new(),
|
||||
default_strategy: ResponderStrategy::Echo,
|
||||
random: None,
|
||||
};
|
||||
|
||||
let mut responder =
|
||||
ResponderFactory::create_responder(&ResponderStrategy::Echo, &responders, &session)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(responder.respond("test", &session), "Echo: test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_creates_keywords_responder() {
|
||||
let session = create_test_session(vec![]);
|
||||
let mut keyword_map = HashMap::new();
|
||||
keyword_map.insert("test".to_string(), "Test response".to_string());
|
||||
|
||||
let responders = Responders {
|
||||
keyword_map,
|
||||
default_strategy: ResponderStrategy::Keywords,
|
||||
random: None,
|
||||
};
|
||||
|
||||
let mut responder =
|
||||
ResponderFactory::create_responder(&ResponderStrategy::Keywords, &responders, &session)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(responder.respond("test", &session), "Test response");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_creates_random_responder() {
|
||||
let session = create_test_session(vec![]);
|
||||
let responders = Responders {
|
||||
keyword_map: HashMap::new(),
|
||||
default_strategy: ResponderStrategy::Random,
|
||||
random: Some(RandomConfig {
|
||||
seed: 42,
|
||||
corpus: vec!["Response 1".to_string(), "Response 2".to_string()],
|
||||
}),
|
||||
};
|
||||
|
||||
let mut responder =
|
||||
ResponderFactory::create_responder(&ResponderStrategy::Random, &responders, &session)
|
||||
.unwrap();
|
||||
|
||||
let response = responder.respond("test", &session);
|
||||
assert!(response == "Response 1" || response == "Response 2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_creates_fixture_only_responder() {
|
||||
let messages = vec![create_message("1", MessageRole::Assistant, "Fixture response")];
|
||||
let session = create_test_session(messages);
|
||||
|
||||
let responders = Responders {
|
||||
keyword_map: HashMap::new(),
|
||||
default_strategy: ResponderStrategy::FixtureOnly,
|
||||
random: None,
|
||||
};
|
||||
|
||||
let mut responder = ResponderFactory::create_responder(
|
||||
&ResponderStrategy::FixtureOnly,
|
||||
&responders,
|
||||
&session,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(responder.respond("test", &session), "Fixture response");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_random_without_config_fails() {
|
||||
let session = create_test_session(vec![]);
|
||||
let responders = Responders {
|
||||
keyword_map: HashMap::new(),
|
||||
default_strategy: ResponderStrategy::Random,
|
||||
random: None, // Missing required config
|
||||
};
|
||||
|
||||
let result =
|
||||
ResponderFactory::create_responder(&ResponderStrategy::Random, &responders, &session);
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(MockerError::FixtureValidation(msg)) => {
|
||||
assert!(msg.contains("Random strategy requires"));
|
||||
}
|
||||
_ => panic!("Expected FixtureValidation error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_random_with_empty_corpus_fails() {
|
||||
let session = create_test_session(vec![]);
|
||||
let responders = Responders {
|
||||
keyword_map: HashMap::new(),
|
||||
default_strategy: ResponderStrategy::Random,
|
||||
random: Some(RandomConfig {
|
||||
seed: 42,
|
||||
corpus: vec![], // Empty corpus
|
||||
}),
|
||||
};
|
||||
|
||||
let result =
|
||||
ResponderFactory::create_responder(&ResponderStrategy::Random, &responders, &session);
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(MockerError::FixtureValidation(msg)) => {
|
||||
assert!(msg.contains("corpus cannot be empty"));
|
||||
}
|
||||
_ => panic!("Expected FixtureValidation error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_create_for_session_uses_default() {
|
||||
let session = create_test_session(vec![]);
|
||||
let responders = Responders {
|
||||
keyword_map: HashMap::new(),
|
||||
default_strategy: ResponderStrategy::Echo,
|
||||
random: None,
|
||||
};
|
||||
|
||||
let mut responder = ResponderFactory::create_for_session(&responders, &session).unwrap();
|
||||
|
||||
assert_eq!(responder.respond("test", &session), "Echo: test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_create_for_session_uses_override() {
|
||||
use crate::fixture::types::SessionBehavior;
|
||||
|
||||
let mut session = create_test_session(vec![]);
|
||||
session.behavior = Some(SessionBehavior {
|
||||
responder: Some(ResponderStrategy::Echo),
|
||||
streaming: None,
|
||||
});
|
||||
|
||||
let mut keyword_map = HashMap::new();
|
||||
keyword_map.insert("test".to_string(), "Keyword response".to_string());
|
||||
|
||||
let responders = Responders {
|
||||
keyword_map,
|
||||
default_strategy: ResponderStrategy::Keywords, // Default is Keywords
|
||||
random: None,
|
||||
};
|
||||
|
||||
let mut responder = ResponderFactory::create_for_session(&responders, &session).unwrap();
|
||||
|
||||
// Should use Echo (override) not Keywords (default)
|
||||
assert_eq!(responder.respond("test", &session), "Echo: test");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,255 @@
|
||||
//! Core fixture type definitions.
|
||||
//!
|
||||
//! This module defines the structure of fixture files and the types
|
||||
//! used to represent mock sessions and their behaviors.
|
||||
//!
|
||||
//! ## Fixture Schema v0.1
|
||||
//!
|
||||
//! Fixtures define mock sessions with messages, participants, and behavior configuration.
|
||||
//! They support:
|
||||
//! - Multiple sessions with participants and message histories
|
||||
//! - Configurable response strategies (echo, keywords, random, fixture-only)
|
||||
//! - Streaming simulation with configurable chunking
|
||||
//! - Per-session behavior overrides
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Top-level fixture definition.
|
||||
///
|
||||
/// A fixture represents a complete test scenario with sessions, response behavior,
|
||||
/// and streaming configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Fixture {
|
||||
/// Fixture version (must be "0.1").
|
||||
pub version: String,
|
||||
|
||||
/// List of sessions defined in this fixture.
|
||||
pub sessions: Vec<Session>,
|
||||
|
||||
/// Response behavior configuration.
|
||||
pub responders: Responders,
|
||||
|
||||
/// Streaming behavior configuration.
|
||||
pub streaming: Streaming,
|
||||
}
|
||||
|
||||
/// Session fixture defining a mock session.
|
||||
///
|
||||
/// Represents a single conversation session with participants and messages.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Session {
|
||||
/// Unique session identifier.
|
||||
pub id: String,
|
||||
|
||||
/// Human-readable session title.
|
||||
pub title: String,
|
||||
|
||||
/// ISO8601 timestamp when session was created.
|
||||
pub created_at: String,
|
||||
|
||||
/// Participants in this session.
|
||||
pub participants: Vec<Participant>,
|
||||
|
||||
/// Messages in this session.
|
||||
pub messages: Vec<Message>,
|
||||
|
||||
/// Optional per-session behavior overrides.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub behavior: Option<SessionBehavior>,
|
||||
}
|
||||
|
||||
/// Participant in a session.
|
||||
///
|
||||
/// Represents an entity that can send or receive messages.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Participant {
|
||||
/// Unique participant identifier.
|
||||
pub id: String,
|
||||
|
||||
/// Type of participant.
|
||||
pub kind: ParticipantKind,
|
||||
|
||||
/// Optional display name for the participant.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub display_name: Option<String>,
|
||||
}
|
||||
|
||||
/// Type of participant.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ParticipantKind {
|
||||
/// Human user.
|
||||
User,
|
||||
|
||||
/// AI assistant.
|
||||
Assistant,
|
||||
|
||||
/// System message source.
|
||||
System,
|
||||
|
||||
/// Tool or function.
|
||||
Tool,
|
||||
|
||||
/// Other participant type.
|
||||
Other,
|
||||
}
|
||||
|
||||
/// Message in a session.
|
||||
///
|
||||
/// Represents a single message with content, role, and metadata.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
/// Unique message identifier.
|
||||
pub id: String,
|
||||
|
||||
/// Session this message belongs to.
|
||||
pub session_id: String,
|
||||
|
||||
/// Message role.
|
||||
pub role: MessageRole,
|
||||
|
||||
/// Message content (text-only in v0.1).
|
||||
pub content: String,
|
||||
|
||||
/// ISO8601 timestamp when message was created.
|
||||
pub created_at: String,
|
||||
|
||||
/// Optional parent message ID for threading.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub parent_id: Option<String>,
|
||||
|
||||
/// Optional metadata (arbitrary JSON).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Message role.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MessageRole {
|
||||
/// Message from user.
|
||||
User,
|
||||
|
||||
/// Message from assistant.
|
||||
Assistant,
|
||||
|
||||
/// System message.
|
||||
System,
|
||||
}
|
||||
|
||||
/// Response behavior configuration.
|
||||
///
|
||||
/// Defines how the mocker generates responses to incoming messages.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Responders {
|
||||
/// Map of keywords to response strings.
|
||||
/// When user input contains a keyword, the corresponding response is used.
|
||||
#[serde(default)]
|
||||
pub keyword_map: HashMap<String, String>,
|
||||
|
||||
/// Default response strategy when no keyword matches.
|
||||
pub default_strategy: ResponderStrategy,
|
||||
|
||||
/// Configuration for random responses (required if default_strategy is Random).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub random: Option<RandomConfig>,
|
||||
}
|
||||
|
||||
/// Response generation strategy.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ResponderStrategy {
|
||||
/// Echo back the user's message.
|
||||
Echo,
|
||||
|
||||
/// Use keyword matching from keyword_map.
|
||||
Keywords,
|
||||
|
||||
/// Return random responses from corpus.
|
||||
Random,
|
||||
|
||||
/// Only respond with messages from fixtures (no dynamic generation).
|
||||
FixtureOnly,
|
||||
}
|
||||
|
||||
/// Configuration for random response generation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RandomConfig {
|
||||
/// Random seed for reproducibility.
|
||||
pub seed: u64,
|
||||
|
||||
/// Corpus of possible responses.
|
||||
pub corpus: Vec<String>,
|
||||
}
|
||||
|
||||
/// Streaming behavior configuration.
|
||||
///
|
||||
/// Controls how responses are streamed to clients.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Streaming {
|
||||
/// Whether streaming is enabled.
|
||||
pub enabled: bool,
|
||||
|
||||
/// Number of tokens per chunk.
|
||||
pub tokens_per_chunk: usize,
|
||||
|
||||
/// Interval between chunks in milliseconds.
|
||||
pub chunk_interval_ms: u64,
|
||||
|
||||
/// Optional random jitter to add to chunk intervals (in milliseconds).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub jitter_ms: Option<u64>,
|
||||
}
|
||||
|
||||
/// Per-session behavior overrides.
|
||||
///
|
||||
/// Allows specific sessions to override global responder and streaming settings.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionBehavior {
|
||||
/// Override default responder strategy for this session.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub responder: Option<ResponderStrategy>,
|
||||
|
||||
/// Override streaming configuration for this session.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub streaming: Option<Streaming>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_participant_kind_serde() {
|
||||
// Test snake_case serialization
|
||||
let kind = ParticipantKind::Assistant;
|
||||
let json = serde_json::to_string(&kind).unwrap();
|
||||
assert_eq!(json, "\"assistant\"");
|
||||
|
||||
let kind: ParticipantKind = serde_json::from_str("\"user\"").unwrap();
|
||||
assert_eq!(kind, ParticipantKind::User);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_role_serde() {
|
||||
// Test snake_case serialization
|
||||
let role = MessageRole::Assistant;
|
||||
let json = serde_json::to_string(&role).unwrap();
|
||||
assert_eq!(json, "\"assistant\"");
|
||||
|
||||
let role: MessageRole = serde_json::from_str("\"user\"").unwrap();
|
||||
assert_eq!(role, MessageRole::User);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responder_strategy_serde() {
|
||||
// Test snake_case serialization
|
||||
let strategy = ResponderStrategy::FixtureOnly;
|
||||
let json = serde_json::to_string(&strategy).unwrap();
|
||||
assert_eq!(json, "\"fixture_only\"");
|
||||
|
||||
let strategy: ResponderStrategy = serde_json::from_str("\"echo\"").unwrap();
|
||||
assert_eq!(strategy, ResponderStrategy::Echo);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//! Session ingestion from external sources (feature-gated).
|
||||
//!
|
||||
//! This module provides functionality to import sessions from live agent systems
|
||||
//! (like OpenCode.ai) and convert them into fixture format. This is useful for:
|
||||
//!
|
||||
//! - Creating realistic test fixtures from production sessions
|
||||
//! - Recording complex interaction flows for regression testing
|
||||
//! - Building fixture libraries from existing conversations
|
||||
//!
|
||||
//! **Note**: This module is only available when the `ingest` feature is enabled.
|
||||
|
||||
#[cfg(feature = "ingest")]
|
||||
pub mod opencode;
|
||||
|
||||
#[cfg(feature = "ingest")]
|
||||
pub use opencode::run_ingest;
|
||||
@@ -0,0 +1,569 @@
|
||||
//! OpenCode.ai session ingestion.
|
||||
//!
|
||||
//! This module provides functionality to fetch sessions from OpenCode.ai
|
||||
//! and convert them into fixture format.
|
||||
|
||||
use crate::{
|
||||
fixture::{
|
||||
types::{
|
||||
Fixture, Message, MessageRole, Participant, ParticipantKind, Responders,
|
||||
ResponderStrategy, Session, Streaming,
|
||||
},
|
||||
validate_fixture,
|
||||
},
|
||||
MockerError, Result,
|
||||
};
|
||||
use opencode_client::{MessageWithParts, OpenCodeClient};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
// ============================================================================
|
||||
// Public API
|
||||
// ============================================================================
|
||||
|
||||
/// Run the complete ingestion workflow.
|
||||
///
|
||||
/// This is the main entry point for ingesting sessions from OpenCode.ai.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `base_url` - Base URL of the OpenCode API
|
||||
/// * `session_id` - Optional specific session ID to ingest
|
||||
/// * `all` - Whether to ingest all sessions
|
||||
/// * `output` - Output file path for the fixture
|
||||
/// * `merge` - Whether to merge with existing fixture
|
||||
///
|
||||
/// # Workflow
|
||||
///
|
||||
/// 1. Fetch session(s) from OpenCode API
|
||||
/// 2. Map to fixture format
|
||||
/// 3. Load existing fixture if merge is enabled
|
||||
/// 4. Merge fixtures if necessary
|
||||
/// 5. Validate the final fixture
|
||||
/// 6. Export to YAML file
|
||||
pub async fn run_ingest(
|
||||
base_url: &str,
|
||||
session_id: Option<String>,
|
||||
all: bool,
|
||||
output: &Path,
|
||||
merge: bool,
|
||||
) -> Result<()> {
|
||||
// Step 1: Create ingestor and fetch sessions
|
||||
tracing::info!("Creating OpenCode client");
|
||||
let ingestor = OpenCodeIngestor::new(base_url);
|
||||
|
||||
let opencode_sessions = if all {
|
||||
tracing::info!("Fetching all sessions from OpenCode");
|
||||
ingestor.fetch_all_sessions().await?
|
||||
} else if let Some(id) = session_id {
|
||||
tracing::info!("Fetching single session: {}", id);
|
||||
vec![ingestor.fetch_session(&id).await?]
|
||||
} else {
|
||||
return Err(MockerError::Ingest(
|
||||
"Must specify either session_id or all".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
tracing::info!("Fetched {} session(s)", opencode_sessions.len());
|
||||
|
||||
// Step 2: Map sessions to fixture format
|
||||
tracing::info!("Mapping sessions to fixture format");
|
||||
let new_fixture = map_sessions_to_fixture(opencode_sessions)?;
|
||||
|
||||
// Step 3: Load existing fixture if merge is enabled
|
||||
let final_fixture = if merge && output.exists() {
|
||||
tracing::info!("Loading existing fixture for merge: {}", output.display());
|
||||
let existing_fixture = load_or_create_fixture(output).await?;
|
||||
tracing::info!(
|
||||
"Merging {} existing sessions with {} new sessions",
|
||||
existing_fixture.sessions.len(),
|
||||
new_fixture.sessions.len()
|
||||
);
|
||||
merge_fixtures(existing_fixture, new_fixture)?
|
||||
} else {
|
||||
tracing::info!("Creating new fixture (no merge)");
|
||||
new_fixture
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
"Final fixture has {} session(s)",
|
||||
final_fixture.sessions.len()
|
||||
);
|
||||
|
||||
// Step 4: Validate the fixture
|
||||
tracing::info!("Validating fixture");
|
||||
validate_fixture(&final_fixture)?;
|
||||
|
||||
// Step 5: Export to YAML
|
||||
tracing::info!("Exporting fixture to: {}", output.display());
|
||||
export_fixture(&final_fixture, output).await?;
|
||||
|
||||
tracing::info!("Ingestion complete!");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// OpenCodeIngestor
|
||||
// ============================================================================
|
||||
|
||||
/// OpenCode session ingestor.
|
||||
///
|
||||
/// Handles fetching sessions and messages from OpenCode.ai API.
|
||||
struct OpenCodeIngestor {
|
||||
client: OpenCodeClient,
|
||||
}
|
||||
|
||||
impl OpenCodeIngestor {
|
||||
/// Create a new OpenCode ingestor.
|
||||
fn new(base_url: &str) -> Self {
|
||||
Self {
|
||||
client: OpenCodeClient::new(base_url),
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch a single session by ID.
|
||||
async fn fetch_session(&self, session_id: &str) -> Result<OpenCodeSession> {
|
||||
tracing::debug!("Fetching session: {}", session_id);
|
||||
|
||||
let session = self
|
||||
.client
|
||||
.get_session(session_id)
|
||||
.await
|
||||
.map_err(|e| MockerError::Ingest(format!("Failed to fetch session: {}", e)))?;
|
||||
|
||||
tracing::debug!("Fetching messages for session: {}", session_id);
|
||||
let messages = self
|
||||
.client
|
||||
.list_messages(session_id)
|
||||
.await
|
||||
.map_err(|e| MockerError::Ingest(format!("Failed to fetch messages: {}", e)))?;
|
||||
|
||||
tracing::debug!(
|
||||
"Fetched {} messages for session {}",
|
||||
messages.len(),
|
||||
session_id
|
||||
);
|
||||
|
||||
Ok(OpenCodeSession { session, messages })
|
||||
}
|
||||
|
||||
/// Fetch all sessions from the API.
|
||||
async fn fetch_all_sessions(&self) -> Result<Vec<OpenCodeSession>> {
|
||||
tracing::debug!("Fetching all sessions");
|
||||
|
||||
let sessions = self
|
||||
.client
|
||||
.list_sessions()
|
||||
.await
|
||||
.map_err(|e| MockerError::Ingest(format!("Failed to list sessions: {}", e)))?;
|
||||
|
||||
tracing::info!("Found {} sessions to ingest", sessions.len());
|
||||
|
||||
let mut results = Vec::new();
|
||||
for session_info in sessions {
|
||||
match self.fetch_session(&session_info.id).await {
|
||||
Ok(session) => results.push(session),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to fetch session {}: {}", session_info.id, e);
|
||||
// Continue with other sessions
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
/// OpenCode session with messages.
|
||||
struct OpenCodeSession {
|
||||
session: opencode_client::Session,
|
||||
messages: Vec<MessageWithParts>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session Mapping
|
||||
// ============================================================================
|
||||
|
||||
/// Map OpenCode sessions to fixture format.
|
||||
fn map_sessions_to_fixture(opencode_sessions: Vec<OpenCodeSession>) -> Result<Fixture> {
|
||||
let sessions: Vec<Session> = opencode_sessions
|
||||
.into_iter()
|
||||
.map(|oc_session| map_session(oc_session))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
|
||||
// Create default responders and streaming config
|
||||
let responders = create_default_responders();
|
||||
let streaming = create_default_streaming();
|
||||
|
||||
Ok(Fixture {
|
||||
version: "0.1".to_string(),
|
||||
sessions,
|
||||
responders,
|
||||
streaming,
|
||||
})
|
||||
}
|
||||
|
||||
/// Map a single OpenCode session to fixture format.
|
||||
fn map_session(oc_session: OpenCodeSession) -> Result<Session> {
|
||||
let session_id = oc_session.session.id.clone();
|
||||
|
||||
// Create participants (user and assistant)
|
||||
let participants = vec![
|
||||
Participant {
|
||||
id: "user".to_string(),
|
||||
kind: ParticipantKind::User,
|
||||
display_name: Some("User".to_string()),
|
||||
},
|
||||
Participant {
|
||||
id: "assistant".to_string(),
|
||||
kind: ParticipantKind::Assistant,
|
||||
display_name: Some("Assistant".to_string()),
|
||||
},
|
||||
];
|
||||
|
||||
// Map messages
|
||||
let messages: Vec<Message> = oc_session
|
||||
.messages
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, msg_with_parts)| map_message(&session_id, idx, msg_with_parts))
|
||||
.collect();
|
||||
|
||||
// Convert Unix timestamp (milliseconds) to ISO8601
|
||||
let created_at = chrono::DateTime::from_timestamp_millis(
|
||||
oc_session.session.time.created as i64
|
||||
)
|
||||
.map(|dt| dt.to_rfc3339())
|
||||
.unwrap_or_else(|| chrono::Utc::now().to_rfc3339());
|
||||
|
||||
Ok(Session {
|
||||
id: session_id,
|
||||
title: oc_session.session.title,
|
||||
created_at,
|
||||
participants,
|
||||
messages,
|
||||
behavior: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Map an OpenCode message to fixture format.
|
||||
///
|
||||
/// Returns None if the message has no text content.
|
||||
fn map_message(
|
||||
session_id: &str,
|
||||
index: usize,
|
||||
msg_with_parts: MessageWithParts,
|
||||
) -> Option<Message> {
|
||||
// Extract role and timestamp from message info
|
||||
let (role, created_ms) = match msg_with_parts.info {
|
||||
opencode_client::Message::User(user_msg) => {
|
||||
(MessageRole::User, user_msg.time.created)
|
||||
}
|
||||
opencode_client::Message::Assistant(assistant_msg) => {
|
||||
(MessageRole::Assistant, assistant_msg.time.created)
|
||||
}
|
||||
};
|
||||
|
||||
// Extract text content from parts
|
||||
// We concatenate all text and reasoning parts for simplicity in v0.1
|
||||
let mut content_parts = Vec::new();
|
||||
|
||||
for part in msg_with_parts.parts {
|
||||
match part {
|
||||
opencode_client::Part::Text(text_part) => {
|
||||
content_parts.push(text_part.text);
|
||||
}
|
||||
opencode_client::Part::Reasoning(reasoning_part) => {
|
||||
// Include reasoning in content for v0.1
|
||||
content_parts.push(format!("[Reasoning]\n{}", reasoning_part.text));
|
||||
}
|
||||
opencode_client::Part::Tool(tool_part) => {
|
||||
// Include tool information for context
|
||||
let tool_info = match tool_part.state {
|
||||
opencode_client::ToolState::Completed { output, title, .. } => {
|
||||
format!("[Tool: {}]\n{}\nOutput: {}", tool_part.tool, title, output)
|
||||
}
|
||||
opencode_client::ToolState::Error { error, .. } => {
|
||||
format!("[Tool: {} - Error]\n{}", tool_part.tool, error)
|
||||
}
|
||||
opencode_client::ToolState::Running { title, .. } => {
|
||||
format!(
|
||||
"[Tool: {} - Running]\n{}",
|
||||
tool_part.tool,
|
||||
title.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
opencode_client::ToolState::Pending => {
|
||||
format!("[Tool: {} - Pending]", tool_part.tool)
|
||||
}
|
||||
};
|
||||
content_parts.push(tool_info);
|
||||
}
|
||||
// Ignore other part types for v0.1
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// If no content, skip this message
|
||||
if content_parts.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let content = content_parts.join("\n\n");
|
||||
|
||||
// Convert timestamp
|
||||
let created_at = chrono::DateTime::from_timestamp_millis(created_ms as i64)
|
||||
.map(|dt| dt.to_rfc3339())
|
||||
.unwrap_or_else(|| chrono::Utc::now().to_rfc3339());
|
||||
|
||||
// Generate message ID
|
||||
let message_id = format!("msg-{}", index);
|
||||
|
||||
// Parent ID for threading (assistant messages follow user messages)
|
||||
let parent_id = if role == MessageRole::Assistant && index > 0 {
|
||||
Some(format!("msg-{}", index - 1))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Some(Message {
|
||||
id: message_id,
|
||||
session_id: session_id.to_string(),
|
||||
role,
|
||||
content,
|
||||
created_at,
|
||||
parent_id,
|
||||
metadata: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create default responder configuration.
|
||||
///
|
||||
/// Uses Echo strategy for simplicity.
|
||||
fn create_default_responders() -> Responders {
|
||||
Responders {
|
||||
keyword_map: HashMap::new(),
|
||||
default_strategy: ResponderStrategy::Echo,
|
||||
random: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create default streaming configuration.
|
||||
fn create_default_streaming() -> Streaming {
|
||||
Streaming {
|
||||
enabled: true,
|
||||
tokens_per_chunk: 5,
|
||||
chunk_interval_ms: 50,
|
||||
jitter_ms: Some(10),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Fixture Merging
|
||||
// ============================================================================
|
||||
|
||||
/// Merge two fixtures together.
|
||||
///
|
||||
/// Combines sessions from both fixtures, handling duplicate session IDs
|
||||
/// by renaming with a numeric suffix.
|
||||
fn merge_fixtures(existing: Fixture, new: Fixture) -> Result<Fixture> {
|
||||
let mut merged = existing.clone();
|
||||
let mut existing_ids: HashMap<String, usize> = HashMap::new();
|
||||
|
||||
// Track existing session IDs
|
||||
for session in &merged.sessions {
|
||||
existing_ids.insert(session.id.clone(), 0);
|
||||
}
|
||||
|
||||
// Add new sessions, renaming duplicates
|
||||
for mut session in new.sessions {
|
||||
if existing_ids.contains_key(&session.id) {
|
||||
// Find a unique ID by appending a suffix
|
||||
let original_id = session.id.clone();
|
||||
let mut suffix = 1;
|
||||
let mut new_id = format!("{}-{}", original_id, suffix);
|
||||
|
||||
while existing_ids.contains_key(&new_id) {
|
||||
suffix += 1;
|
||||
new_id = format!("{}-{}", original_id, suffix);
|
||||
}
|
||||
|
||||
tracing::info!("Renaming duplicate session {} to {}", original_id, new_id);
|
||||
|
||||
// Update session ID and message session IDs
|
||||
session.id = new_id.clone();
|
||||
for message in &mut session.messages {
|
||||
message.session_id = new_id.clone();
|
||||
}
|
||||
|
||||
existing_ids.insert(new_id, 0);
|
||||
} else {
|
||||
existing_ids.insert(session.id.clone(), 0);
|
||||
}
|
||||
|
||||
merged.sessions.push(session);
|
||||
}
|
||||
|
||||
// Merge keyword maps (new takes precedence)
|
||||
for (keyword, response) in new.responders.keyword_map {
|
||||
merged.responders.keyword_map.insert(keyword, response);
|
||||
}
|
||||
|
||||
// Use new fixture's responder strategy and streaming if different
|
||||
// (In practice, we use the existing fixture's settings)
|
||||
|
||||
Ok(merged)
|
||||
}
|
||||
|
||||
/// Load an existing fixture from a file, or create a new empty one.
|
||||
async fn load_or_create_fixture(path: &Path) -> Result<Fixture> {
|
||||
if path.exists() {
|
||||
tracing::debug!("Loading existing fixture from: {}", path.display());
|
||||
let content = tokio::fs::read_to_string(path).await.map_err(|e| {
|
||||
MockerError::FixtureLoad(format!("Failed to read fixture file: {}", e))
|
||||
})?;
|
||||
|
||||
let fixture: Fixture = serde_yaml::from_str(&content).map_err(|e| {
|
||||
MockerError::FixtureLoad(format!("Failed to parse fixture YAML: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(fixture)
|
||||
} else {
|
||||
tracing::debug!("Creating new empty fixture");
|
||||
Ok(Fixture {
|
||||
version: "0.1".to_string(),
|
||||
sessions: Vec::new(),
|
||||
responders: create_default_responders(),
|
||||
streaming: create_default_streaming(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// YAML Export
|
||||
// ============================================================================
|
||||
|
||||
/// Export a fixture to a YAML file.
|
||||
///
|
||||
/// Creates parent directories if they don't exist.
|
||||
async fn export_fixture(fixture: &Fixture, path: &Path) -> Result<()> {
|
||||
// Create parent directories if necessary
|
||||
if let Some(parent) = path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await.map_err(|e| {
|
||||
MockerError::FixtureLoad(format!("Failed to create parent directories: {}", e))
|
||||
})?;
|
||||
}
|
||||
|
||||
// Serialize to YAML with pretty formatting
|
||||
let yaml = serde_yaml::to_string(fixture).map_err(|e| {
|
||||
MockerError::FixtureLoad(format!("Failed to serialize fixture to YAML: {}", e))
|
||||
})?;
|
||||
|
||||
// Write to file
|
||||
tokio::fs::write(path, yaml).await.map_err(|e| {
|
||||
MockerError::Transport(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
format!("Failed to write fixture file: {}", e),
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_create_default_responders() {
|
||||
let responders = create_default_responders();
|
||||
assert_eq!(responders.default_strategy, ResponderStrategy::Echo);
|
||||
assert!(responders.keyword_map.is_empty());
|
||||
assert!(responders.random.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_create_default_streaming() {
|
||||
let streaming = create_default_streaming();
|
||||
assert!(streaming.enabled);
|
||||
assert_eq!(streaming.tokens_per_chunk, 5);
|
||||
assert_eq!(streaming.chunk_interval_ms, 50);
|
||||
assert_eq!(streaming.jitter_ms, Some(10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_fixtures_no_duplicates() {
|
||||
let existing = Fixture {
|
||||
version: "0.1".to_string(),
|
||||
sessions: vec![Session {
|
||||
id: "session-1".to_string(),
|
||||
title: "Session 1".to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
participants: vec![],
|
||||
messages: vec![],
|
||||
behavior: None,
|
||||
}],
|
||||
responders: create_default_responders(),
|
||||
streaming: create_default_streaming(),
|
||||
};
|
||||
|
||||
let new = Fixture {
|
||||
version: "0.1".to_string(),
|
||||
sessions: vec![Session {
|
||||
id: "session-2".to_string(),
|
||||
title: "Session 2".to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
participants: vec![],
|
||||
messages: vec![],
|
||||
behavior: None,
|
||||
}],
|
||||
responders: create_default_responders(),
|
||||
streaming: create_default_streaming(),
|
||||
};
|
||||
|
||||
let merged = merge_fixtures(existing, new).unwrap();
|
||||
assert_eq!(merged.sessions.len(), 2);
|
||||
assert_eq!(merged.sessions[0].id, "session-1");
|
||||
assert_eq!(merged.sessions[1].id, "session-2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_fixtures_with_duplicates() {
|
||||
let existing = Fixture {
|
||||
version: "0.1".to_string(),
|
||||
sessions: vec![Session {
|
||||
id: "session-1".to_string(),
|
||||
title: "Session 1".to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
participants: vec![],
|
||||
messages: vec![],
|
||||
behavior: None,
|
||||
}],
|
||||
responders: create_default_responders(),
|
||||
streaming: create_default_streaming(),
|
||||
};
|
||||
|
||||
let new = Fixture {
|
||||
version: "0.1".to_string(),
|
||||
sessions: vec![Session {
|
||||
id: "session-1".to_string(),
|
||||
title: "Session 1 (New)".to_string(),
|
||||
created_at: "2025-01-02T00:00:00Z".to_string(),
|
||||
participants: vec![],
|
||||
messages: vec![],
|
||||
behavior: None,
|
||||
}],
|
||||
responders: create_default_responders(),
|
||||
streaming: create_default_streaming(),
|
||||
};
|
||||
|
||||
let merged = merge_fixtures(existing, new).unwrap();
|
||||
assert_eq!(merged.sessions.len(), 2);
|
||||
assert_eq!(merged.sessions[0].id, "session-1");
|
||||
assert_eq!(merged.sessions[1].id, "session-1-1"); // Renamed
|
||||
}
|
||||
}
|
||||
+894
@@ -0,0 +1,894 @@
|
||||
//! # dirigate
|
||||
//!
|
||||
//! ACP (Agent-Client Protocol) mock agent server for testing clients without real agents.
|
||||
//!
|
||||
//! This library provides a configurable mock server that responds to ACP requests
|
||||
//! based on YAML fixture definitions. It supports:
|
||||
//!
|
||||
//! - **Fixture-based responses**: Define sessions and message flows in YAML
|
||||
//! - **Multiple response modes**: Static, random, sequential, and pattern-based
|
||||
//! - **Session ingestion**: Import sessions from OpenCode.ai or other sources (feature-gated)
|
||||
//! - **ACP compliance**: Full protocol implementation for testing clients
|
||||
//!
|
||||
//! ## Usage
|
||||
//!
|
||||
//! As a library:
|
||||
//! ```rust,no_run
|
||||
//! use dirigate::Result;
|
||||
//!
|
||||
//! # async fn example() -> Result<()> {
|
||||
//! // TODO: Add usage example once API is implemented
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
//!
|
||||
//! As a CLI tool:
|
||||
//! ```bash
|
||||
//! dirigate serve --fixtures ./fixtures --port 8080
|
||||
//! ```
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
use rand::SeedableRng;
|
||||
use rand_chacha::ChaCha8Rng;
|
||||
|
||||
// Core modules
|
||||
pub mod error;
|
||||
pub mod logging;
|
||||
|
||||
// ACP server implementation
|
||||
pub mod acp;
|
||||
|
||||
// Fixture system
|
||||
pub mod fixture;
|
||||
|
||||
// Optional ingestion module (feature-gated)
|
||||
#[cfg(feature = "ingest")]
|
||||
pub mod ingest;
|
||||
|
||||
// CLI module
|
||||
pub mod cli;
|
||||
|
||||
// Re-export key types
|
||||
pub use error::{MockerError, Result};
|
||||
pub use acp::stream::{chunk_text, StreamConfig, StreamController, StreamEvent};
|
||||
pub use fixture::types::{Fixture, Message, Participant};
|
||||
pub use fixture::responders::Responder;
|
||||
|
||||
// ============================================================================
|
||||
// Configuration Types
|
||||
// ============================================================================
|
||||
|
||||
/// Server configuration for the ACP mocker.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MockerConfig {
|
||||
/// Port to bind the server to.
|
||||
pub port: u16,
|
||||
|
||||
/// Host address to bind to.
|
||||
pub host: String,
|
||||
|
||||
/// Default streaming configuration.
|
||||
pub default_stream_config: StreamConfig,
|
||||
}
|
||||
|
||||
impl Default for MockerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
port: 8080,
|
||||
host: "127.0.0.1".to_string(),
|
||||
default_stream_config: StreamConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MockerConfig {
|
||||
/// Create a new mocker configuration.
|
||||
pub fn new(port: u16, host: impl Into<String>) -> Self {
|
||||
Self {
|
||||
port,
|
||||
host: host.into(),
|
||||
default_stream_config: StreamConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the default streaming configuration.
|
||||
pub fn with_stream_config(mut self, config: StreamConfig) -> Self {
|
||||
self.default_stream_config = config;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session State
|
||||
// ============================================================================
|
||||
|
||||
/// Runtime state for a single session.
|
||||
///
|
||||
/// Contains all information about an active session, including its messages,
|
||||
/// participants, and the responder used to generate responses.
|
||||
pub struct SessionState {
|
||||
/// Session identifier.
|
||||
pub id: String,
|
||||
|
||||
/// Human-readable session title.
|
||||
pub title: String,
|
||||
|
||||
/// ISO8601 timestamp when session was created.
|
||||
pub created_at: String,
|
||||
|
||||
/// Participants in this session.
|
||||
pub participants: Vec<Participant>,
|
||||
|
||||
/// Messages in this session (grows as turns progress).
|
||||
pub messages: Vec<Message>,
|
||||
|
||||
/// Responder assigned to this session.
|
||||
responder: Box<dyn Responder>,
|
||||
|
||||
/// Stream controller for ongoing streams (if any).
|
||||
pub stream_controller: Option<StreamController>,
|
||||
}
|
||||
|
||||
impl SessionState {
|
||||
/// Create a new session state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `id` - Unique session identifier
|
||||
/// * `title` - Human-readable session title
|
||||
/// * `created_at` - ISO8601 timestamp
|
||||
/// * `participants` - List of session participants
|
||||
/// * `messages` - Initial message list
|
||||
/// * `responder` - Responder for generating responses
|
||||
pub fn new(
|
||||
id: String,
|
||||
title: String,
|
||||
created_at: String,
|
||||
participants: Vec<Participant>,
|
||||
messages: Vec<Message>,
|
||||
responder: Box<dyn Responder>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
title,
|
||||
created_at,
|
||||
participants,
|
||||
messages,
|
||||
responder,
|
||||
stream_controller: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a mutable reference to the responder.
|
||||
pub fn responder_mut(&mut self) -> &mut Box<dyn Responder> {
|
||||
&mut self.responder
|
||||
}
|
||||
|
||||
/// Add a message to this session.
|
||||
pub fn add_message(&mut self, message: Message) {
|
||||
tracing::debug!(
|
||||
session_id = %self.id,
|
||||
message_id = %message.id,
|
||||
role = ?message.role,
|
||||
"Adding message to session"
|
||||
);
|
||||
self.messages.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SessionState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SessionState")
|
||||
.field("id", &self.id)
|
||||
.field("title", &self.title)
|
||||
.field("created_at", &self.created_at)
|
||||
.field("participants", &self.participants)
|
||||
.field("message_count", &self.messages.len())
|
||||
.field("has_stream_controller", &self.stream_controller.is_some())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Mocker State
|
||||
// ============================================================================
|
||||
|
||||
/// Runtime state for the ACP mocker.
|
||||
///
|
||||
/// Manages active sessions, fixture data, and configuration. This is the
|
||||
/// central coordination point for the mock server.
|
||||
#[derive(Clone)]
|
||||
pub struct MockerState {
|
||||
/// Active sessions (thread-safe).
|
||||
sessions: Arc<RwLock<HashMap<String, SessionState>>>,
|
||||
|
||||
/// Loaded fixtures (immutable after creation).
|
||||
fixtures: Arc<Fixture>,
|
||||
|
||||
/// Server configuration.
|
||||
config: MockerConfig,
|
||||
|
||||
/// Global random number generator for ID generation.
|
||||
_global_rng: Arc<Mutex<ChaCha8Rng>>,
|
||||
|
||||
/// Broadcast channel for SSE notifications.
|
||||
/// Clients subscribe to this channel to receive session updates.
|
||||
event_tx: Arc<tokio::sync::broadcast::Sender<SseNotification>>,
|
||||
}
|
||||
|
||||
/// SSE notification message.
|
||||
///
|
||||
/// Represents a server-sent event that is broadcast to all connected clients.
|
||||
/// Clients filter events based on session_id to only process relevant updates.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SseNotification {
|
||||
/// Session ID this event relates to.
|
||||
pub session_id: String,
|
||||
|
||||
/// JSON-RPC notification payload.
|
||||
pub notification: String,
|
||||
}
|
||||
|
||||
impl MockerState {
|
||||
/// Create a new mocker state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - Server configuration
|
||||
/// * `fixtures` - Loaded fixture data
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new mocker state instance ready for use
|
||||
pub fn new(config: MockerConfig, fixtures: Fixture) -> Self {
|
||||
// Create broadcast channel for SSE events
|
||||
// Read buffer size from environment variable, default to 100
|
||||
let capacity = std::env::var("CONDUCTOR_BUFFER_SIZE")
|
||||
.ok()
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.unwrap_or(100);
|
||||
let (event_tx, _) = tokio::sync::broadcast::channel(capacity);
|
||||
|
||||
Self {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
fixtures: Arc::new(fixtures),
|
||||
config,
|
||||
_global_rng: Arc::new(Mutex::new(ChaCha8Rng::from_entropy())),
|
||||
event_tx: Arc::new(event_tx),
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to SSE notifications.
|
||||
///
|
||||
/// Returns a receiver that can be used to listen for session update events.
|
||||
pub fn subscribe_events(&self) -> tokio::sync::broadcast::Receiver<SseNotification> {
|
||||
self.event_tx.subscribe()
|
||||
}
|
||||
|
||||
/// Broadcast an SSE notification to all connected clients.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `notification` - The SSE notification to broadcast
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This is a fire-and-forget operation. If no clients are connected,
|
||||
/// the notification is simply dropped.
|
||||
pub fn broadcast_event(&self, notification: SseNotification) {
|
||||
// send() returns Err if there are no receivers, which is fine
|
||||
let _ = self.event_tx.send(notification);
|
||||
}
|
||||
|
||||
/// Create a new session.
|
||||
///
|
||||
/// Creates a new session, optionally based on a fixture template.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `template_id` - Optional fixture template ID to base the session on
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The ID of the newly created session
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the template is not found or session creation fails
|
||||
pub async fn create_session(&self, template_id: Option<String>) -> Result<String> {
|
||||
// Generate a new session ID
|
||||
let session_id = self.generate_session_id().await;
|
||||
|
||||
tracing::info!(
|
||||
session_id = %session_id,
|
||||
template_id = ?template_id,
|
||||
"Creating new session"
|
||||
);
|
||||
|
||||
// Load fixture template if specified
|
||||
let session_fixture = if let Some(template_id) = &template_id {
|
||||
self.fixtures
|
||||
.sessions
|
||||
.iter()
|
||||
.find(|s| &s.id == template_id)
|
||||
.ok_or_else(|| {
|
||||
MockerError::FixtureValidation(format!(
|
||||
"Template session not found: {}",
|
||||
template_id
|
||||
))
|
||||
})?
|
||||
.clone()
|
||||
} else {
|
||||
// Create a minimal default session
|
||||
fixture::types::Session {
|
||||
id: session_id.clone(),
|
||||
title: "New Session".to_string(),
|
||||
created_at: chrono::Utc::now().to_rfc3339(),
|
||||
participants: vec![],
|
||||
messages: vec![],
|
||||
behavior: None,
|
||||
}
|
||||
};
|
||||
|
||||
// Create responder for this session
|
||||
let responder = fixture::responders::ResponderFactory::create_for_session(
|
||||
&self.fixtures.responders,
|
||||
&session_fixture,
|
||||
)?;
|
||||
|
||||
// Create session state
|
||||
let session_state = SessionState::new(
|
||||
session_id.clone(),
|
||||
session_fixture.title,
|
||||
session_fixture.created_at,
|
||||
session_fixture.participants,
|
||||
session_fixture.messages,
|
||||
responder,
|
||||
);
|
||||
|
||||
// Store the session
|
||||
let mut sessions = self.sessions.write().await;
|
||||
sessions.insert(session_id.clone(), session_state);
|
||||
|
||||
Ok(session_id)
|
||||
}
|
||||
|
||||
/// Load a session from fixtures.
|
||||
///
|
||||
/// Loads an existing session from the fixture data. The session must
|
||||
/// exist in the fixtures.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - ID of the session to load from fixtures
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A clone of the session state (not a reference, to avoid lock issues)
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `SessionNotFound` if the session doesn't exist in fixtures
|
||||
pub async fn load_session(&self, session_id: &str) -> Result<SessionState> {
|
||||
tracing::info!(session_id, "Loading session from fixtures");
|
||||
|
||||
// Find session in fixtures
|
||||
let session_fixture = self
|
||||
.fixtures
|
||||
.sessions
|
||||
.iter()
|
||||
.find(|s| s.id == session_id)
|
||||
.ok_or_else(|| MockerError::SessionNotFound {
|
||||
session_id: session_id.to_string(),
|
||||
})?;
|
||||
|
||||
// Create responder for this session
|
||||
let responder = fixture::responders::ResponderFactory::create_for_session(
|
||||
&self.fixtures.responders,
|
||||
session_fixture,
|
||||
)?;
|
||||
|
||||
// Create session state
|
||||
let session_state = SessionState::new(
|
||||
session_fixture.id.clone(),
|
||||
session_fixture.title.clone(),
|
||||
session_fixture.created_at.clone(),
|
||||
session_fixture.participants.clone(),
|
||||
session_fixture.messages.clone(),
|
||||
responder,
|
||||
);
|
||||
|
||||
// Store in active sessions
|
||||
let mut sessions = self.sessions.write().await;
|
||||
sessions.insert(session_id.to_string(), session_state);
|
||||
|
||||
// Return a clone (we need to drop the write lock first)
|
||||
drop(sessions);
|
||||
self.get_session(session_id).await
|
||||
}
|
||||
|
||||
/// Get a session from active sessions.
|
||||
///
|
||||
/// Retrieves a session that has been created or loaded.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - ID of the session to retrieve
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A clone of the session state
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `SessionNotFound` if the session is not active
|
||||
pub async fn get_session(&self, session_id: &str) -> Result<SessionState> {
|
||||
let sessions = self.sessions.read().await;
|
||||
|
||||
// We can't return a reference due to the lock, so we need to clone
|
||||
// In a real implementation, you might want to use Arc<SessionState> or
|
||||
// return specific data rather than the whole state
|
||||
sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| MockerError::SessionNotFound {
|
||||
session_id: session_id.to_string(),
|
||||
})
|
||||
.map(|s| {
|
||||
// Create a shallow clone for return
|
||||
// Note: This is a temporary solution. In practice, you'd want
|
||||
// to either use Arc or return specific fields
|
||||
SessionState {
|
||||
id: s.id.clone(),
|
||||
title: s.title.clone(),
|
||||
created_at: s.created_at.clone(),
|
||||
participants: s.participants.clone(),
|
||||
messages: s.messages.clone(),
|
||||
responder: fixture::responders::ResponderFactory::create_for_session(
|
||||
&self.fixtures.responders,
|
||||
&fixture::types::Session {
|
||||
id: s.id.clone(),
|
||||
title: s.title.clone(),
|
||||
created_at: s.created_at.clone(),
|
||||
participants: s.participants.clone(),
|
||||
messages: s.messages.clone(),
|
||||
behavior: None,
|
||||
},
|
||||
)
|
||||
.unwrap_or_else(|_| {
|
||||
Box::new(fixture::responders::EchoResponder)
|
||||
}),
|
||||
stream_controller: None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a message to a session.
|
||||
///
|
||||
/// Adds a message to an active session's message history.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - ID of the session
|
||||
/// * `message` - The message to add
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `SessionNotFound` if the session is not active
|
||||
pub async fn add_message(&self, session_id: &str, message: Message) -> Result<()> {
|
||||
let mut sessions = self.sessions.write().await;
|
||||
|
||||
let session = sessions
|
||||
.get_mut(session_id)
|
||||
.ok_or_else(|| MockerError::SessionNotFound {
|
||||
session_id: session_id.to_string(),
|
||||
})?;
|
||||
|
||||
session.add_message(message);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Cancel an ongoing stream for a session.
|
||||
///
|
||||
/// Sets the cancellation flag on the session's stream controller,
|
||||
/// causing streaming to stop.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `session_id` - ID of the session to cancel
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `SessionNotFound` if the session is not active
|
||||
pub async fn cancel_stream(&self, session_id: &str) -> Result<()> {
|
||||
let sessions = self.sessions.read().await;
|
||||
|
||||
let session = sessions
|
||||
.get(session_id)
|
||||
.ok_or_else(|| MockerError::SessionNotFound {
|
||||
session_id: session_id.to_string(),
|
||||
})?;
|
||||
|
||||
if let Some(controller) = &session.stream_controller {
|
||||
tracing::info!(session_id, "Cancelling stream");
|
||||
controller.cancel();
|
||||
} else {
|
||||
tracing::warn!(session_id, "No active stream to cancel");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the fixture data.
|
||||
pub fn fixtures(&self) -> &Fixture {
|
||||
&self.fixtures
|
||||
}
|
||||
|
||||
/// Get the server configuration.
|
||||
pub fn config(&self) -> &MockerConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Generate a new session ID.
|
||||
async fn generate_session_id(&self) -> String {
|
||||
uuid::Uuid::new_v4().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for MockerState {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("MockerState")
|
||||
.field("config", &self.config)
|
||||
.field("fixture_count", &self.fixtures.sessions.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use fixture::types::{
|
||||
MessageRole, Participant, ParticipantKind, Responders, ResponderStrategy,
|
||||
Session, Streaming,
|
||||
};
|
||||
|
||||
// Helper to create a minimal test fixture
|
||||
fn create_test_fixture() -> Fixture {
|
||||
Fixture {
|
||||
version: "0.1".to_string(),
|
||||
sessions: vec![
|
||||
Session {
|
||||
id: "test-session-1".to_string(),
|
||||
title: "Test Session 1".to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
participants: vec![Participant {
|
||||
id: "user-1".to_string(),
|
||||
kind: ParticipantKind::User,
|
||||
display_name: Some("Test User".to_string()),
|
||||
}],
|
||||
messages: vec![],
|
||||
behavior: None,
|
||||
},
|
||||
Session {
|
||||
id: "test-session-2".to_string(),
|
||||
title: "Test Session 2".to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
participants: vec![],
|
||||
messages: vec![Message {
|
||||
id: "msg-1".to_string(),
|
||||
session_id: "test-session-2".to_string(),
|
||||
role: MessageRole::Assistant,
|
||||
content: "Hello!".to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
parent_id: None,
|
||||
metadata: None,
|
||||
}],
|
||||
behavior: None,
|
||||
},
|
||||
],
|
||||
responders: Responders {
|
||||
keyword_map: std::collections::HashMap::new(),
|
||||
default_strategy: ResponderStrategy::Echo,
|
||||
random: None,
|
||||
},
|
||||
streaming: Streaming {
|
||||
enabled: true,
|
||||
tokens_per_chunk: 5,
|
||||
chunk_interval_ms: 100,
|
||||
jitter_ms: Some(10),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// MockerConfig Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_mocker_config_default() {
|
||||
let config = MockerConfig::default();
|
||||
assert_eq!(config.port, 8080);
|
||||
assert_eq!(config.host, "127.0.0.1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mocker_config_new() {
|
||||
let config = MockerConfig::new(3000, "0.0.0.0");
|
||||
assert_eq!(config.port, 3000);
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mocker_config_with_stream_config() {
|
||||
let stream_config = StreamConfig::new(10, 200);
|
||||
let config = MockerConfig::default().with_stream_config(stream_config);
|
||||
assert_eq!(config.default_stream_config.tokens_per_chunk, 10);
|
||||
assert_eq!(config.default_stream_config.chunk_interval_ms, 200);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// SessionState Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_session_state_creation() {
|
||||
let responder = Box::new(fixture::responders::EchoResponder);
|
||||
let session = SessionState::new(
|
||||
"test-id".to_string(),
|
||||
"Test Title".to_string(),
|
||||
"2025-01-01T00:00:00Z".to_string(),
|
||||
vec![],
|
||||
vec![],
|
||||
responder,
|
||||
);
|
||||
|
||||
assert_eq!(session.id, "test-id");
|
||||
assert_eq!(session.title, "Test Title");
|
||||
assert_eq!(session.messages.len(), 0);
|
||||
assert!(session.stream_controller.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_state_add_message() {
|
||||
let responder = Box::new(fixture::responders::EchoResponder);
|
||||
let mut session = SessionState::new(
|
||||
"test-id".to_string(),
|
||||
"Test Title".to_string(),
|
||||
"2025-01-01T00:00:00Z".to_string(),
|
||||
vec![],
|
||||
vec![],
|
||||
responder,
|
||||
);
|
||||
|
||||
let message = Message {
|
||||
id: "msg-1".to_string(),
|
||||
session_id: "test-id".to_string(),
|
||||
role: MessageRole::User,
|
||||
content: "Hello".to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
parent_id: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
session.add_message(message);
|
||||
assert_eq!(session.messages.len(), 1);
|
||||
assert_eq!(session.messages[0].content, "Hello");
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// MockerState Tests
|
||||
// ========================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_creation() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
assert_eq!(state.fixtures().sessions.len(), 2);
|
||||
assert_eq!(state.config().port, 8080);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_create_session_without_template() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
let session_id = state.create_session(None).await.unwrap();
|
||||
|
||||
assert!(!session_id.is_empty());
|
||||
|
||||
// Session should be retrievable
|
||||
let session = state.get_session(&session_id).await.unwrap();
|
||||
assert_eq!(session.id, session_id);
|
||||
assert_eq!(session.title, "New Session");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_create_session_with_template() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
let session_id = state
|
||||
.create_session(Some("test-session-1".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!session_id.is_empty());
|
||||
|
||||
// Session should have template data
|
||||
let session = state.get_session(&session_id).await.unwrap();
|
||||
assert_eq!(session.title, "Test Session 1");
|
||||
assert_eq!(session.participants.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_create_session_invalid_template() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
let result = state
|
||||
.create_session(Some("non-existent-template".to_string()))
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(MockerError::FixtureValidation(msg)) => {
|
||||
assert!(msg.contains("Template session not found"));
|
||||
}
|
||||
_ => panic!("Expected FixtureValidation error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_load_session() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
let session = state.load_session("test-session-1").await.unwrap();
|
||||
|
||||
assert_eq!(session.id, "test-session-1");
|
||||
assert_eq!(session.title, "Test Session 1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_load_session_not_found() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
let result = state.load_session("non-existent").await;
|
||||
|
||||
assert!(result.is_err());
|
||||
match result {
|
||||
Err(MockerError::SessionNotFound { session_id }) => {
|
||||
assert_eq!(session_id, "non-existent");
|
||||
}
|
||||
_ => panic!("Expected SessionNotFound error"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_get_session() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
// Create a session first
|
||||
let session_id = state.create_session(None).await.unwrap();
|
||||
|
||||
// Should be able to retrieve it
|
||||
let session = state.get_session(&session_id).await.unwrap();
|
||||
assert_eq!(session.id, session_id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_get_session_not_found() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
let result = state.get_session("non-existent").await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_add_message() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
// Create a session
|
||||
let session_id = state.create_session(None).await.unwrap();
|
||||
|
||||
// Add a message
|
||||
let message = Message {
|
||||
id: "msg-1".to_string(),
|
||||
session_id: session_id.clone(),
|
||||
role: MessageRole::User,
|
||||
content: "Test message".to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
parent_id: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
state.add_message(&session_id, message).await.unwrap();
|
||||
|
||||
// Verify message was added
|
||||
let session = state.get_session(&session_id).await.unwrap();
|
||||
assert_eq!(session.messages.len(), 1);
|
||||
assert_eq!(session.messages[0].content, "Test message");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_add_message_not_found() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
let message = Message {
|
||||
id: "msg-1".to_string(),
|
||||
session_id: "non-existent".to_string(),
|
||||
role: MessageRole::User,
|
||||
content: "Test message".to_string(),
|
||||
created_at: "2025-01-01T00:00:00Z".to_string(),
|
||||
parent_id: None,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let result = state.add_message("non-existent", message).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_cancel_stream() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
// Create a session
|
||||
let session_id = state.create_session(None).await.unwrap();
|
||||
|
||||
// Cancel stream (should not error even if no stream exists)
|
||||
let result = state.cancel_stream(&session_id).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mocker_state_cancel_stream_not_found() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
let result = state.cancel_stream("non-existent").await;
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_id_generation_is_unique() {
|
||||
let config = MockerConfig::default();
|
||||
let fixtures = create_test_fixture();
|
||||
let state = MockerState::new(config, fixtures);
|
||||
|
||||
let id1 = state.create_session(None).await.unwrap();
|
||||
let id2 = state.create_session(None).await.unwrap();
|
||||
let id3 = state.create_session(None).await.unwrap();
|
||||
|
||||
assert_ne!(id1, id2);
|
||||
assert_ne!(id2, id3);
|
||||
assert_ne!(id1, id3);
|
||||
}
|
||||
}
|
||||
+168
@@ -0,0 +1,168 @@
|
||||
//! Logging infrastructure for the ACP mocker.
|
||||
//!
|
||||
//! This module provides utilities for initializing and configuring structured logging
|
||||
//! using the `tracing` ecosystem. It supports multiple output formats and log levels
|
||||
//! configurable via environment variables.
|
||||
//!
|
||||
//! ## Important: stdio Mode
|
||||
//!
|
||||
//! When using stdio transport (the primary ACP communication method), ALL logs are
|
||||
//! automatically written to stderr to keep stdout clean for JSON-RPC messages.
|
||||
|
||||
use std::fs::OpenOptions;
|
||||
use std::sync::Arc;
|
||||
use tracing_subscriber::{fmt, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
|
||||
/// Output format for logs.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum LogFormat {
|
||||
/// Human-readable pretty format with colors (default for development).
|
||||
#[default]
|
||||
Pretty,
|
||||
/// JSON format for structured logging (recommended for production/CI).
|
||||
Json,
|
||||
/// Compact format for minimal output.
|
||||
Compact,
|
||||
}
|
||||
|
||||
/// Initialize the logging system with the specified format.
|
||||
///
|
||||
/// # Log Levels
|
||||
///
|
||||
/// Log levels can be configured via the `RUST_LOG` environment variable:
|
||||
/// - `RUST_LOG=trace` - Most verbose, includes all internal details
|
||||
/// - `RUST_LOG=debug` - Detailed debugging information
|
||||
/// - `RUST_LOG=info` - General informational messages (default)
|
||||
/// - `RUST_LOG=warn` - Warning messages only
|
||||
/// - `RUST_LOG=error` - Error messages only
|
||||
///
|
||||
/// You can also filter by module:
|
||||
/// ```bash
|
||||
/// RUST_LOG=dirigate=debug,axum=info
|
||||
/// ```
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use dirigate::logging::{init_logging, LogFormat};
|
||||
///
|
||||
/// // Initialize with pretty format (default)
|
||||
/// init_logging(LogFormat::Pretty);
|
||||
///
|
||||
/// // Initialize with JSON format for production
|
||||
/// init_logging(LogFormat::Json);
|
||||
/// ```
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if logging has already been initialized. This should be called
|
||||
/// exactly once at the start of the application.
|
||||
pub fn init_logging(format: LogFormat) {
|
||||
let env_filter = EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| EnvFilter::new("dirigate=info,axum=info"));
|
||||
|
||||
// Try to create log file in system temp directory
|
||||
// If this fails, we'll just log to stderr only
|
||||
let log_file_path = std::env::temp_dir().join("dirigate.log");
|
||||
let file_writer = OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.open(&log_file_path)
|
||||
.ok()
|
||||
.map(Arc::new);
|
||||
|
||||
// CRITICAL: All logs go to stderr (and optionally to file if available)
|
||||
// In stdio mode, stdout is reserved for JSON-RPC messages only
|
||||
match (format, file_writer) {
|
||||
(LogFormat::Pretty, Some(file_writer)) => {
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(fmt::layer().pretty().with_writer(std::io::stderr))
|
||||
.with(fmt::layer().with_ansi(false).with_writer(file_writer))
|
||||
.init();
|
||||
tracing::info!(
|
||||
"Logging initialized with Pretty format, writing to {:?}",
|
||||
log_file_path
|
||||
);
|
||||
}
|
||||
(LogFormat::Pretty, None) => {
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(fmt::layer().pretty().with_writer(std::io::stderr))
|
||||
.init();
|
||||
tracing::info!("Logging initialized with Pretty format (file logging unavailable)");
|
||||
}
|
||||
(LogFormat::Json, Some(file_writer)) => {
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(fmt::layer().json().with_writer(std::io::stderr))
|
||||
.with(
|
||||
fmt::layer()
|
||||
.json()
|
||||
.with_ansi(false)
|
||||
.with_writer(file_writer),
|
||||
)
|
||||
.init();
|
||||
tracing::info!(
|
||||
"Logging initialized with Json format, writing to {:?}",
|
||||
log_file_path
|
||||
);
|
||||
}
|
||||
(LogFormat::Json, None) => {
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(fmt::layer().json().with_writer(std::io::stderr))
|
||||
.init();
|
||||
tracing::info!("Logging initialized with Json format (file logging unavailable)");
|
||||
}
|
||||
(LogFormat::Compact, Some(file_writer)) => {
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(fmt::layer().compact().with_writer(std::io::stderr))
|
||||
.with(
|
||||
fmt::layer()
|
||||
.compact()
|
||||
.with_ansi(false)
|
||||
.with_writer(file_writer),
|
||||
)
|
||||
.init();
|
||||
tracing::info!(
|
||||
"Logging initialized with Compact format, writing to {:?}",
|
||||
log_file_path
|
||||
);
|
||||
}
|
||||
(LogFormat::Compact, None) => {
|
||||
tracing_subscriber::registry()
|
||||
.with(env_filter)
|
||||
.with(fmt::layer().compact().with_writer(std::io::stderr))
|
||||
.init();
|
||||
tracing::info!("Logging initialized with Compact format (file logging unavailable)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize logging with default settings (pretty format, info level).
|
||||
///
|
||||
/// This is a convenience function equivalent to `init_logging(LogFormat::Pretty)`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use dirigate::logging::init_default_logging;
|
||||
///
|
||||
/// init_default_logging();
|
||||
/// ```
|
||||
pub fn init_default_logging() {
|
||||
init_logging(LogFormat::default());
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_log_format_default() {
|
||||
assert_eq!(LogFormat::default(), LogFormat::Pretty);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user