sync from monorepo @ 2452e92e
This commit is contained in:
@@ -0,0 +1,338 @@
|
||||
//! Common test utilities for bidirectional flow testing.
|
||||
//!
|
||||
//! This module provides mock implementations and test helpers for testing
|
||||
//! the bidirectional request/response flow in the ACP Server.
|
||||
|
||||
use anyhow::Result;
|
||||
use serde_json::{json, Value};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, oneshot, Mutex};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Mock event for testing event bridge
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MockEvent {
|
||||
pub event_type: String,
|
||||
pub data: Value,
|
||||
}
|
||||
|
||||
/// Mock SSE client for testing
|
||||
///
|
||||
/// Simulates an HTTP client that receives SSE events and posts responses.
|
||||
pub struct MockSseClient {
|
||||
pub client_id: String,
|
||||
/// Events received via SSE
|
||||
pub received_events: Arc<Mutex<Vec<MockEvent>>>,
|
||||
/// Sender for simulating HTTP POST responses
|
||||
pub response_tx: mpsc::UnboundedSender<(Value, oneshot::Sender<Result<()>>)>,
|
||||
}
|
||||
|
||||
impl MockSseClient {
|
||||
pub fn new(client_id: String) -> (Self, mpsc::UnboundedReceiver<(Value, oneshot::Sender<Result<()>>)>) {
|
||||
let (response_tx, response_rx) = mpsc::unbounded_channel();
|
||||
|
||||
(
|
||||
Self {
|
||||
client_id,
|
||||
received_events: Arc::new(Mutex::new(Vec::new())),
|
||||
response_tx,
|
||||
},
|
||||
response_rx,
|
||||
)
|
||||
}
|
||||
|
||||
/// Simulate receiving an SSE event
|
||||
pub async fn receive_sse(&self, event_type: String, data: Value) {
|
||||
let mut events = self.received_events.lock().await;
|
||||
events.push(MockEvent { event_type, data });
|
||||
}
|
||||
|
||||
/// Get all received events
|
||||
pub async fn get_events(&self) -> Vec<MockEvent> {
|
||||
let events = self.received_events.lock().await;
|
||||
events.clone()
|
||||
}
|
||||
|
||||
/// Get the most recent event of a specific type
|
||||
pub async fn get_latest_event(&self, event_type: &str) -> Option<MockEvent> {
|
||||
let events = self.received_events.lock().await;
|
||||
events.iter()
|
||||
.filter(|e| e.event_type == event_type)
|
||||
.last()
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Clear received events
|
||||
pub async fn clear_events(&self) {
|
||||
let mut events = self.received_events.lock().await;
|
||||
events.clear();
|
||||
}
|
||||
|
||||
/// Simulate sending a response to /acp/agent_response
|
||||
pub async fn send_response(&self, response: Value) -> Result<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.response_tx.send((response, tx))
|
||||
.map_err(|_| anyhow::anyhow!("Failed to send response"))?;
|
||||
rx.await?
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock connector for testing
|
||||
///
|
||||
/// Simulates an ACP connector that can send agent requests and receive responses.
|
||||
pub struct MockConnector {
|
||||
pub connector_id: String,
|
||||
/// Channel for receiving agent requests from this connector
|
||||
pub request_tx: mpsc::UnboundedSender<(Value, String, String, Value)>,
|
||||
/// Pending responses (request_id -> sender)
|
||||
pub pending_responses: Arc<Mutex<HashMap<Value, oneshot::Sender<Value>>>>,
|
||||
}
|
||||
|
||||
impl MockConnector {
|
||||
pub fn new(
|
||||
connector_id: String,
|
||||
) -> (Self, mpsc::UnboundedReceiver<(Value, String, String, Value)>) {
|
||||
let (request_tx, request_rx) = mpsc::unbounded_channel();
|
||||
|
||||
(
|
||||
Self {
|
||||
connector_id,
|
||||
request_tx,
|
||||
pending_responses: Arc::new(Mutex::new(HashMap::new())),
|
||||
},
|
||||
request_rx,
|
||||
)
|
||||
}
|
||||
|
||||
/// Simulate sending an agent request
|
||||
pub async fn send_agent_request(
|
||||
&self,
|
||||
session_id: String,
|
||||
request_id: Value,
|
||||
method: String,
|
||||
params: Value,
|
||||
) -> oneshot::Receiver<Value> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
// Store the response sender
|
||||
let mut pending = self.pending_responses.lock().await;
|
||||
pending.insert(request_id.clone(), tx);
|
||||
drop(pending);
|
||||
|
||||
// Send the request
|
||||
self.request_tx.send((request_id, session_id, method, params))
|
||||
.expect("Failed to send agent request");
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
/// Complete a pending response (simulating response from ACP Server)
|
||||
pub async fn complete_response(&self, request_id: Value, response: Value) -> Result<()> {
|
||||
let mut pending = self.pending_responses.lock().await;
|
||||
|
||||
if let Some(tx) = pending.remove(&request_id) {
|
||||
tx.send(response)
|
||||
.map_err(|_| anyhow::anyhow!("Failed to send response to connector"))?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(anyhow::anyhow!("No pending response for request_id: {}", request_id))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Test context for integration tests
|
||||
///
|
||||
/// Provides a complete test environment with mocked components.
|
||||
pub struct TestContext {
|
||||
/// Mock SSE clients by client_id
|
||||
pub clients: Arc<Mutex<HashMap<String, MockSseClient>>>,
|
||||
/// Mock connectors by connector_id
|
||||
pub connectors: Arc<Mutex<HashMap<String, MockConnector>>>,
|
||||
}
|
||||
|
||||
impl TestContext {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
clients: Arc::new(Mutex::new(HashMap::new())),
|
||||
connectors: Arc::new(Mutex::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a mock SSE client
|
||||
pub async fn create_client(
|
||||
&self,
|
||||
client_id: String,
|
||||
) -> (MockSseClient, mpsc::UnboundedReceiver<(Value, oneshot::Sender<Result<()>>)>) {
|
||||
let (client, response_rx) = MockSseClient::new(client_id.clone());
|
||||
|
||||
let mut clients = self.clients.lock().await;
|
||||
clients.insert(client_id.clone(), client.clone());
|
||||
|
||||
(client, response_rx)
|
||||
}
|
||||
|
||||
/// Create a mock connector
|
||||
pub async fn create_connector(
|
||||
&self,
|
||||
connector_id: String,
|
||||
) -> (MockConnector, mpsc::UnboundedReceiver<(Value, String, String, Value)>) {
|
||||
let (connector, request_rx) = MockConnector::new(connector_id.clone());
|
||||
|
||||
let mut connectors = self.connectors.lock().await;
|
||||
connectors.insert(connector_id.clone(), connector.clone());
|
||||
|
||||
(connector, request_rx)
|
||||
}
|
||||
|
||||
/// Get a client by ID
|
||||
pub async fn get_client(&self, client_id: &str) -> Option<MockSseClient> {
|
||||
let clients = self.clients.lock().await;
|
||||
clients.get(client_id).cloned()
|
||||
}
|
||||
|
||||
/// Get a connector by ID
|
||||
pub async fn get_connector(&self, connector_id: &str) -> Option<MockConnector> {
|
||||
let connectors = self.connectors.lock().await;
|
||||
connectors.get(connector_id).cloned()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TestContext {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to create a sample permission request
|
||||
pub fn sample_permission_request(request_id: u64) -> Value {
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": "session/request_permission",
|
||||
"params": {
|
||||
"sessionId": "test-session",
|
||||
"tool": "Write",
|
||||
"parameters": {
|
||||
"path": "/tmp/test.txt",
|
||||
"content": "test"
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper to create a sample permission response
|
||||
pub fn sample_permission_response(request_id: u64, allow: bool) -> Value {
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"selectedOptionId": if allow { "allow" } else { "deny" }
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Helper to extract agent_request data from SSE event
|
||||
pub fn extract_agent_request(event: &MockEvent) -> Option<(Value, String, Value)> {
|
||||
if event.event_type != "session/update" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let update = event.data.get("update")?;
|
||||
if update.get("sessionUpdate")?.as_str()? != "agent_request" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let request_id = update.get("requestId")?.clone();
|
||||
let method = update.get("method")?.as_str()?.to_string();
|
||||
let params = update.get("params")?.clone();
|
||||
|
||||
Some((request_id, method, params))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_client() {
|
||||
let (client, _response_rx) = MockSseClient::new("test-client".to_string());
|
||||
|
||||
// Simulate receiving an event
|
||||
client.receive_sse("session/update".to_string(), json!({"test": "data"})).await;
|
||||
|
||||
// Verify event was received
|
||||
let events = client.get_events().await;
|
||||
assert_eq!(events.len(), 1);
|
||||
assert_eq!(events[0].event_type, "session/update");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_connector() {
|
||||
let (connector, mut request_rx) = MockConnector::new("test-connector".to_string());
|
||||
|
||||
// Simulate sending an agent request
|
||||
let response_fut = connector.send_agent_request(
|
||||
"session-1".to_string(),
|
||||
json!(0),
|
||||
"session/request_permission".to_string(),
|
||||
json!({"tool": "Write"}),
|
||||
).await;
|
||||
|
||||
// Verify request was sent
|
||||
let request = request_rx.recv().await.unwrap();
|
||||
assert_eq!(request.0, json!(0));
|
||||
assert_eq!(request.1, "session-1");
|
||||
assert_eq!(request.2, "session/request_permission");
|
||||
|
||||
// Simulate completing the response
|
||||
connector.complete_response(json!(0), json!({"result": "success"})).await.unwrap();
|
||||
|
||||
// Verify response was received
|
||||
let response = response_fut.await.unwrap();
|
||||
assert_eq!(response, json!({"result": "success"}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_test_context() {
|
||||
let ctx = TestContext::new();
|
||||
|
||||
// Create client and connector
|
||||
let (client, _) = ctx.create_client("client-1".to_string()).await;
|
||||
let (connector, _) = ctx.create_connector("connector-1".to_string()).await;
|
||||
|
||||
// Verify we can retrieve them
|
||||
assert!(ctx.get_client("client-1").await.is_some());
|
||||
assert!(ctx.get_connector("connector-1").await.is_some());
|
||||
assert!(ctx.get_client("non-existent").await.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sample_helpers() {
|
||||
let request = sample_permission_request(0);
|
||||
assert_eq!(request["method"], "session/request_permission");
|
||||
|
||||
let response = sample_permission_response(0, true);
|
||||
assert_eq!(response["result"]["selectedOptionId"], "allow");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_agent_request() {
|
||||
let event = MockEvent {
|
||||
event_type: "session/update".to_string(),
|
||||
data: json!({
|
||||
"sessionId": "session-1",
|
||||
"update": {
|
||||
"sessionUpdate": "agent_request",
|
||||
"requestId": 0,
|
||||
"method": "session/request_permission",
|
||||
"params": {"tool": "Write"}
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
let (request_id, method, params) = extract_agent_request(&event).unwrap();
|
||||
assert_eq!(request_id, json!(0));
|
||||
assert_eq!(method, "session/request_permission");
|
||||
assert_eq!(params["tool"], "Write");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,327 @@
|
||||
//! Integration test for concurrent agent requests (T050)
|
||||
//!
|
||||
//! This test verifies that the system can handle multiple agent requests
|
||||
//! simultaneously without cross-contamination.
|
||||
//!
|
||||
//! Test scenario:
|
||||
//! 1. Register multiple pending requests concurrently
|
||||
//! 2. Complete them in random/different order
|
||||
//! 3. Verify each response goes to the correct request
|
||||
//! 4. Verify no cross-contamination
|
||||
|
||||
use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
use serde_json::json;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_requests_basic() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// Register 5 concurrent requests
|
||||
let mut receivers = Vec::new();
|
||||
for i in 0..5 {
|
||||
let rx = tracker.register(client_id, json!(i));
|
||||
receivers.push((i, rx));
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 5);
|
||||
|
||||
// Complete them in reverse order
|
||||
for i in (0..5).rev() {
|
||||
let response = json!({"request": i, "result": "success"});
|
||||
tracker.complete(client_id, json!(i), response).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Verify each receiver got the correct response
|
||||
for (i, rx) in receivers {
|
||||
let response = rx.await.unwrap();
|
||||
assert_eq!(response["request"], i);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_requests_multiple_clients() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
// Two clients, each with 3 requests
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
|
||||
let mut receivers1 = Vec::new();
|
||||
let mut receivers2 = Vec::new();
|
||||
|
||||
for i in 0..3 {
|
||||
let rx1 = tracker.register(client1, json!(i));
|
||||
let rx2 = tracker.register(client2, json!(i));
|
||||
receivers1.push((i, rx1));
|
||||
receivers2.push((i, rx2));
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 6);
|
||||
assert_eq!(tracker.client_pending_count(client1), 3);
|
||||
assert_eq!(tracker.client_pending_count(client2), 3);
|
||||
|
||||
// Complete client1's requests
|
||||
for i in 0..3 {
|
||||
let response = json!({"client": 1, "request": i});
|
||||
tracker.complete(client1, json!(i), response).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(tracker.client_pending_count(client1), 0);
|
||||
assert_eq!(tracker.client_pending_count(client2), 3);
|
||||
|
||||
// Complete client2's requests
|
||||
for i in 0..3 {
|
||||
let response = json!({"client": 2, "request": i});
|
||||
tracker.complete(client2, json!(i), response).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Verify each receiver got the correct response
|
||||
for (i, rx) in receivers1 {
|
||||
let response = rx.await.unwrap();
|
||||
assert_eq!(response["client"], 1);
|
||||
assert_eq!(response["request"], i);
|
||||
}
|
||||
|
||||
for (i, rx) in receivers2 {
|
||||
let response = rx.await.unwrap();
|
||||
assert_eq!(response["client"], 2);
|
||||
assert_eq!(response["request"], i);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_requests_same_id_different_clients() {
|
||||
// Test that same request_id for different clients are handled independently
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
let request_id = json!(0); // Same ID for both
|
||||
|
||||
let rx1 = tracker.register(client1, request_id.clone());
|
||||
let rx2 = tracker.register(client2, request_id.clone());
|
||||
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
|
||||
// Complete client1's request
|
||||
let response1 = json!({"client": "client-1"});
|
||||
tracker.complete(client1, request_id.clone(), response1.clone()).unwrap();
|
||||
|
||||
// Complete client2's request
|
||||
let response2 = json!({"client": "client-2"});
|
||||
tracker.complete(client2, request_id, response2.clone()).unwrap();
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Verify each got the correct response
|
||||
let received1 = rx1.await.unwrap();
|
||||
let received2 = rx2.await.unwrap();
|
||||
|
||||
assert_eq!(received1["client"], "client-1");
|
||||
assert_eq!(received2["client"], "client-2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_async_completion() {
|
||||
// Test completing requests from multiple async tasks concurrently
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let num_requests = 10;
|
||||
|
||||
// Register requests
|
||||
let mut receivers = Vec::new();
|
||||
for i in 0..num_requests {
|
||||
let rx = tracker.register(client_id, json!(i));
|
||||
receivers.push((i, rx));
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), num_requests);
|
||||
|
||||
// Spawn tasks to complete requests concurrently
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for i in 0..num_requests {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
// Small delay to ensure concurrency
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(((i % 3) * 10) as u64)).await;
|
||||
let response = json!({"request": i, "result": "success"});
|
||||
tracker_clone.complete(client_id, json!(i), response)
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for all completions
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
assert!(result.unwrap().is_ok());
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Verify all receivers got correct responses
|
||||
for (i, rx) in receivers {
|
||||
let response = rx.await.unwrap();
|
||||
assert_eq!(response["request"], i);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_register_and_complete() {
|
||||
// Test registering and completing requests concurrently
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let num_requests = 20;
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn tasks to register and complete requests
|
||||
for i in 0..num_requests {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
// Register
|
||||
let rx = tracker_clone.register(client_id, json!(i));
|
||||
|
||||
// Small delay
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||||
|
||||
// Complete
|
||||
let response = json!({"request": i});
|
||||
tracker_clone.complete(client_id, json!(i), response.clone()).unwrap();
|
||||
|
||||
// Wait for response
|
||||
let received = rx.await.unwrap();
|
||||
assert_eq!(received["request"], i);
|
||||
|
||||
i
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
let mut completed = Vec::new();
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
completed.push(result.unwrap());
|
||||
}
|
||||
|
||||
// All requests should have completed
|
||||
assert_eq!(completed.len(), num_requests);
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_mixed_operations() {
|
||||
// Test mix of register, complete, and timeout operations
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn 15 tasks with different behaviors
|
||||
for i in 0..15 {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
let rx = tracker_clone.register(client_id, json!(i));
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||||
|
||||
match i % 3 {
|
||||
0 => {
|
||||
// Complete normally
|
||||
let response = json!({"request": i, "type": "complete"});
|
||||
tracker_clone.complete(client_id, json!(i), response.clone()).unwrap();
|
||||
let received = rx.await.unwrap();
|
||||
assert_eq!(received["type"], "complete");
|
||||
"completed"
|
||||
}
|
||||
1 => {
|
||||
// Timeout
|
||||
tracker_clone.timeout(client_id, json!(i));
|
||||
assert!(rx.await.is_err());
|
||||
"timeout"
|
||||
}
|
||||
_ => {
|
||||
// Complete with delay
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
|
||||
let response = json!({"request": i, "type": "delayed"});
|
||||
tracker_clone.complete(client_id, json!(i), response.clone()).unwrap();
|
||||
let received = rx.await.unwrap();
|
||||
assert_eq!(received["type"], "delayed");
|
||||
"delayed"
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
let mut results = Vec::new();
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
results.push(result.unwrap());
|
||||
}
|
||||
|
||||
assert_eq!(results.len(), 15);
|
||||
|
||||
// Count outcomes
|
||||
let completed = results.iter().filter(|&r| r == &"completed").count();
|
||||
let timeout = results.iter().filter(|&r| r == &"timeout").count();
|
||||
let delayed = results.iter().filter(|&r| r == &"delayed").count();
|
||||
|
||||
assert_eq!(completed, 5);
|
||||
assert_eq!(timeout, 5);
|
||||
assert_eq!(delayed, 5);
|
||||
|
||||
// All should be cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_high_concurrency() {
|
||||
// Stress test with many concurrent requests
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let num_requests = 100;
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
for i in 0..num_requests {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
let rx = tracker_clone.register(client_id, json!(i));
|
||||
|
||||
// Random-ish delay
|
||||
let delay = ((i * 7) % 20) as u64;
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(delay)).await;
|
||||
|
||||
let response = json!({"request": i});
|
||||
tracker_clone.complete(client_id, json!(i), response).unwrap();
|
||||
|
||||
rx.await.unwrap()["request"].as_u64().unwrap()
|
||||
});
|
||||
}
|
||||
|
||||
// Collect all results
|
||||
let mut results = Vec::new();
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
results.push(result.unwrap());
|
||||
}
|
||||
|
||||
// Verify all requests completed
|
||||
assert_eq!(results.len(), num_requests);
|
||||
|
||||
// Verify all request IDs are present
|
||||
results.sort();
|
||||
for (idx, &val) in results.iter().enumerate() {
|
||||
assert_eq!(val, idx as u64);
|
||||
}
|
||||
|
||||
// All cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
@@ -0,0 +1,286 @@
|
||||
//! Integration test for client disconnection (T051)
|
||||
//!
|
||||
//! This test verifies that pending agent requests are cleaned up when a client
|
||||
//! disconnects, and that other clients are unaffected.
|
||||
//!
|
||||
//! Test scenario:
|
||||
//! 1. Register pending requests for multiple clients
|
||||
//! 2. Simulate client disconnection
|
||||
//! 3. Verify cleanup occurs for disconnected client
|
||||
//! 4. Verify other clients are unaffected
|
||||
|
||||
use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disconnect_single_client() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// Register 3 pending requests
|
||||
let rx1 = tracker.register(client_id, json!(0));
|
||||
let rx2 = tracker.register(client_id, json!(1));
|
||||
let rx3 = tracker.register(client_id, json!(2));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 3);
|
||||
assert_eq!(tracker.client_pending_count(client_id), 3);
|
||||
|
||||
// Simulate client disconnection - clear all requests for this client
|
||||
tracker.clear(Some(client_id));
|
||||
|
||||
// All requests should be removed
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
assert_eq!(tracker.client_pending_count(client_id), 0);
|
||||
|
||||
// All receivers should get errors (channels closed)
|
||||
assert!(rx1.await.is_err());
|
||||
assert!(rx2.await.is_err());
|
||||
assert!(rx3.await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disconnect_multiple_clients() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
let client3 = "client-3";
|
||||
|
||||
// Register requests for all clients
|
||||
let rx1_1 = tracker.register(client1, json!(0));
|
||||
let rx1_2 = tracker.register(client1, json!(1));
|
||||
|
||||
let rx2_1 = tracker.register(client2, json!(0));
|
||||
let rx2_2 = tracker.register(client2, json!(1));
|
||||
let rx2_3 = tracker.register(client2, json!(2));
|
||||
|
||||
let rx3_1 = tracker.register(client3, json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 6);
|
||||
assert_eq!(tracker.client_pending_count(client1), 2);
|
||||
assert_eq!(tracker.client_pending_count(client2), 3);
|
||||
assert_eq!(tracker.client_pending_count(client3), 1);
|
||||
|
||||
// Disconnect client2
|
||||
tracker.clear(Some(client2));
|
||||
|
||||
// Only client2's requests should be removed
|
||||
assert_eq!(tracker.pending_count(), 3);
|
||||
assert_eq!(tracker.client_pending_count(client1), 2);
|
||||
assert_eq!(tracker.client_pending_count(client2), 0);
|
||||
assert_eq!(tracker.client_pending_count(client3), 1);
|
||||
|
||||
// Client2's receivers should error
|
||||
assert!(rx2_1.await.is_err());
|
||||
assert!(rx2_2.await.is_err());
|
||||
assert!(rx2_3.await.is_err());
|
||||
|
||||
// Client1 and client3 should still work
|
||||
tracker.complete(client1, json!(0), json!({"result": "client1-0"})).unwrap();
|
||||
assert_eq!(rx1_1.await.unwrap()["result"], "client1-0");
|
||||
|
||||
tracker.complete(client3, json!(0), json!({"result": "client3-0"})).unwrap();
|
||||
assert_eq!(rx3_1.await.unwrap()["result"], "client3-0");
|
||||
|
||||
// Complete remaining client1 request
|
||||
tracker.complete(client1, json!(1), json!({"result": "client1-1"})).unwrap();
|
||||
assert_eq!(rx1_2.await.unwrap()["result"], "client1-1");
|
||||
|
||||
// All cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disconnect_then_reconnect() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// First connection - register requests
|
||||
let rx1 = tracker.register(client_id, json!(0));
|
||||
let rx2 = tracker.register(client_id, json!(1));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
|
||||
// Disconnect - cleanup
|
||||
tracker.clear(Some(client_id));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Old receivers should error
|
||||
assert!(rx1.await.is_err());
|
||||
assert!(rx2.await.is_err());
|
||||
|
||||
// Reconnect - register new requests (same client_id, same request_ids)
|
||||
let rx3 = tracker.register(client_id, json!(0));
|
||||
let rx4 = tracker.register(client_id, json!(1));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
|
||||
// Complete new requests
|
||||
tracker.complete(client_id, json!(0), json!({"result": "new-0"})).unwrap();
|
||||
tracker.complete(client_id, json!(1), json!({"result": "new-1"})).unwrap();
|
||||
|
||||
// New receivers should get responses
|
||||
assert_eq!(rx3.await.unwrap()["result"], "new-0");
|
||||
assert_eq!(rx4.await.unwrap()["result"], "new-1");
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disconnect_no_pending_requests() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// No pending requests
|
||||
assert_eq!(tracker.client_pending_count(client_id), 0);
|
||||
|
||||
// Disconnect should be no-op
|
||||
tracker.clear(Some(client_id));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_clear_all_clients() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
let client3 = "client-3";
|
||||
|
||||
// Register requests for multiple clients
|
||||
let rx1 = tracker.register(client1, json!(0));
|
||||
let rx2 = tracker.register(client2, json!(0));
|
||||
let rx3 = tracker.register(client3, json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 3);
|
||||
|
||||
// Clear all (simulating server shutdown)
|
||||
tracker.clear(None);
|
||||
|
||||
// All should be removed
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
assert_eq!(tracker.client_pending_count(client1), 0);
|
||||
assert_eq!(tracker.client_pending_count(client2), 0);
|
||||
assert_eq!(tracker.client_pending_count(client3), 0);
|
||||
|
||||
// All receivers should error
|
||||
assert!(rx1.await.is_err());
|
||||
assert!(rx2.await.is_err());
|
||||
assert!(rx3.await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_disconnect_race_with_completion() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// Register multiple requests
|
||||
let rx1 = tracker.register(client_id, json!(0));
|
||||
let rx2 = tracker.register(client_id, json!(1));
|
||||
let rx3 = tracker.register(client_id, json!(2));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 3);
|
||||
|
||||
// Complete one request
|
||||
tracker.complete(client_id, json!(0), json!({"result": "0"})).unwrap();
|
||||
|
||||
// Verify it was removed
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
|
||||
// Now disconnect (should only clear remaining requests)
|
||||
tracker.clear(Some(client_id));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// First receiver should have gotten response
|
||||
assert_eq!(rx1.await.unwrap()["result"], "0");
|
||||
|
||||
// Other receivers should error
|
||||
assert!(rx2.await.is_err());
|
||||
assert!(rx3.await.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_partial_disconnect_completion() {
|
||||
// Test that completing a request after disconnect fails gracefully
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
let _rx = tracker.register(client_id, json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// Disconnect
|
||||
tracker.clear(Some(client_id));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Try to complete after disconnect - should fail
|
||||
let result = tracker.complete(client_id, json!(0), json!({"result": "late"}));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_disconnect_and_complete() {
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// Register many requests
|
||||
for i in 0..50 {
|
||||
tracker.register(client_id, json!(i));
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 50);
|
||||
|
||||
let mut join_set = JoinSet::new();
|
||||
|
||||
// Spawn task to disconnect after delay
|
||||
{
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||||
tracker_clone.clear(Some(client_id));
|
||||
"disconnected"
|
||||
});
|
||||
}
|
||||
|
||||
// Spawn tasks to complete requests
|
||||
for i in 0..50 {
|
||||
let tracker_clone = tracker.clone();
|
||||
join_set.spawn(async move {
|
||||
// Small delay to ensure some complete before disconnect
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis((i % 5) as u64)).await;
|
||||
let response = json!({"request": i});
|
||||
match tracker_clone.complete(client_id, json!(i), response) {
|
||||
Ok(_) => "completed",
|
||||
Err(_) => "failed",
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for all tasks
|
||||
let mut results = Vec::new();
|
||||
while let Some(result) = join_set.join_next().await {
|
||||
results.push(result.unwrap());
|
||||
}
|
||||
|
||||
// Should have 1 disconnect + 50 completion attempts
|
||||
assert_eq!(results.len(), 51);
|
||||
|
||||
// Some completions succeeded, some failed (after disconnect)
|
||||
let disconnects = results.iter().filter(|r| r == &&"disconnected").count();
|
||||
assert_eq!(disconnects, 1);
|
||||
|
||||
// All requests should be cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
@@ -0,0 +1,153 @@
|
||||
//! Integration tests for session/list RPC method
|
||||
|
||||
use dirigent_acp_api::{
|
||||
NoOpConnectorOperations, RpcHandler, SessionManager,
|
||||
};
|
||||
use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
use dirigent_acp_api::sse::SseNotifier;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn create_test_handler() -> RpcHandler<NoOpConnectorOperations> {
|
||||
RpcHandler::new(
|
||||
SessionManager::new(),
|
||||
NoOpConnectorOperations,
|
||||
SseNotifier::new(),
|
||||
Arc::new(AgentRequestTracker::new()),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_list_returns_sessions() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/list",
|
||||
"params": {
|
||||
"connectorId": "stub-connector"
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
// NoOp returns one stub session
|
||||
let result = &response_json["result"];
|
||||
assert!(result["sessions"].is_array());
|
||||
let sessions = result["sessions"].as_array().unwrap();
|
||||
assert!(!sessions.is_empty());
|
||||
assert!(sessions[0]["sessionId"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_list_no_params() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
// session/list with no params should use default connector
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "session/list"
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
// Should succeed using default connector from NoOp
|
||||
let result = &response_json["result"];
|
||||
assert!(result["sessions"].is_array());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_list_creates_mappings() {
|
||||
let session_manager = SessionManager::new();
|
||||
let handler = RpcHandler::new(
|
||||
session_manager.clone(),
|
||||
NoOpConnectorOperations,
|
||||
SseNotifier::new(),
|
||||
Arc::new(AgentRequestTracker::new()),
|
||||
);
|
||||
|
||||
// First, initialize to register a client
|
||||
let init_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 0,
|
||||
"method": "initialize",
|
||||
"params": {}
|
||||
});
|
||||
let init_response = handler.handle_request(&init_body.to_string(), Some("test-client")).await;
|
||||
let init_json = serde_json::to_value(&init_response).unwrap();
|
||||
let client_id = init_json["result"]["clientId"].as_str().unwrap();
|
||||
|
||||
// Now list sessions
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/list",
|
||||
"params": {
|
||||
"connectorId": "stub-connector"
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), Some(client_id)).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let sessions = response_json["result"]["sessions"].as_array().unwrap();
|
||||
assert!(!sessions.is_empty());
|
||||
|
||||
// Session mapping should exist for the returned session
|
||||
let session_id = sessions[0]["sessionId"].as_str().unwrap();
|
||||
let mapping = session_manager.get_mapping(session_id);
|
||||
assert!(mapping.is_some(), "Session mapping should be created for listed sessions");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_load_without_connector_id() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
// Standard ACP: only sessionId + cwd + mcpServers, no connectorId
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/load",
|
||||
"params": {
|
||||
"sessionId": "sess-abc",
|
||||
"cwd": "G:\\dev\\projects\\test",
|
||||
"mcpServers": []
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let result = &response_json["result"];
|
||||
assert!(result["sessionId"].is_string(), "session/load should succeed without connectorId");
|
||||
assert!(result["createdAt"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialize_advertises_list_sessions() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let caps = &response_json["result"]["agentCapabilities"];
|
||||
assert_eq!(caps["listSessions"], true, "listSessions capability should be advertised");
|
||||
assert_eq!(caps["loadSession"], true, "loadSession capability should still be advertised");
|
||||
|
||||
// Verify nested sessionCapabilities.list is advertised (required by Zed v0.9.4+)
|
||||
assert!(
|
||||
caps["sessionCapabilities"]["list"].is_object(),
|
||||
"sessionCapabilities.list should be advertised as an empty object"
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,134 @@
|
||||
//! Integration tests for session/resume RPC method
|
||||
|
||||
use dirigent_acp_api::{
|
||||
NoOpConnectorOperations, RpcHandler, SessionManager,
|
||||
};
|
||||
use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
use dirigent_acp_api::sse::SseNotifier;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn create_test_handler() -> RpcHandler<NoOpConnectorOperations> {
|
||||
RpcHandler::new(
|
||||
SessionManager::new(),
|
||||
NoOpConnectorOperations,
|
||||
SseNotifier::new(),
|
||||
Arc::new(AgentRequestTracker::new()),
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_resume_returns_session() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/resume",
|
||||
"params": {
|
||||
"sessionId": "sess-123",
|
||||
"connectorId": "stub-connector"
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let result = &response_json["result"];
|
||||
assert!(result["sessionId"].is_string());
|
||||
assert!(result["connectorId"].is_string());
|
||||
assert!(result["createdAt"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_resume_missing_params() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/resume"
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
// Should return error for missing params
|
||||
assert!(response_json["error"].is_object());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_resume_creates_mapping() {
|
||||
let session_manager = SessionManager::new();
|
||||
let handler = RpcHandler::new(
|
||||
session_manager.clone(),
|
||||
NoOpConnectorOperations,
|
||||
SseNotifier::new(),
|
||||
Arc::new(AgentRequestTracker::new()),
|
||||
);
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/resume",
|
||||
"params": {
|
||||
"sessionId": "sess-456",
|
||||
"connectorId": "stub-connector"
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let session_id = response_json["result"]["sessionId"].as_str().unwrap();
|
||||
let mapping = session_manager.get_mapping(session_id);
|
||||
assert!(mapping.is_some(), "Session mapping should be created for resumed session");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_session_resume_without_connector_id() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
// Standard ACP: only sessionId, no connectorId — should resolve via default connector
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "session/resume",
|
||||
"params": {
|
||||
"sessionId": "sess-789",
|
||||
"cwd": "G:\\dev\\projects\\test"
|
||||
}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let result = &response_json["result"];
|
||||
assert!(result["sessionId"].is_string(), "Should succeed without connectorId");
|
||||
assert!(result["createdAt"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_initialize_advertises_session_resume() {
|
||||
let handler = create_test_handler();
|
||||
|
||||
let request_body = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {}
|
||||
});
|
||||
|
||||
let response = handler.handle_request(&request_body.to_string(), None).await;
|
||||
let response_json = serde_json::to_value(&response).unwrap();
|
||||
|
||||
let caps = &response_json["result"]["agentCapabilities"];
|
||||
assert!(
|
||||
caps["sessionCapabilities"]["list"].is_object(),
|
||||
"sessionCapabilities.list should be advertised as an empty object"
|
||||
);
|
||||
assert!(
|
||||
caps["sessionCapabilities"]["resume"].is_object(),
|
||||
"sessionCapabilities.resume should be advertised as an empty object"
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,229 @@
|
||||
//! Integration test for timeout handling (T049)
|
||||
//!
|
||||
//! This test verifies that the system handles timeouts gracefully when a client
|
||||
//! fails to respond to an agent request within the timeout period.
|
||||
//!
|
||||
//! Test scenario:
|
||||
//! 1. Register a pending agent request
|
||||
//! 2. Wait for timeout (using reduced timeout for testing)
|
||||
//! 3. Verify timeout occurs
|
||||
//! 4. Verify cleanup happens correctly
|
||||
//! 5. Verify no resource leaks
|
||||
|
||||
use dirigent_acp_api::agent_requests::AgentRequestTracker;
|
||||
use serde_json::json;
|
||||
use tokio::time::{timeout, Duration};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_basic() {
|
||||
// Create tracker
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let request_id = json!(0);
|
||||
|
||||
// Register request
|
||||
let receiver = tracker.register(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// Wait for timeout (use short timeout for testing)
|
||||
let result = timeout(Duration::from_millis(100), receiver).await;
|
||||
|
||||
// Should timeout
|
||||
assert!(result.is_err(), "Expected timeout but request completed");
|
||||
|
||||
// Manually trigger cleanup (in production, event bridge does this)
|
||||
tracker.timeout(client_id, request_id);
|
||||
|
||||
// Verify cleanup
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_cleanup() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let request_id = json!(123);
|
||||
|
||||
// Register request
|
||||
let receiver = tracker.register(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// Trigger timeout before waiting
|
||||
tracker.timeout(client_id, request_id);
|
||||
|
||||
// Verify cleanup happened
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Receiver should get error (channel closed)
|
||||
let result = receiver.await;
|
||||
assert!(result.is_err(), "Expected receiver to get error after timeout");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_multiple_clients() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client1 = "client-1";
|
||||
let client2 = "client-2";
|
||||
|
||||
// Register requests for both clients
|
||||
let rx1 = tracker.register(client1, json!(0));
|
||||
let rx2 = tracker.register(client2, json!(0));
|
||||
|
||||
assert_eq!(tracker.pending_count(), 2);
|
||||
assert_eq!(tracker.client_pending_count(client1), 1);
|
||||
assert_eq!(tracker.client_pending_count(client2), 1);
|
||||
|
||||
// Timeout only client1's request
|
||||
tracker.timeout(client1, json!(0));
|
||||
|
||||
// Verify only client1's request is removed
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
assert_eq!(tracker.client_pending_count(client1), 0);
|
||||
assert_eq!(tracker.client_pending_count(client2), 1);
|
||||
|
||||
// Client1's receiver should error
|
||||
assert!(rx1.await.is_err());
|
||||
|
||||
// Complete client2's request normally
|
||||
let response = json!({"result": "success"});
|
||||
let result = tracker.complete(client2, json!(0), response);
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Client2's receiver should get response
|
||||
let received = rx2.await.unwrap();
|
||||
assert_eq!(received, json!({"result": "success"}));
|
||||
|
||||
// All cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_no_double_cleanup() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let request_id = json!(0);
|
||||
|
||||
// Register request
|
||||
let _receiver = tracker.register(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// First timeout - should remove
|
||||
tracker.timeout(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Second timeout - should be no-op (not panic)
|
||||
tracker.timeout(client_id, request_id);
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_race_with_complete() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let request_id = json!(0);
|
||||
|
||||
// Register request
|
||||
let receiver = tracker.register(client_id, request_id.clone());
|
||||
assert_eq!(tracker.pending_count(), 1);
|
||||
|
||||
// Complete the request
|
||||
let response = json!({"result": "success"});
|
||||
let result = tracker.complete(client_id, request_id.clone(), response.clone());
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Try to timeout after completion - should be no-op
|
||||
tracker.timeout(client_id, request_id);
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// Receiver should still get the response
|
||||
let received = receiver.await.unwrap();
|
||||
assert_eq!(received, response);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_timeouts() {
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
|
||||
// Register 10 requests
|
||||
let mut receivers = Vec::new();
|
||||
for i in 0..10 {
|
||||
let rx = tracker.register(client_id, json!(i));
|
||||
receivers.push((i, rx));
|
||||
}
|
||||
|
||||
assert_eq!(tracker.pending_count(), 10);
|
||||
|
||||
// Spawn tasks to timeout each request after random delays
|
||||
let tracker_clone = tracker.clone();
|
||||
let timeout_handles: Vec<_> = (0..10)
|
||||
.map(|i| {
|
||||
let tracker = tracker_clone.clone();
|
||||
tokio::spawn(async move {
|
||||
// Small random-ish delay based on index
|
||||
tokio::time::sleep(Duration::from_millis((i * 10) as u64)).await;
|
||||
tracker.timeout(client_id, json!(i));
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Wait for all timeouts to complete
|
||||
for handle in timeout_handles {
|
||||
handle.await.unwrap();
|
||||
}
|
||||
|
||||
// All should be cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
|
||||
// All receivers should get errors
|
||||
for (_i, rx) in receivers {
|
||||
assert!(rx.await.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_timeout_with_actual_delay() {
|
||||
// This test uses actual time delays to verify timeout behavior more realistically
|
||||
let tracker = AgentRequestTracker::new();
|
||||
|
||||
let client_id = "test-client";
|
||||
let request_id = json!(0);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Register request
|
||||
let receiver = tracker.register(client_id, request_id.clone());
|
||||
|
||||
// Spawn task to timeout after 200ms
|
||||
let tracker_clone = tracker.clone();
|
||||
tokio::spawn(async move {
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
tracker_clone.timeout(client_id, json!(0));
|
||||
});
|
||||
|
||||
// Wait on receiver with longer timeout
|
||||
let result = timeout(Duration::from_secs(1), receiver).await;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
// Should complete due to timeout() call, not tokio::time::timeout
|
||||
assert!(result.is_ok(), "Should complete when timeout() is called");
|
||||
assert!(result.unwrap().is_err(), "Receiver should get error");
|
||||
|
||||
// Should take approximately 200ms
|
||||
assert!(
|
||||
elapsed >= Duration::from_millis(180) && elapsed < Duration::from_millis(300),
|
||||
"Expected ~200ms but took {:?}",
|
||||
elapsed
|
||||
);
|
||||
|
||||
// Should be cleaned up
|
||||
assert_eq!(tracker.pending_count(), 0);
|
||||
}
|
||||
Reference in New Issue
Block a user