sync from monorepo @ 2452e92e
This commit is contained in:
@@ -0,0 +1,717 @@
|
||||
//! Protocol-level message accumulator for incremental message assembly.
|
||||
//!
|
||||
//! Handles streaming message deltas and assembles them into complete
|
||||
//! [`AccumulatedMessage`] values using protocol types. Unlike the archivist's
|
||||
//! accumulator, this module produces protocol-native output without UUID
|
||||
//! parsing, markdown generation, or storage metadata.
|
||||
//!
|
||||
//! The accumulator preserves the order of content parts (text, thinking, tool
|
||||
//! calls) as they arrive in the event stream, enabling inline tool rendering
|
||||
//! and faithful forwarding to downstream consumers.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::conversation::MessagePart;
|
||||
use crate::types::ContentBlock;
|
||||
|
||||
/// Tool call data accumulated during streaming.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallData {
|
||||
pub id: String,
|
||||
pub tool_name: String,
|
||||
pub input: Value,
|
||||
pub output: Option<Value>,
|
||||
}
|
||||
|
||||
/// A single accumulated content part, preserving event-stream order.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AccumulatedPart {
|
||||
Text { text: String },
|
||||
Thinking { text: String },
|
||||
Tool { data: ToolCallData },
|
||||
}
|
||||
|
||||
/// A fully accumulated message assembled from streaming chunks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AccumulatedMessage {
|
||||
pub message_id: String,
|
||||
pub session_id: String,
|
||||
pub connector_id: String,
|
||||
pub role: String,
|
||||
pub parts: Vec<AccumulatedPart>,
|
||||
pub created_at: Option<DateTime<Utc>>,
|
||||
pub last_activity: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl AccumulatedMessage {
|
||||
/// Build an `AccumulatedMessage` directly from a slice of [`MessagePart`]s.
|
||||
///
|
||||
/// This is the path for non-streaming clients that deliver content in a
|
||||
/// single `MessageCompleted` event rather than via incremental chunks.
|
||||
pub fn from_message_parts(
|
||||
message_id: String,
|
||||
session_id: String,
|
||||
connector_id: String,
|
||||
role: String,
|
||||
parts: &[MessagePart],
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
let mut accumulated_parts = Vec::new();
|
||||
|
||||
for part in parts {
|
||||
match part {
|
||||
MessagePart::Text { text } => {
|
||||
if !text.is_empty() {
|
||||
accumulated_parts.push(AccumulatedPart::Text { text: text.clone() });
|
||||
}
|
||||
}
|
||||
MessagePart::Thinking { text } => {
|
||||
if !text.is_empty() {
|
||||
accumulated_parts.push(AccumulatedPart::Thinking { text: text.clone() });
|
||||
}
|
||||
}
|
||||
MessagePart::Code { language, code } => {
|
||||
if !code.is_empty() {
|
||||
let fenced = format!("```{}\n{}\n```", language, code);
|
||||
accumulated_parts.push(AccumulatedPart::Text { text: fenced });
|
||||
}
|
||||
}
|
||||
MessagePart::Tool {
|
||||
tool,
|
||||
tool_call_id,
|
||||
input,
|
||||
output,
|
||||
} => {
|
||||
accumulated_parts.push(AccumulatedPart::Tool {
|
||||
data: ToolCallData {
|
||||
id: tool_call_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| String::new()),
|
||||
tool_name: tool.clone(),
|
||||
input: input.clone(),
|
||||
output: output.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
MessagePart::File { path, content: _ } => {
|
||||
let label = format!("\u{1f4c4} File: {}", path);
|
||||
accumulated_parts.push(AccumulatedPart::Text { text: label });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
message_id,
|
||||
session_id,
|
||||
connector_id,
|
||||
role,
|
||||
parts: accumulated_parts,
|
||||
created_at: Some(now),
|
||||
last_activity: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert accumulated parts back to protocol [`MessagePart`]s.
|
||||
pub fn to_message_parts(&self) -> Vec<MessagePart> {
|
||||
self.parts
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
AccumulatedPart::Text { text } => MessagePart::Text { text: text.clone() },
|
||||
AccumulatedPart::Thinking { text } => {
|
||||
MessagePart::Thinking { text: text.clone() }
|
||||
}
|
||||
AccumulatedPart::Tool { data } => MessagePart::Tool {
|
||||
tool: data.tool_name.clone(),
|
||||
tool_call_id: if data.id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(data.id.clone())
|
||||
},
|
||||
input: data.input.clone(),
|
||||
output: data.output.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns `true` when the message has no accumulated content.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.parts.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal buffer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Buffer for accumulating streaming chunks into a complete message.
|
||||
#[derive(Debug)]
|
||||
struct MessageBuffer {
|
||||
message_id: String,
|
||||
session_id: String,
|
||||
connector_id: String,
|
||||
role: String,
|
||||
parts: Vec<AccumulatedPart>,
|
||||
created_at: Option<DateTime<Utc>>,
|
||||
last_activity: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl MessageBuffer {
|
||||
fn new(message_id: String, session_id: String, connector_id: String, role: String) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
message_id,
|
||||
session_id,
|
||||
connector_id,
|
||||
role,
|
||||
parts: Vec::new(),
|
||||
created_at: None,
|
||||
last_activity: now,
|
||||
}
|
||||
}
|
||||
|
||||
fn touch(&mut self) {
|
||||
let now = Utc::now();
|
||||
if self.created_at.is_none() {
|
||||
self.created_at = Some(now);
|
||||
}
|
||||
self.last_activity = now;
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MessageAccumulator
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Accumulator for assembling streaming message deltas into complete messages.
|
||||
///
|
||||
/// Each in-flight message is identified by its `message_id` and tracked in an
|
||||
/// internal buffer. Text and thinking chunks are coalesced when consecutive;
|
||||
/// tool calls are deduplicated by `tool_call_id`.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MessageAccumulator {
|
||||
buffers: HashMap<String, MessageBuffer>,
|
||||
}
|
||||
|
||||
impl MessageAccumulator {
|
||||
/// Create a new, empty accumulator.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a content chunk to the message buffer.
|
||||
///
|
||||
/// Consecutive text chunks are coalesced into a single `AccumulatedPart::Text`.
|
||||
pub fn add_chunk(
|
||||
&mut self,
|
||||
message_id: &str,
|
||||
session_id: &str,
|
||||
connector_id: &str,
|
||||
role: &str,
|
||||
content: ContentBlock,
|
||||
) {
|
||||
let buffer = self
|
||||
.buffers
|
||||
.entry(message_id.to_string())
|
||||
.or_insert_with(|| {
|
||||
MessageBuffer::new(
|
||||
message_id.to_string(),
|
||||
session_id.to_string(),
|
||||
connector_id.to_string(),
|
||||
role.to_string(),
|
||||
)
|
||||
});
|
||||
|
||||
buffer.touch();
|
||||
|
||||
match content {
|
||||
ContentBlock::Text { text } => {
|
||||
if let Some(AccumulatedPart::Text { text: existing }) = buffer.parts.last_mut() {
|
||||
existing.push_str(&text);
|
||||
} else {
|
||||
buffer.parts.push(AccumulatedPart::Text { text });
|
||||
}
|
||||
}
|
||||
ContentBlock::ResourceLink { .. } => {
|
||||
// ResourceLink is not accumulated as text content for now.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add thinking content to the message buffer.
|
||||
///
|
||||
/// Consecutive thinking chunks are coalesced into a single
|
||||
/// `AccumulatedPart::Thinking`.
|
||||
pub fn add_thinking(
|
||||
&mut self,
|
||||
message_id: &str,
|
||||
session_id: &str,
|
||||
connector_id: &str,
|
||||
content: &str,
|
||||
) {
|
||||
let buffer = self
|
||||
.buffers
|
||||
.entry(message_id.to_string())
|
||||
.or_insert_with(|| {
|
||||
MessageBuffer::new(
|
||||
message_id.to_string(),
|
||||
session_id.to_string(),
|
||||
connector_id.to_string(),
|
||||
"assistant".to_string(),
|
||||
)
|
||||
});
|
||||
|
||||
buffer.touch();
|
||||
|
||||
if let Some(AccumulatedPart::Thinking { text: existing }) = buffer.parts.last_mut() {
|
||||
existing.push_str(content);
|
||||
} else {
|
||||
buffer
|
||||
.parts
|
||||
.push(AccumulatedPart::Thinking { text: content.to_string() });
|
||||
}
|
||||
}
|
||||
|
||||
/// Add or update a tool call in the message buffer.
|
||||
///
|
||||
/// If a tool call with the same `id` already exists in the buffer, the
|
||||
/// existing entry is updated (input is overwritten only when non-empty;
|
||||
/// output is overwritten when `Some`). Otherwise a new entry is appended,
|
||||
/// preserving event-stream ordering.
|
||||
pub fn add_or_update_tool_call(&mut self, message_id: &str, tool_call: ToolCallData) {
|
||||
if let Some(buffer) = self.buffers.get_mut(message_id) {
|
||||
buffer.last_activity = Utc::now();
|
||||
|
||||
// Try to find and update an existing tool call with the same id.
|
||||
for part in buffer.parts.iter_mut() {
|
||||
if let AccumulatedPart::Tool { data } = part {
|
||||
if data.id == tool_call.id {
|
||||
data.tool_name = tool_call.tool_name;
|
||||
|
||||
if tool_call.input != Value::Null
|
||||
&& tool_call.input != serde_json::json!({})
|
||||
{
|
||||
data.input = tool_call.input;
|
||||
}
|
||||
|
||||
if tool_call.output.is_some() {
|
||||
data.output = tool_call.output;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// First time seeing this tool_call_id -- append.
|
||||
buffer
|
||||
.parts
|
||||
.push(AccumulatedPart::Tool { data: tool_call });
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalize a message and return its accumulated content.
|
||||
///
|
||||
/// The internal buffer for `message_id` is removed. Returns `None` if no
|
||||
/// buffer exists for the given id.
|
||||
pub fn finalize(&mut self, message_id: &str) -> Option<AccumulatedMessage> {
|
||||
let buffer = self.buffers.remove(message_id)?;
|
||||
|
||||
Some(AccumulatedMessage {
|
||||
message_id: buffer.message_id,
|
||||
session_id: buffer.session_id,
|
||||
connector_id: buffer.connector_id,
|
||||
role: buffer.role,
|
||||
parts: buffer.parts,
|
||||
created_at: buffer.created_at,
|
||||
last_activity: buffer.last_activity,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns `true` if a buffer exists for the given message id.
|
||||
pub fn has_buffer(&self, message_id: &str) -> bool {
|
||||
self.buffers.contains_key(message_id)
|
||||
}
|
||||
|
||||
/// Return all buffered message ids that belong to `session_id`.
|
||||
pub fn message_ids_for_session(&self, session_id: &str) -> Vec<String> {
|
||||
self.buffers
|
||||
.iter()
|
||||
.filter(|(_, buf)| buf.session_id == session_id)
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Return all currently buffered message ids.
|
||||
pub fn active_message_ids(&self) -> Vec<String> {
|
||||
self.buffers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Return message ids whose buffers have not been touched for longer than
|
||||
/// `threshold`.
|
||||
pub fn stale_message_ids(&self, threshold: std::time::Duration) -> Vec<String> {
|
||||
let now = Utc::now();
|
||||
self.buffers
|
||||
.iter()
|
||||
.filter(|(_, buf)| {
|
||||
let inactive = now.signed_duration_since(buf.last_activity);
|
||||
inactive
|
||||
.to_std()
|
||||
.unwrap_or(std::time::Duration::ZERO)
|
||||
> threshold
|
||||
})
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_accumulator() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
assert!(acc.finalize("nonexistent").is_none());
|
||||
assert!(acc.active_message_ids().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_chunk_coalescing() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text {
|
||||
text: "Hello, ".to_string(),
|
||||
});
|
||||
acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text {
|
||||
text: "world!".to_string(),
|
||||
});
|
||||
|
||||
let msg = acc.finalize("msg1").unwrap();
|
||||
assert_eq!(msg.parts.len(), 1);
|
||||
match &msg.parts[0] {
|
||||
AccumulatedPart::Text { text } => assert_eq!(text, "Hello, world!"),
|
||||
other => panic!("Expected Text, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_coalescing() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
acc.add_thinking("msg1", "s1", "c1", "First. ");
|
||||
acc.add_thinking("msg1", "s1", "c1", "Second.");
|
||||
|
||||
let msg = acc.finalize("msg1").unwrap();
|
||||
assert_eq!(msg.parts.len(), 1);
|
||||
match &msg.parts[0] {
|
||||
AccumulatedPart::Thinking { text } => assert_eq!(text, "First. Second."),
|
||||
other => panic!("Expected Thinking, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interleaved_parts_preserve_order() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
// text, tool, text -- should produce 3 distinct parts
|
||||
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "Before tool.".to_string(),
|
||||
});
|
||||
|
||||
acc.add_or_update_tool_call("msg1", ToolCallData {
|
||||
id: "tc1".to_string(),
|
||||
tool_name: "grep".to_string(),
|
||||
input: serde_json::json!({"q": "x"}),
|
||||
output: None,
|
||||
});
|
||||
|
||||
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "After tool.".to_string(),
|
||||
});
|
||||
|
||||
let msg = acc.finalize("msg1").unwrap();
|
||||
assert_eq!(msg.parts.len(), 3);
|
||||
assert!(matches!(&msg.parts[0], AccumulatedPart::Text { .. }));
|
||||
assert!(matches!(&msg.parts[1], AccumulatedPart::Tool { .. }));
|
||||
assert!(matches!(&msg.parts[2], AccumulatedPart::Text { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_deduplication() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
// Create buffer first
|
||||
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "hi".to_string(),
|
||||
});
|
||||
|
||||
// Initial tool call
|
||||
acc.add_or_update_tool_call("msg1", ToolCallData {
|
||||
id: "tc1".to_string(),
|
||||
tool_name: "read".to_string(),
|
||||
input: serde_json::json!({"path": "foo.rs"}),
|
||||
output: None,
|
||||
});
|
||||
|
||||
// Update same tool call with output (empty input should NOT overwrite)
|
||||
acc.add_or_update_tool_call("msg1", ToolCallData {
|
||||
id: "tc1".to_string(),
|
||||
tool_name: "read".to_string(),
|
||||
input: serde_json::json!({}),
|
||||
output: Some(serde_json::json!({"content": "fn main() {}"})),
|
||||
});
|
||||
|
||||
let msg = acc.finalize("msg1").unwrap();
|
||||
// text + 1 tool (not 2)
|
||||
assert_eq!(msg.parts.len(), 2);
|
||||
|
||||
match &msg.parts[1] {
|
||||
AccumulatedPart::Tool { data } => {
|
||||
assert_eq!(data.id, "tc1");
|
||||
// Input preserved from first call (non-empty), not overwritten by empty update
|
||||
assert_eq!(data.input, serde_json::json!({"path": "foo.rs"}));
|
||||
// Output set from update
|
||||
assert_eq!(
|
||||
data.output,
|
||||
Some(serde_json::json!({"content": "fn main() {}"}))
|
||||
);
|
||||
}
|
||||
other => panic!("Expected Tool, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_message_parts_non_streaming() {
|
||||
let parts = vec![
|
||||
MessagePart::Text {
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
MessagePart::Thinking {
|
||||
text: "hmm".to_string(),
|
||||
},
|
||||
MessagePart::Code {
|
||||
language: "rs".to_string(),
|
||||
code: "fn main() {}".to_string(),
|
||||
},
|
||||
MessagePart::Tool {
|
||||
tool: "grep".to_string(),
|
||||
tool_call_id: Some("tc1".to_string()),
|
||||
input: serde_json::json!({"q": "x"}),
|
||||
output: Some(serde_json::json!("found")),
|
||||
},
|
||||
MessagePart::File {
|
||||
path: "README.md".to_string(),
|
||||
content: "# Title".to_string(),
|
||||
},
|
||||
// Empty text and code should be skipped
|
||||
MessagePart::Text {
|
||||
text: String::new(),
|
||||
},
|
||||
MessagePart::Code {
|
||||
language: "py".to_string(),
|
||||
code: String::new(),
|
||||
},
|
||||
];
|
||||
|
||||
let msg = AccumulatedMessage::from_message_parts(
|
||||
"msg1".into(),
|
||||
"s1".into(),
|
||||
"c1".into(),
|
||||
"assistant".into(),
|
||||
&parts,
|
||||
);
|
||||
|
||||
// 5 non-empty parts: text, thinking, code-as-text, tool, file-as-text
|
||||
assert_eq!(msg.parts.len(), 5);
|
||||
|
||||
match &msg.parts[0] {
|
||||
AccumulatedPart::Text { text } => assert_eq!(text, "Hello"),
|
||||
other => panic!("Expected Text, got {:?}", other),
|
||||
}
|
||||
match &msg.parts[1] {
|
||||
AccumulatedPart::Thinking { text } => assert_eq!(text, "hmm"),
|
||||
other => panic!("Expected Thinking, got {:?}", other),
|
||||
}
|
||||
match &msg.parts[2] {
|
||||
AccumulatedPart::Text { text } => {
|
||||
assert!(text.contains("```rs"));
|
||||
assert!(text.contains("fn main() {}"));
|
||||
}
|
||||
other => panic!("Expected Text (code), got {:?}", other),
|
||||
}
|
||||
match &msg.parts[3] {
|
||||
AccumulatedPart::Tool { data } => {
|
||||
assert_eq!(data.tool_name, "grep");
|
||||
assert_eq!(data.id, "tc1");
|
||||
assert_eq!(data.output, Some(serde_json::json!("found")));
|
||||
}
|
||||
other => panic!("Expected Tool, got {:?}", other),
|
||||
}
|
||||
match &msg.parts[4] {
|
||||
AccumulatedPart::Text { text } => {
|
||||
assert!(text.contains("File: README.md"));
|
||||
}
|
||||
other => panic!("Expected Text (file), got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_and_stale_queries() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text {
|
||||
text: "a".to_string(),
|
||||
});
|
||||
acc.add_chunk("msg2", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "b".to_string(),
|
||||
});
|
||||
acc.add_chunk("msg3", "s2", "c1", "user", ContentBlock::Text {
|
||||
text: "c".to_string(),
|
||||
});
|
||||
|
||||
// message_ids_for_session
|
||||
let mut s1_ids = acc.message_ids_for_session("s1");
|
||||
s1_ids.sort();
|
||||
assert_eq!(s1_ids, vec!["msg1", "msg2"]);
|
||||
|
||||
let s2_ids = acc.message_ids_for_session("s2");
|
||||
assert_eq!(s2_ids, vec!["msg3"]);
|
||||
|
||||
assert!(acc.message_ids_for_session("s3").is_empty());
|
||||
|
||||
// active_message_ids
|
||||
let mut all = acc.active_message_ids();
|
||||
all.sort();
|
||||
assert_eq!(all, vec!["msg1", "msg2", "msg3"]);
|
||||
|
||||
// has_buffer
|
||||
assert!(acc.has_buffer("msg1"));
|
||||
assert!(!acc.has_buffer("msg99"));
|
||||
|
||||
// stale_message_ids with zero threshold -- everything is stale
|
||||
// (last_activity <= now)
|
||||
let _stale_zero = acc.stale_message_ids(std::time::Duration::ZERO);
|
||||
// All three should be considered stale since last_activity <= now
|
||||
// (Due to timing, they might not all be strictly < now, so we check
|
||||
// with a small threshold instead.)
|
||||
let stale_lenient = acc.stale_message_ids(std::time::Duration::from_secs(0));
|
||||
// At minimum, none should be stale with a huge threshold
|
||||
let not_stale = acc.stale_message_ids(std::time::Duration::from_secs(3600));
|
||||
assert!(not_stale.is_empty());
|
||||
|
||||
// Verify stale detection works by checking the lenient case doesn't
|
||||
// return more ids than we have buffers
|
||||
assert!(stale_lenient.len() <= 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_message_parts() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "Hello ".to_string(),
|
||||
});
|
||||
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "world.".to_string(),
|
||||
});
|
||||
acc.add_thinking("msg1", "s1", "c1", "thinking...");
|
||||
acc.add_or_update_tool_call("msg1", ToolCallData {
|
||||
id: "tc1".to_string(),
|
||||
tool_name: "search".to_string(),
|
||||
input: serde_json::json!({"q": "test"}),
|
||||
output: Some(serde_json::json!("result")),
|
||||
});
|
||||
|
||||
let msg = acc.finalize("msg1").unwrap();
|
||||
let parts = msg.to_message_parts();
|
||||
|
||||
assert_eq!(parts.len(), 3);
|
||||
|
||||
// Coalesced text
|
||||
match &parts[0] {
|
||||
MessagePart::Text { text } => assert_eq!(text, "Hello world."),
|
||||
other => panic!("Expected Text, got {:?}", other),
|
||||
}
|
||||
|
||||
// Thinking
|
||||
match &parts[1] {
|
||||
MessagePart::Thinking { text } => assert_eq!(text, "thinking..."),
|
||||
other => panic!("Expected Thinking, got {:?}", other),
|
||||
}
|
||||
|
||||
// Tool roundtrip
|
||||
match &parts[2] {
|
||||
MessagePart::Tool {
|
||||
tool,
|
||||
tool_call_id,
|
||||
input,
|
||||
output,
|
||||
} => {
|
||||
assert_eq!(tool, "search");
|
||||
assert_eq!(tool_call_id, &Some("tc1".to_string()));
|
||||
assert_eq!(input, &serde_json::json!({"q": "test"}));
|
||||
assert_eq!(output, &Some(serde_json::json!("result")));
|
||||
}
|
||||
other => panic!("Expected Tool, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_empty() {
|
||||
let msg = AccumulatedMessage::from_message_parts(
|
||||
"msg1".into(),
|
||||
"s1".into(),
|
||||
"c1".into(),
|
||||
"user".into(),
|
||||
&[],
|
||||
);
|
||||
assert!(msg.is_empty());
|
||||
|
||||
let msg2 = AccumulatedMessage::from_message_parts(
|
||||
"msg2".into(),
|
||||
"s1".into(),
|
||||
"c1".into(),
|
||||
"user".into(),
|
||||
&[MessagePart::Text {
|
||||
text: "hi".to_string(),
|
||||
}],
|
||||
);
|
||||
assert!(!msg2.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_message_parts_empty_tool_id() {
|
||||
// Tool with no tool_call_id should roundtrip as None
|
||||
let parts = vec![MessagePart::Tool {
|
||||
tool: "bash".to_string(),
|
||||
tool_call_id: None,
|
||||
input: serde_json::json!({"cmd": "ls"}),
|
||||
output: None,
|
||||
}];
|
||||
|
||||
let msg = AccumulatedMessage::from_message_parts(
|
||||
"msg1".into(),
|
||||
"s1".into(),
|
||||
"c1".into(),
|
||||
"assistant".into(),
|
||||
&parts,
|
||||
);
|
||||
|
||||
let roundtripped = msg.to_message_parts();
|
||||
match &roundtripped[0] {
|
||||
MessagePart::Tool { tool_call_id, .. } => {
|
||||
assert_eq!(tool_call_id, &None);
|
||||
}
|
||||
other => panic!("Expected Tool, got {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,11 @@
|
||||
pub mod acp;
|
||||
|
||||
#[cfg(feature = "adapters")]
|
||||
pub mod opencode;
|
||||
#[cfg(feature = "adapters")]
|
||||
pub mod rest;
|
||||
|
||||
pub use acp::{AcpAdapter, AcpTranslationError};
|
||||
|
||||
#[cfg(feature = "adapters")]
|
||||
pub use opencode::{OpenCodeAdapter, TranslationError};
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,167 @@
|
||||
/// REST API conversion helpers
|
||||
///
|
||||
/// Converts OpenCode REST API responses to Dirigent protocol types
|
||||
use crate::{
|
||||
Message, MessageMetadata, MessagePart, MessageRole, MessageStatus, Session, SessionMetadata,
|
||||
};
|
||||
use chrono::{DateTime, TimeZone, Utc};
|
||||
use opencode_client::types as oc;
|
||||
|
||||
/// Convert OpenCode Session to Dirigent Session
|
||||
pub fn convert_session(oc_session: oc::Session) -> Session {
|
||||
Session {
|
||||
id: oc_session.id,
|
||||
title: oc_session.title,
|
||||
created_at: timestamp_to_datetime(oc_session.time.created),
|
||||
updated_at: timestamp_to_datetime(oc_session.time.updated),
|
||||
metadata: SessionMetadata {
|
||||
project_path: oc_session.directory,
|
||||
model: None, // Not available in session info
|
||||
total_messages: 0, // Would need to be calculated separately
|
||||
system_message: None, // Will be set from first assistant message
|
||||
current_mode_id: None,
|
||||
_meta: None,
|
||||
project_id: None,
|
||||
},
|
||||
cwd: None, // OpenCode REST doesn't expose cwd separately from project_path
|
||||
models: None, // OpenCode doesn't provide ACP model state
|
||||
modes: None, // OpenCode doesn't provide ACP mode state
|
||||
config_options: None,
|
||||
acp_client_id: None, // OpenCode doesn't have ACP client ID
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenCode Message to Dirigent Message
|
||||
pub fn convert_message(oc_msg: oc::Message) -> Message {
|
||||
let (id, session_id, role, created_at, status, metadata) = match oc_msg {
|
||||
oc::Message::User(u) => (
|
||||
u.id,
|
||||
u.session_id,
|
||||
MessageRole::User,
|
||||
timestamp_to_datetime(u.time.created),
|
||||
MessageStatus::Completed,
|
||||
None, // User messages don't have metadata
|
||||
),
|
||||
oc::Message::Assistant(a) => {
|
||||
let status = if let Some(err) = a.error {
|
||||
MessageStatus::Failed {
|
||||
error: format_message_error(&err),
|
||||
}
|
||||
} else if a.time.completed.is_some() {
|
||||
MessageStatus::Completed
|
||||
} else {
|
||||
MessageStatus::Streaming
|
||||
};
|
||||
|
||||
// Extract metadata from assistant message
|
||||
let metadata = Some(MessageMetadata {
|
||||
cost: Some(a.cost),
|
||||
tokens_input: Some(a.tokens.input),
|
||||
tokens_output: Some(a.tokens.output),
|
||||
response_time_ms: None,
|
||||
latency_ms: None,
|
||||
model: a.model_id.clone(),
|
||||
other: None,
|
||||
});
|
||||
|
||||
(
|
||||
a.id,
|
||||
a.session_id,
|
||||
MessageRole::Assistant,
|
||||
timestamp_to_datetime(a.time.created),
|
||||
status,
|
||||
metadata,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Message {
|
||||
id,
|
||||
session_id,
|
||||
role,
|
||||
created_at,
|
||||
content: vec![], // Parts are separate
|
||||
status,
|
||||
metadata,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenCode MessageWithParts to Dirigent Message with parts
|
||||
pub fn convert_message_with_parts(oc_msg: oc::MessageWithParts) -> Message {
|
||||
let mut message = convert_message(oc_msg.info);
|
||||
|
||||
// Convert parts
|
||||
message.content = oc_msg
|
||||
.parts
|
||||
.into_iter()
|
||||
.filter_map(|part| convert_part(part))
|
||||
.collect();
|
||||
|
||||
message
|
||||
}
|
||||
|
||||
/// Convert OpenCode Part to Dirigent MessagePart
|
||||
fn convert_part(oc_part: oc::Part) -> Option<MessagePart> {
|
||||
match oc_part {
|
||||
oc::Part::Text(t) => Some(MessagePart::Text { text: t.text }),
|
||||
oc::Part::Reasoning(r) => Some(MessagePart::Thinking { text: r.text }),
|
||||
oc::Part::Tool(t) => {
|
||||
let (input, output) = match t.state {
|
||||
oc::ToolState::Pending => (serde_json::Value::Null, None),
|
||||
oc::ToolState::Running { input, .. } => (input, None),
|
||||
oc::ToolState::Completed { input, output, .. } => {
|
||||
(input, Some(serde_json::Value::String(output)))
|
||||
}
|
||||
oc::ToolState::Error { input, error, .. } => {
|
||||
(input, Some(serde_json::json!({ "error": error })))
|
||||
}
|
||||
};
|
||||
|
||||
Some(MessagePart::Tool {
|
||||
tool: t.tool,
|
||||
tool_call_id: None,
|
||||
input,
|
||||
output,
|
||||
})
|
||||
}
|
||||
oc::Part::File(f) => Some(MessagePart::File {
|
||||
path: f.filename.unwrap_or_else(|| f.url.clone()),
|
||||
content: f.url,
|
||||
}),
|
||||
// Skip unsupported part types
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Unix timestamp (milliseconds) to DateTime<Utc>
|
||||
fn timestamp_to_datetime(timestamp: u64) -> DateTime<Utc> {
|
||||
Utc.timestamp_millis_opt(timestamp as i64)
|
||||
.single()
|
||||
.unwrap_or_else(|| Utc::now())
|
||||
}
|
||||
|
||||
/// Format a message error into a user-friendly string
|
||||
fn format_message_error(error: &oc::MessageError) -> String {
|
||||
match error {
|
||||
oc::MessageError::ProviderAuthError { data } => {
|
||||
format!(
|
||||
"Authentication error for {}: {}",
|
||||
data.provider_id, data.message
|
||||
)
|
||||
}
|
||||
oc::MessageError::UnknownError { data } => {
|
||||
format!("Unknown error: {}", data.message)
|
||||
}
|
||||
oc::MessageError::MessageOutputLengthError => "Message output length exceeded".to_string(),
|
||||
oc::MessageError::MessageAbortedError { data } => {
|
||||
format!("Message aborted: {}", data.message)
|
||||
}
|
||||
oc::MessageError::ApiError { data } => {
|
||||
if let Some(status) = data.status_code {
|
||||
format!("API error ({}): {}", status, data.message)
|
||||
} else {
|
||||
format!("API error: {}", data.message)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Message {
|
||||
pub id: String,
|
||||
pub session_id: String,
|
||||
pub role: MessageRole,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub content: Vec<MessagePart>,
|
||||
pub status: MessageStatus,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<MessageMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct MessageMetadata {
|
||||
// Cost information
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cost: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tokens_input: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tokens_output: Option<u64>,
|
||||
|
||||
// Performance metrics
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_time_ms: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub latency_ms: Option<u64>,
|
||||
|
||||
// Model information
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
|
||||
// Arbitrary metadata from connector clients
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub other: Option<Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum MessageRole {
|
||||
User,
|
||||
Assistant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum MessageStatus {
|
||||
Pending,
|
||||
Streaming,
|
||||
Completed,
|
||||
Failed { error: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum MessagePart {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
Thinking {
|
||||
text: String,
|
||||
},
|
||||
Code {
|
||||
language: String,
|
||||
code: String,
|
||||
},
|
||||
Tool {
|
||||
tool: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
input: Value,
|
||||
output: Option<Value>,
|
||||
},
|
||||
File {
|
||||
path: String,
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,974 @@
|
||||
use crate::session::{ConfigOption, SessionModeState, SessionModelState};
|
||||
use crate::{Message, Session, SessionUpdate};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Reason why the turn was marked complete (for debugging/observability)
|
||||
///
|
||||
/// This enum indicates **how** the system determined that a turn has completed.
|
||||
/// Different connector types use different strategies:
|
||||
///
|
||||
/// - **OpenCode Connector**: Uses `ExplicitSignal` (upstream session.idle event)
|
||||
/// - **ACP Connector (stdio)**: Uses `ResponseReceived` (JSON-RPC response is final)
|
||||
/// - **Gateway Connector**: Uses `OperationsComplete` (tracks pending tool calls)
|
||||
/// - **Fallback**: Uses `IdleTimeout` when no other signal available
|
||||
///
|
||||
/// # Consumer Usage
|
||||
///
|
||||
/// Most consumers should treat all triggers the same (turn is complete).
|
||||
/// The trigger type is primarily for debugging and observability.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum TurnCompleteTrigger {
|
||||
/// Explicit signal from upstream provider (e.g., OpenCode session.idle event)
|
||||
///
|
||||
/// This is the most reliable trigger as it comes directly from the agent system.
|
||||
ExplicitSignal,
|
||||
|
||||
/// JSON-RPC response received (ACP stdio transport)
|
||||
///
|
||||
/// In ACP stdio mode, the response message is the last message in the turn.
|
||||
ResponseReceived,
|
||||
|
||||
/// All tracked operations completed (e.g., pending tool calls resolved)
|
||||
///
|
||||
/// Used when the connector tracks operation state and can determine
|
||||
/// completion by monitoring tool call statuses.
|
||||
OperationsComplete,
|
||||
|
||||
/// Timeout-based idle detection (fallback mechanism)
|
||||
///
|
||||
/// Used when no other completion signal is available.
|
||||
/// The duration indicates how long the system waited before declaring completion.
|
||||
IdleTimeout { duration_ms: u64 },
|
||||
}
|
||||
|
||||
/// A single node in an inspector snapshot (protocol-level DTO).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InspectorSnapshotNode {
|
||||
pub id: String,
|
||||
pub parent: Option<String>,
|
||||
pub children: Vec<String>,
|
||||
pub label: String,
|
||||
pub kind: String,
|
||||
pub state: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub state_detail: Option<String>,
|
||||
pub properties: std::collections::BTreeMap<String, serde_json::Value>,
|
||||
pub created_at: String,
|
||||
pub last_updated: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "event", content = "data")]
|
||||
pub enum Event {
|
||||
// Session events
|
||||
SessionsListed {
|
||||
connector_id: String,
|
||||
sessions: Vec<Session>,
|
||||
},
|
||||
SessionCreated {
|
||||
connector_id: String,
|
||||
session: Session,
|
||||
},
|
||||
SessionUpdated {
|
||||
connector_id: String,
|
||||
session: Session,
|
||||
},
|
||||
SessionMetadataUpdated {
|
||||
connector_id: String,
|
||||
session_id: String,
|
||||
title: Option<String>,
|
||||
total_messages: Option<u32>,
|
||||
model: Option<String>,
|
||||
},
|
||||
SessionDeleted {
|
||||
session_id: String,
|
||||
},
|
||||
/// Session was closed (agent released resources, session remains in list).
|
||||
/// The session can be loaded again later via session/load.
|
||||
SessionClosed {
|
||||
connector_id: String,
|
||||
session_id: String,
|
||||
},
|
||||
SessionSystemMessageSet {
|
||||
session_id: String,
|
||||
system_message: String,
|
||||
},
|
||||
SessionIdle {
|
||||
connector_id: String,
|
||||
session_id: String,
|
||||
},
|
||||
/// Session mode/model metadata received from an ACP connector
|
||||
///
|
||||
/// Emitted when metadata is received from a connector (e.g., after session/new or session/load).
|
||||
/// This event is separate from SessionCreated to support:
|
||||
/// - Session takeover scenarios (session already exists, but metadata is new)
|
||||
/// - Specific subscriptions to metadata changes
|
||||
/// - Connectors that provide metadata asynchronously
|
||||
///
|
||||
/// # Fields
|
||||
/// - `models`: UNSTABLE in ACP spec but used by Claude-ACP
|
||||
/// - `modes`: Stable in ACP spec
|
||||
///
|
||||
/// Both fields are optional since not all connectors provide this data.
|
||||
SessionMetadataReceived {
|
||||
/// Connector that provided the metadata
|
||||
connector_id: String,
|
||||
/// Session the metadata belongs to
|
||||
session_id: String,
|
||||
/// Available models and current model (UNSTABLE in ACP spec)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
models: Option<SessionModelState>,
|
||||
/// Available modes and current mode
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
modes: Option<SessionModeState>,
|
||||
/// ACP config options (replaces modes/models in future ACP versions)
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
config_options: Option<Vec<ConfigOption>>,
|
||||
},
|
||||
|
||||
/// **All content for this turn/message has been received.**
|
||||
///
|
||||
/// This is the **primary signal** for finalization actions (archiving, UI state lock).
|
||||
/// Emitted BEFORE `SessionIdle` to ensure proper event ordering.
|
||||
///
|
||||
/// # Event Semantics
|
||||
///
|
||||
/// - **`MessageCompleted`**: Message metadata is ready (informational)
|
||||
/// - Purpose: UI status updates ("Assistant is typing" → "Complete")
|
||||
/// - Timing: Sent when message record exists, content may still be streaming
|
||||
/// - Consumer action: Update UI state, show completion status
|
||||
///
|
||||
/// - **`TurnComplete`**: All content received (actionable)
|
||||
/// - Purpose: Signal that the entire turn is finalized
|
||||
/// - Timing: Sent AFTER all content chunks, tool calls, and metadata
|
||||
/// - Consumer action: Finalize storage, lock state, trigger post-processing
|
||||
///
|
||||
/// - **`SessionIdle`**: No recent activity (informational)
|
||||
/// - Purpose: UI spinner control, activity indication
|
||||
/// - Timing: Sent AFTER `TurnComplete` as final safety signal
|
||||
/// - Consumer action: Hide spinners, update activity indicators
|
||||
///
|
||||
/// # Event Ordering
|
||||
///
|
||||
/// ```text
|
||||
/// 1. MessageStarted (message created)
|
||||
/// 2. SessionUpdate::*Chunk (content streaming)
|
||||
/// 3. SessionUpdate::ToolCall* (tool execution)
|
||||
/// 4. MessageCompleted (metadata ready)
|
||||
/// 5. TurnComplete ← YOU ARE HERE (finalize!)
|
||||
/// 6. SessionIdle (activity stopped)
|
||||
/// ```
|
||||
///
|
||||
/// # Consumer Behavior
|
||||
///
|
||||
/// | Consumer | MessageCompleted | TurnComplete | SessionIdle |
|
||||
/// |----------|------------------|--------------|-------------|
|
||||
/// | **Archivist** | Ignore | **Finalize and write** | Safety net |
|
||||
/// | **UI Cache** | Update status | **Lock state** | Hide spinner |
|
||||
/// | **Conductor Bridge** | - | **Flush response** | Fallback flush |
|
||||
///
|
||||
/// # Example Usage
|
||||
///
|
||||
/// ```rust
|
||||
/// use dirigent_protocol::{Event, TurnCompleteTrigger};
|
||||
///
|
||||
/// match event {
|
||||
/// Event::TurnComplete { session_id, message_id, trigger, .. } => {
|
||||
/// // Finalize the message in your storage
|
||||
/// archivist.finalize_message(&session_id, &message_id).await?;
|
||||
///
|
||||
/// // Lock UI state
|
||||
/// ui_cache.lock_message(&message_id);
|
||||
///
|
||||
/// // Log trigger for debugging
|
||||
/// println!("Turn complete via {:?}", trigger);
|
||||
/// }
|
||||
/// _ => {}
|
||||
/// }
|
||||
/// ```
|
||||
TurnComplete {
|
||||
connector_id: String,
|
||||
session_id: String,
|
||||
message_id: String,
|
||||
trigger: TurnCompleteTrigger,
|
||||
},
|
||||
/// Session-level error that can be displayed in the chat UI.
|
||||
/// Used when a connector encounters an error during session operations.
|
||||
///
|
||||
/// # Fields
|
||||
///
|
||||
/// - `error_message`: Human-readable error summary
|
||||
/// - `is_recoverable`: Whether the session can continue after this error
|
||||
/// - `error_code`: Optional categorization code (e.g., "TRANSPORT_PARSE_FAILED")
|
||||
/// - `technical_details`: Optional full technical details (truncated if large)
|
||||
/// - `context`: Optional JSON blob with structured error context for debugging
|
||||
SessionError {
|
||||
connector_id: String,
|
||||
session_id: String,
|
||||
error_message: String,
|
||||
/// Whether the session can continue after this error.
|
||||
/// If false, the session should be considered terminated.
|
||||
is_recoverable: bool,
|
||||
/// Error categorization code for UI grouping and filtering.
|
||||
/// Examples: "TRANSPORT_PARSE_FAILED", "SESSION_NOT_FOUND", "TIMEOUT"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error_code: Option<String>,
|
||||
/// Full technical details including stack traces, received content, etc.
|
||||
/// May be truncated if the original content was very large.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
technical_details: Option<String>,
|
||||
/// Structured error context for debug view (JSON blob).
|
||||
/// Contains machine-readable error information.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
context: Option<serde_json::Value>,
|
||||
},
|
||||
/// Session was transferred from one connector to another
|
||||
///
|
||||
/// Emitted by CoreRuntime when a session transfer completes successfully.
|
||||
/// ACP Server should update client mappings on receiving this event.
|
||||
SessionTransferred {
|
||||
/// Source connector ID (where transfer originated)
|
||||
from_connector: String,
|
||||
/// Source session ID
|
||||
from_session: String,
|
||||
/// Target connector ID (where session is now active)
|
||||
to_connector: String,
|
||||
/// New session ID in target connector
|
||||
to_session: String,
|
||||
/// Whether a new session was created (true) or existing loaded (false)
|
||||
is_new_session: bool,
|
||||
/// Available models and current model from the new connector (optional)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
models: Option<SessionModelState>,
|
||||
/// Available modes and current mode from the new connector (optional)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
modes: Option<SessionModeState>,
|
||||
},
|
||||
/// Emitted by archivist when a session registration is durable and list-stable.
|
||||
///
|
||||
/// Frontend can use this to refresh the session list with confidence that
|
||||
/// the session will appear (it's been written to the archive index).
|
||||
/// This replaces any timeout-based delay hacks after session creation.
|
||||
///
|
||||
/// # Fields
|
||||
///
|
||||
/// - `connector_id`: The connector that owns this session
|
||||
/// - `session_id`: The native session ID from the connector
|
||||
/// - `scroll_id`: The archivist's canonical scroll_id for this session
|
||||
SessionRegistered {
|
||||
connector_id: String,
|
||||
session_id: String,
|
||||
/// The archivist's canonical scroll_id for this session
|
||||
scroll_id: String,
|
||||
},
|
||||
|
||||
/// A forwarded session encountered a failure
|
||||
///
|
||||
/// Emitted when a connector that received a transferred session fails.
|
||||
/// Clients should be routed back to Gateway automatically.
|
||||
ForwardingPanic {
|
||||
/// Connector that failed
|
||||
connector_id: String,
|
||||
/// Session that was affected
|
||||
session_id: String,
|
||||
/// Human-readable reason for the failure
|
||||
reason: String,
|
||||
/// ID of the Gateway session to fall back to (if available)
|
||||
fallback_gateway_session: Option<String>,
|
||||
},
|
||||
/// New ACP-style session update (replaces MessagePartAdded for new consumers)
|
||||
SessionUpdate {
|
||||
connector_id: String,
|
||||
session_id: String,
|
||||
update: SessionUpdate,
|
||||
},
|
||||
/// Agent-initiated request requiring client response (e.g., permission prompt)
|
||||
///
|
||||
/// Emitted when an agent sends a request (like session/request_permission) that
|
||||
/// requires user input. The client should respond via the appropriate API endpoint.
|
||||
///
|
||||
/// # Permission Flow Routing
|
||||
///
|
||||
/// - `is_forwarded: false` - Internal session (UI-owned) → Show modal in web UI
|
||||
/// - `is_forwarded: true` - External session (ACP client-owned) → Forward to EventBridge, DO NOT show UI modal
|
||||
///
|
||||
/// The `is_forwarded` field determines whether this permission request should be handled
|
||||
/// by the Dirigent web UI or forwarded to an external ACP client that owns the session.
|
||||
AgentRequest {
|
||||
/// ID of the connector that received the request
|
||||
connector_id: String,
|
||||
/// Session ID from the request parameters
|
||||
session_id: String,
|
||||
/// Request ID from the agent (for correlating the response)
|
||||
request_id: serde_json::Value,
|
||||
/// Method being requested (e.g., "session/request_permission")
|
||||
method: String,
|
||||
/// Request parameters from the agent
|
||||
params: serde_json::Value,
|
||||
/// Whether this is a forwarded (external) session.
|
||||
///
|
||||
/// If `true`, the UI MUST NOT show a permission modal. Instead, the EventBridge
|
||||
/// should forward this request to the external ACP client that owns the session.
|
||||
///
|
||||
/// If `false`, this is an internal session and the UI should show the permission modal.
|
||||
is_forwarded: bool,
|
||||
},
|
||||
|
||||
// ACP Client Connection Events (for UI visibility of incoming connections)
|
||||
/// An ACP client has connected to the server
|
||||
///
|
||||
/// Emitted when a new client connects via the ACP Server.
|
||||
/// Used by UI to show incoming connections in the sidebar.
|
||||
AcpClientConnected {
|
||||
/// Unique client identifier (UUID7)
|
||||
client_id: String,
|
||||
/// When the client connected (ISO 8601 timestamp)
|
||||
connected_at: String,
|
||||
/// Optional client capabilities from the initialize handshake
|
||||
capabilities: Option<serde_json::Value>,
|
||||
/// The Acceptor connector's UID (for archivist meta session creation)
|
||||
connector_uid: String,
|
||||
},
|
||||
/// An ACP client has disconnected from the server
|
||||
///
|
||||
/// Emitted when a client disconnects (explicitly or due to connection loss).
|
||||
/// The client record should be marked as disconnected, not removed (for history).
|
||||
AcpClientDisconnected {
|
||||
/// Unique client identifier
|
||||
client_id: String,
|
||||
/// When the client disconnected (ISO 8601 timestamp)
|
||||
disconnected_at: String,
|
||||
/// Optional reason for disconnection
|
||||
reason: Option<String>,
|
||||
},
|
||||
/// An ACP client has opened a new session via Gateway
|
||||
///
|
||||
/// Emitted when a client creates a new session through the ACP Server.
|
||||
/// This adds an entry to the connection history.
|
||||
AcpClientSessionOpened {
|
||||
/// Client that opened the session
|
||||
client_id: String,
|
||||
/// The Gateway session ID (or initial session before routing)
|
||||
gateway_session_id: String,
|
||||
/// The client-facing session ID
|
||||
client_session_id: String,
|
||||
/// When this occurred (ISO 8601 timestamp)
|
||||
timestamp: String,
|
||||
},
|
||||
/// An ACP client's session was routed to a different connector
|
||||
///
|
||||
/// Emitted when a session is transferred from Gateway to another connector.
|
||||
/// This adds an entry to the connection history showing the route change.
|
||||
AcpClientSessionRouted {
|
||||
/// Client whose session was routed
|
||||
client_id: String,
|
||||
/// Original session ID (typically Gateway session)
|
||||
from_session_id: String,
|
||||
/// New session ID in the target connector
|
||||
to_session_id: String,
|
||||
/// Target connector ID
|
||||
connector_id: String,
|
||||
/// Target connector title (for display)
|
||||
connector_title: String,
|
||||
/// Connector kind (e.g., "opencode", "acp", "gateway")
|
||||
#[serde(default)]
|
||||
connector_kind: Option<String>,
|
||||
/// Current model being used (if known)
|
||||
#[serde(default)]
|
||||
model: Option<String>,
|
||||
/// Agent version/name info (if available)
|
||||
#[serde(default)]
|
||||
agent_info: Option<String>,
|
||||
/// When this occurred (ISO 8601 timestamp)
|
||||
timestamp: String,
|
||||
},
|
||||
|
||||
// Message events
|
||||
MessagesListed {
|
||||
messages: Vec<Message>,
|
||||
},
|
||||
MessageStarted {
|
||||
connector_id: String,
|
||||
message: Message,
|
||||
},
|
||||
MessageCompleted {
|
||||
connector_id: String,
|
||||
message: Message,
|
||||
},
|
||||
MessageFailed {
|
||||
message_id: String,
|
||||
error: String,
|
||||
},
|
||||
|
||||
// Connector lifecycle events
|
||||
ConnectorCreated {
|
||||
connector_id: String,
|
||||
kind: String,
|
||||
title: String,
|
||||
},
|
||||
ConnectorRemoved {
|
||||
connector_id: String,
|
||||
},
|
||||
ConnectorStateChanged {
|
||||
connector_id: String,
|
||||
state: String,
|
||||
/// Machine-readable error classification ("offline", "unstable", "connection_failed")
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
error_kind: Option<String>,
|
||||
},
|
||||
|
||||
// System events
|
||||
Connected,
|
||||
Disconnected,
|
||||
Error {
|
||||
message: String,
|
||||
},
|
||||
|
||||
// Inspector events (runtime tree visualization)
|
||||
/// Full snapshot of the inspector tree — sent on initial connection
|
||||
/// and can be requested via server function.
|
||||
InspectorSnapshot {
|
||||
/// ISO 8601 timestamp of the snapshot
|
||||
timestamp: String,
|
||||
/// All nodes in the tree
|
||||
nodes: Vec<InspectorSnapshotNode>,
|
||||
/// Total node count
|
||||
node_count: usize,
|
||||
},
|
||||
/// A new node was registered in the inspector tree
|
||||
InspectorNodeRegistered {
|
||||
id: String,
|
||||
parent: String,
|
||||
kind: String,
|
||||
},
|
||||
/// A node was removed from the inspector tree
|
||||
InspectorNodeRemoved {
|
||||
id: String,
|
||||
},
|
||||
/// A node's lifecycle state changed
|
||||
InspectorStateChanged {
|
||||
id: String,
|
||||
old: String,
|
||||
new: String,
|
||||
},
|
||||
/// A node's properties were updated
|
||||
InspectorPropertiesUpdated {
|
||||
id: String,
|
||||
keys: Vec<String>,
|
||||
},
|
||||
|
||||
// System task events
|
||||
/// A background system task changed status (completed, failed, cancelled).
|
||||
///
|
||||
/// Emitted by the SystemTaskRegistry when a task reaches a terminal state.
|
||||
/// Allows the UI to react to task completion without polling.
|
||||
SystemTaskStatusChanged {
|
||||
/// Unique task identifier (UUIDv7)
|
||||
task_id: String,
|
||||
/// What kind of operation (e.g., "ClaudeImport")
|
||||
kind: String,
|
||||
/// Terminal status: "completed", "failed", or "cancelled"
|
||||
status: String,
|
||||
/// JSON result payload (present when status == "completed")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
result_json: Option<String>,
|
||||
/// Error message (present when status == "failed")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::ContentBlock;
|
||||
|
||||
#[test]
|
||||
fn test_session_update_variant_serialization() {
|
||||
let update = SessionUpdate::UserMessageChunk {
|
||||
message_id: "msg_123".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Hello from event".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let event = Event::SessionUpdate {
|
||||
connector_id: "conn_123".to_string(),
|
||||
session_id: "session_456".to_string(),
|
||||
update,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains(r#""event":"SessionUpdate"#));
|
||||
assert!(json.contains(r#""session_id":"session_456"#));
|
||||
assert!(json.contains(r#""type":"user_message_chunk"#));
|
||||
assert!(json.contains(r#""message_id":"msg_123"#));
|
||||
assert!(json.contains(r#""text":"Hello from event"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_update_variant_deserialization() {
|
||||
let json = r#"{
|
||||
"event": "SessionUpdate",
|
||||
"data": {
|
||||
"connector_id": "conn_123",
|
||||
"session_id": "session_789",
|
||||
"update": {
|
||||
"type": "agent_message_chunk",
|
||||
"message_id": "msg_789",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "Agent response"
|
||||
}
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
let event: Event = serde_json::from_str(json).unwrap();
|
||||
match event {
|
||||
Event::SessionUpdate {
|
||||
connector_id,
|
||||
session_id,
|
||||
update,
|
||||
} => {
|
||||
assert_eq!(connector_id, "conn_123");
|
||||
assert_eq!(session_id, "session_789");
|
||||
match update {
|
||||
SessionUpdate::AgentMessageChunk {
|
||||
message_id,
|
||||
content,
|
||||
_meta,
|
||||
} => {
|
||||
assert_eq!(message_id, "msg_789");
|
||||
assert_eq!(
|
||||
content,
|
||||
ContentBlock::Text {
|
||||
text: "Agent response".to_string()
|
||||
}
|
||||
);
|
||||
assert_eq!(_meta, None);
|
||||
}
|
||||
_ => panic!("Expected AgentMessageChunk"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected SessionUpdate event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_update_variant_roundtrip() {
|
||||
let original = Event::SessionUpdate {
|
||||
connector_id: "conn_roundtrip".to_string(),
|
||||
session_id: "session_roundtrip".to_string(),
|
||||
update: SessionUpdate::AgentThoughtChunk {
|
||||
message_id: "msg_thought".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Thinking...".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: Event = serde_json::from_str(&json).unwrap();
|
||||
|
||||
match (&original, &deserialized) {
|
||||
(
|
||||
Event::SessionUpdate {
|
||||
connector_id: cid1,
|
||||
session_id: sid1,
|
||||
update: update1,
|
||||
},
|
||||
Event::SessionUpdate {
|
||||
connector_id: cid2,
|
||||
session_id: sid2,
|
||||
update: update2,
|
||||
},
|
||||
) => {
|
||||
assert_eq!(cid1, cid2);
|
||||
assert_eq!(sid1, sid2);
|
||||
assert_eq!(update1, update2);
|
||||
}
|
||||
_ => panic!("Roundtrip failed"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_existing_events_still_work() {
|
||||
// Verify that existing event variants are not affected
|
||||
use crate::SessionMetadata;
|
||||
use chrono::Utc;
|
||||
|
||||
let now = Utc::now();
|
||||
let session_created = Event::SessionCreated {
|
||||
connector_id: "conn_test".to_string(),
|
||||
session: Session {
|
||||
id: "session_123".to_string(),
|
||||
title: "Test Session".to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
metadata: SessionMetadata {
|
||||
project_path: "/test".to_string(),
|
||||
model: Some("gpt-4".to_string()),
|
||||
total_messages: 0,
|
||||
system_message: None,
|
||||
current_mode_id: None,
|
||||
_meta: None,
|
||||
project_id: None,
|
||||
},
|
||||
cwd: None,
|
||||
models: None,
|
||||
modes: None,
|
||||
config_options: None,
|
||||
acp_client_id: None,
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&session_created).unwrap();
|
||||
let _deserialized: Event = serde_json::from_str(&json).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_error_serialization() {
|
||||
let event = Event::SessionError {
|
||||
connector_id: "acp_conn_1".to_string(),
|
||||
session_id: "session_456".to_string(),
|
||||
error_message: "Session not found".to_string(),
|
||||
is_recoverable: false,
|
||||
error_code: None,
|
||||
technical_details: None,
|
||||
context: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains(r#""event":"SessionError"#));
|
||||
assert!(json.contains(r#""connector_id":"acp_conn_1"#));
|
||||
assert!(json.contains(r#""session_id":"session_456"#));
|
||||
assert!(json.contains(r#""error_message":"Session not found"#));
|
||||
assert!(json.contains(r#""is_recoverable":false"#));
|
||||
// Optional fields should not be present when None
|
||||
assert!(!json.contains(r#""error_code"#));
|
||||
assert!(!json.contains(r#""technical_details"#));
|
||||
assert!(!json.contains(r#""context"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_error_with_details_serialization() {
|
||||
let event = Event::SessionError {
|
||||
connector_id: "acp_conn_1".to_string(),
|
||||
session_id: "session_456".to_string(),
|
||||
error_message: "Transport parse failed".to_string(),
|
||||
is_recoverable: true,
|
||||
error_code: Some("TRANSPORT_PARSE_FAILED".to_string()),
|
||||
technical_details: Some("Failed to parse JSON: expected value at line 1".to_string()),
|
||||
context: Some(serde_json::json!({
|
||||
"received_bytes": 1024,
|
||||
"received_preview": "option { key: ...",
|
||||
"expected": "JSON-RPC message"
|
||||
})),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains(r#""error_code":"TRANSPORT_PARSE_FAILED"#));
|
||||
assert!(json.contains(r#""technical_details"#));
|
||||
assert!(json.contains(r#""context"#));
|
||||
assert!(json.contains(r#""received_bytes":1024"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_error_deserialization() {
|
||||
// Test backward compatibility - old format without new fields
|
||||
let json = r#"{
|
||||
"event": "SessionError",
|
||||
"data": {
|
||||
"connector_id": "conn_test",
|
||||
"session_id": "session_789",
|
||||
"error_message": "Connection timeout",
|
||||
"is_recoverable": true
|
||||
}
|
||||
}"#;
|
||||
|
||||
let event: Event = serde_json::from_str(json).unwrap();
|
||||
match event {
|
||||
Event::SessionError {
|
||||
connector_id,
|
||||
session_id,
|
||||
error_message,
|
||||
is_recoverable,
|
||||
error_code,
|
||||
technical_details,
|
||||
context,
|
||||
} => {
|
||||
assert_eq!(connector_id, "conn_test");
|
||||
assert_eq!(session_id, "session_789");
|
||||
assert_eq!(error_message, "Connection timeout");
|
||||
assert!(is_recoverable);
|
||||
assert!(error_code.is_none());
|
||||
assert!(technical_details.is_none());
|
||||
assert!(context.is_none());
|
||||
}
|
||||
_ => panic!("Expected SessionError event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_error_roundtrip() {
|
||||
let original = Event::SessionError {
|
||||
connector_id: "roundtrip_conn".to_string(),
|
||||
session_id: "roundtrip_session".to_string(),
|
||||
error_message: "API rate limit exceeded".to_string(),
|
||||
is_recoverable: true,
|
||||
error_code: Some("RATE_LIMITED".to_string()),
|
||||
technical_details: Some("429 Too Many Requests".to_string()),
|
||||
context: Some(serde_json::json!({"retry_after": 60})),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: Event = serde_json::from_str(&json).unwrap();
|
||||
|
||||
match (&original, &deserialized) {
|
||||
(
|
||||
Event::SessionError {
|
||||
connector_id: cid1,
|
||||
session_id: sid1,
|
||||
error_message: err1,
|
||||
is_recoverable: rec1,
|
||||
error_code: code1,
|
||||
technical_details: details1,
|
||||
context: ctx1,
|
||||
},
|
||||
Event::SessionError {
|
||||
connector_id: cid2,
|
||||
session_id: sid2,
|
||||
error_message: err2,
|
||||
is_recoverable: rec2,
|
||||
error_code: code2,
|
||||
technical_details: details2,
|
||||
context: ctx2,
|
||||
},
|
||||
) => {
|
||||
assert_eq!(cid1, cid2);
|
||||
assert_eq!(sid1, sid2);
|
||||
assert_eq!(code1, code2);
|
||||
assert_eq!(details1, details2);
|
||||
assert_eq!(ctx1, ctx2);
|
||||
assert_eq!(err1, err2);
|
||||
assert_eq!(rec1, rec2);
|
||||
}
|
||||
_ => panic!("Roundtrip failed"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_transferred_serialization() {
|
||||
let event = Event::SessionTransferred {
|
||||
from_connector: "gateway-1".to_string(),
|
||||
from_session: "session-old".to_string(),
|
||||
to_connector: "opencode-1".to_string(),
|
||||
to_session: "session-new".to_string(),
|
||||
is_new_session: true,
|
||||
models: None,
|
||||
modes: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains(r#""event":"SessionTransferred"#));
|
||||
assert!(json.contains(r#""from_connector":"gateway-1"#));
|
||||
assert!(json.contains(r#""to_connector":"opencode-1"#));
|
||||
|
||||
let deserialized: Event = serde_json::from_str(&json).unwrap();
|
||||
// Verify roundtrip
|
||||
match deserialized {
|
||||
Event::SessionTransferred {
|
||||
from_connector,
|
||||
from_session,
|
||||
to_connector,
|
||||
to_session,
|
||||
is_new_session,
|
||||
models,
|
||||
modes,
|
||||
} => {
|
||||
assert_eq!(from_connector, "gateway-1");
|
||||
assert_eq!(from_session, "session-old");
|
||||
assert_eq!(to_connector, "opencode-1");
|
||||
assert_eq!(to_session, "session-new");
|
||||
assert!(is_new_session);
|
||||
assert!(models.is_none());
|
||||
assert!(modes.is_none());
|
||||
}
|
||||
_ => panic!("Expected SessionTransferred event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_forwarding_panic_serialization() {
|
||||
let event = Event::ForwardingPanic {
|
||||
connector_id: "opencode-1".to_string(),
|
||||
session_id: "session-123".to_string(),
|
||||
reason: "Connection lost".to_string(),
|
||||
fallback_gateway_session: Some("gateway-session-1".to_string()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains(r#""event":"ForwardingPanic"#));
|
||||
assert!(json.contains(r#""connector_id":"opencode-1"#));
|
||||
assert!(json.contains(r#""session_id":"session-123"#));
|
||||
assert!(json.contains(r#""reason":"Connection lost"#));
|
||||
|
||||
let deserialized: Event = serde_json::from_str(&json).unwrap();
|
||||
// Verify roundtrip
|
||||
match deserialized {
|
||||
Event::ForwardingPanic {
|
||||
connector_id,
|
||||
session_id,
|
||||
reason,
|
||||
fallback_gateway_session,
|
||||
} => {
|
||||
assert_eq!(connector_id, "opencode-1");
|
||||
assert_eq!(session_id, "session-123");
|
||||
assert_eq!(reason, "Connection lost");
|
||||
assert_eq!(
|
||||
fallback_gateway_session,
|
||||
Some("gateway-session-1".to_string())
|
||||
);
|
||||
}
|
||||
_ => panic!("Expected ForwardingPanic event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_metadata_received_full() {
|
||||
use crate::session::{ModelInfo, SessionMode, SessionModeState, SessionModelState};
|
||||
|
||||
let event = Event::SessionMetadataReceived {
|
||||
connector_id: "claude-acp-1".to_string(),
|
||||
session_id: "session-123".to_string(),
|
||||
models: Some(SessionModelState {
|
||||
available_models: vec![
|
||||
ModelInfo {
|
||||
model_id: "default".to_string(),
|
||||
name: "Default (recommended)".to_string(),
|
||||
description: Some("Opus 4.5".to_string()),
|
||||
},
|
||||
ModelInfo {
|
||||
model_id: "sonnet".to_string(),
|
||||
name: "Sonnet".to_string(),
|
||||
description: None,
|
||||
},
|
||||
],
|
||||
current_model_id: "default".to_string(),
|
||||
}),
|
||||
modes: Some(SessionModeState {
|
||||
current_mode_id: "default".to_string(),
|
||||
available_modes: vec![
|
||||
SessionMode {
|
||||
id: "default".to_string(),
|
||||
name: "Always Ask".to_string(),
|
||||
description: Some("Prompts for permission".to_string()),
|
||||
},
|
||||
SessionMode {
|
||||
id: "plan".to_string(),
|
||||
name: "Plan Mode".to_string(),
|
||||
description: None,
|
||||
},
|
||||
],
|
||||
}),
|
||||
config_options: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
assert!(json.contains(r#""event":"SessionMetadataReceived"#));
|
||||
assert!(json.contains(r#""connector_id":"claude-acp-1"#));
|
||||
assert!(json.contains(r#""session_id":"session-123"#));
|
||||
// Check camelCase in nested types
|
||||
assert!(json.contains("availableModels"));
|
||||
assert!(json.contains("currentModelId"));
|
||||
assert!(json.contains("availableModes"));
|
||||
assert!(json.contains("currentModeId"));
|
||||
|
||||
let deserialized: Event = serde_json::from_str(&json).unwrap();
|
||||
match deserialized {
|
||||
Event::SessionMetadataReceived {
|
||||
connector_id,
|
||||
session_id,
|
||||
models,
|
||||
modes,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(connector_id, "claude-acp-1");
|
||||
assert_eq!(session_id, "session-123");
|
||||
assert!(models.is_some());
|
||||
assert!(modes.is_some());
|
||||
let models = models.unwrap();
|
||||
assert_eq!(models.current_model_id, "default");
|
||||
assert_eq!(models.available_models.len(), 2);
|
||||
let modes = modes.unwrap();
|
||||
assert_eq!(modes.current_mode_id, "default");
|
||||
assert_eq!(modes.available_modes.len(), 2);
|
||||
}
|
||||
_ => panic!("Expected SessionMetadataReceived event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_metadata_received_partial() {
|
||||
// Test with only modes (models is None)
|
||||
use crate::session::{SessionMode, SessionModeState};
|
||||
|
||||
let event = Event::SessionMetadataReceived {
|
||||
connector_id: "gateway-1".to_string(),
|
||||
session_id: "session-456".to_string(),
|
||||
models: None,
|
||||
modes: Some(SessionModeState {
|
||||
current_mode_id: "default".to_string(),
|
||||
available_modes: vec![SessionMode {
|
||||
id: "default".to_string(),
|
||||
name: "Default".to_string(),
|
||||
description: None,
|
||||
}],
|
||||
}),
|
||||
config_options: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
// models should be skipped when None
|
||||
assert!(!json.contains("availableModels"));
|
||||
assert!(json.contains("availableModes"));
|
||||
|
||||
let deserialized: Event = serde_json::from_str(&json).unwrap();
|
||||
match deserialized {
|
||||
Event::SessionMetadataReceived { models, modes, .. } => {
|
||||
assert!(models.is_none());
|
||||
assert!(modes.is_some());
|
||||
}
|
||||
_ => panic!("Expected SessionMetadataReceived event"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_metadata_received_empty() {
|
||||
// Test with both None (connector provides no metadata)
|
||||
let event = Event::SessionMetadataReceived {
|
||||
connector_id: "generic-1".to_string(),
|
||||
session_id: "session-789".to_string(),
|
||||
models: None,
|
||||
modes: None,
|
||||
config_options: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&event).unwrap();
|
||||
// Both should be skipped when None
|
||||
assert!(!json.contains("models"));
|
||||
assert!(!json.contains("modes"));
|
||||
|
||||
let deserialized: Event = serde_json::from_str(&json).unwrap();
|
||||
match deserialized {
|
||||
Event::SessionMetadataReceived { models, modes, .. } => {
|
||||
assert!(models.is_none());
|
||||
assert!(modes.is_none());
|
||||
}
|
||||
_ => panic!("Expected SessionMetadataReceived event"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
//! Inspector node types for the process tree.
|
||||
//!
|
||||
//! These types represent nodes in the inspector's hierarchical process tree,
|
||||
//! providing a canonical definition that can be shared between the server-side
|
||||
//! inspector and WASM-based UI.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
|
||||
/// Hierarchical identifier for a node in the inspector tree.
|
||||
///
|
||||
/// Uses `/`-separated segments (e.g., `"root/connector-1/process-a"`).
|
||||
#[derive(Clone, Debug, Hash, Eq, PartialEq, Serialize, Deserialize)]
|
||||
pub struct NodeId(pub String);
|
||||
|
||||
impl NodeId {
|
||||
pub fn new(id: impl Into<String>) -> Self {
|
||||
Self(id.into())
|
||||
}
|
||||
|
||||
/// Create a child node ID by appending a segment.
|
||||
pub fn child(&self, segment: &str) -> Self {
|
||||
Self(format!("{}/{}", self.0, segment))
|
||||
}
|
||||
|
||||
/// Get the parent node ID (everything before the last `/`).
|
||||
pub fn parent(&self) -> Option<Self> {
|
||||
self.0.rfind('/').map(|idx| Self(self.0[..idx].to_string()))
|
||||
}
|
||||
|
||||
/// Get the last segment of the path (the node's own name).
|
||||
pub fn name(&self) -> &str {
|
||||
self.0.rsplit('/').next().unwrap_or(&self.0)
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for NodeId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for NodeId {
|
||||
fn from(s: &str) -> Self {
|
||||
Self(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for NodeId {
|
||||
fn from(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
}
|
||||
|
||||
/// The kind of node in the inspector tree.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub enum NodeKind {
|
||||
Root,
|
||||
Connector,
|
||||
Process,
|
||||
Service,
|
||||
AsyncTask,
|
||||
System,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for NodeKind {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
NodeKind::Root => write!(f, "Root"),
|
||||
NodeKind::Connector => write!(f, "Connector"),
|
||||
NodeKind::Process => write!(f, "Process"),
|
||||
NodeKind::Service => write!(f, "Service"),
|
||||
NodeKind::AsyncTask => write!(f, "AsyncTask"),
|
||||
NodeKind::System => write!(f, "System"),
|
||||
NodeKind::Custom(name) => write!(f, "Custom({})", name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The runtime state of an inspector node.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub enum NodeState {
|
||||
Initializing,
|
||||
Running,
|
||||
Idle,
|
||||
Busy(String),
|
||||
Degraded(String),
|
||||
Error(String),
|
||||
Stopped,
|
||||
}
|
||||
|
||||
impl fmt::Display for NodeState {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
NodeState::Initializing => write!(f, "Initializing"),
|
||||
NodeState::Running => write!(f, "Running"),
|
||||
NodeState::Idle => write!(f, "Idle"),
|
||||
NodeState::Busy(desc) => write!(f, "Busy({})", desc),
|
||||
NodeState::Degraded(reason) => write!(f, "Degraded({})", reason),
|
||||
NodeState::Error(msg) => write!(f, "Error({})", msg),
|
||||
NodeState::Stopped => write!(f, "Stopped"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Metadata associated with an inspector node.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct NodeMetadata {
|
||||
pub kind: NodeKind,
|
||||
pub label: String,
|
||||
pub state: NodeState,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub last_updated: chrono::DateTime<chrono::Utc>,
|
||||
pub properties: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
impl NodeMetadata {
|
||||
pub fn new(kind: NodeKind, label: impl Into<String>) -> Self {
|
||||
let now = chrono::Utc::now();
|
||||
Self {
|
||||
kind,
|
||||
label: label.into(),
|
||||
state: NodeState::Initializing,
|
||||
created_at: now,
|
||||
last_updated: now,
|
||||
properties: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_state(mut self, state: NodeState) -> Self {
|
||||
self.state = state;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_property(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
|
||||
self.properties.insert(key.into(), value);
|
||||
self
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,187 @@
|
||||
pub mod accumulator;
|
||||
pub mod adapters;
|
||||
pub mod conversation;
|
||||
pub mod events;
|
||||
pub mod inspector;
|
||||
pub mod log_utils;
|
||||
pub mod project;
|
||||
pub mod session;
|
||||
pub mod sharing;
|
||||
pub mod streaming;
|
||||
pub mod types;
|
||||
|
||||
pub use conversation::{Message, MessageMetadata, MessagePart, MessageRole, MessageStatus};
|
||||
pub use events::{Event, InspectorSnapshotNode, TurnCompleteTrigger};
|
||||
pub use inspector::{NodeId, NodeKind, NodeMetadata, NodeState};
|
||||
pub use session::{
|
||||
ConfigOption, ConfigOptionType, ConfigOptionValue, ModelId, ModelInfo, Session,
|
||||
SessionMetadata, SessionMode, SessionModeId, SessionModeState, SessionModelState,
|
||||
SessionOrigin, SessionOwnership, ToolHandler,
|
||||
};
|
||||
pub use types::{
|
||||
ContentBlock, Meta, PermissionOption, PermissionOptionKind, PermissionToolCallStatus,
|
||||
ProviderMeta, RequestPermissionOutcome, RequestPermissionResponse, SessionUpdate, ToolCall,
|
||||
ToolCallContent, ToolCallId, ToolCallInfo, ToolCallLocation, ToolCallStatus, ToolKind,
|
||||
};
|
||||
pub use sharing::{SessionShare, ShareId, ShareSummary};
|
||||
pub use accumulator::{AccumulatedMessage, AccumulatedPart, MessageAccumulator, ToolCallData as AccumulatorToolCallData};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_public_api_imports() {
|
||||
// Test that all types are accessible from the crate root
|
||||
|
||||
// ContentBlock
|
||||
let _content = ContentBlock::Text {
|
||||
text: "test".to_string(),
|
||||
};
|
||||
|
||||
// Meta and ProviderMeta
|
||||
let _meta = Meta::default();
|
||||
let _provider = ProviderMeta {
|
||||
name: "test".to_string(),
|
||||
original_ids: None,
|
||||
raw_excerpt: None,
|
||||
};
|
||||
|
||||
// ToolCall, ToolCallId, ToolCallStatus
|
||||
let _tool_call_id: ToolCallId = "call_123".to_string();
|
||||
let _status = ToolCallStatus::Pending;
|
||||
let _tool_call = ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "test".to_string(),
|
||||
status: ToolCallStatus::Running,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
// SessionUpdate
|
||||
let _update = SessionUpdate::UserMessageChunk {
|
||||
message_id: "msg_123".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
// If this compiles, all types are accessible via use dirigent_protocol::{...}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_metadata_types_accessible() {
|
||||
// Test that new ACP session metadata types are accessible from crate root
|
||||
|
||||
// ModelId and SessionModeId (type aliases)
|
||||
let _model_id: ModelId = "default".to_string();
|
||||
let _mode_id: SessionModeId = "plan".to_string();
|
||||
|
||||
// ModelInfo
|
||||
let _model_info = ModelInfo {
|
||||
model_id: "default".to_string(),
|
||||
name: "Default".to_string(),
|
||||
description: Some("Main model".to_string()),
|
||||
};
|
||||
|
||||
// SessionModelState
|
||||
let _model_state = SessionModelState {
|
||||
available_models: vec![_model_info],
|
||||
current_model_id: "default".to_string(),
|
||||
};
|
||||
|
||||
// SessionMode
|
||||
let _mode = SessionMode {
|
||||
id: "default".to_string(),
|
||||
name: "Default".to_string(),
|
||||
description: None,
|
||||
};
|
||||
|
||||
// SessionModeState
|
||||
let _mode_state = SessionModeState {
|
||||
current_mode_id: "default".to_string(),
|
||||
available_modes: vec![_mode],
|
||||
};
|
||||
|
||||
// SessionMetadataReceived event
|
||||
let _event = Event::SessionMetadataReceived {
|
||||
connector_id: "test".to_string(),
|
||||
session_id: "test".to_string(),
|
||||
models: Some(_model_state),
|
||||
modes: Some(_mode_state),
|
||||
config_options: None,
|
||||
};
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_types_accessible() {
|
||||
// Test that ACP permission types are accessible from crate root
|
||||
|
||||
// PermissionOption and PermissionOptionKind
|
||||
let _option = PermissionOption {
|
||||
option_id: "allow_1".to_string(),
|
||||
name: "Allow once".to_string(),
|
||||
kind: PermissionOptionKind::AllowOnce,
|
||||
};
|
||||
|
||||
// RequestPermissionResponse and RequestPermissionOutcome
|
||||
let _response = RequestPermissionResponse {
|
||||
outcome: RequestPermissionOutcome::Selected {
|
||||
option_id: "allow_1".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
// ToolCallInfo and related types
|
||||
let _info = ToolCallInfo {
|
||||
tool_call_id: "call_123".to_string(),
|
||||
title: "Read file".to_string(),
|
||||
kind: Some(ToolKind::Read),
|
||||
status: Some(PermissionToolCallStatus::Pending),
|
||||
locations: Some(vec![ToolCallLocation {
|
||||
path: "/test.txt".to_string(),
|
||||
line: Some(10),
|
||||
}]),
|
||||
raw_input: None,
|
||||
};
|
||||
|
||||
// If this compiles, all permission types are accessible
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_ownership_types_accessible() {
|
||||
// Test that Session Ownership Model types are accessible from crate root
|
||||
|
||||
// SessionOrigin
|
||||
let _origin_internal = SessionOrigin::Internal;
|
||||
let _origin_external = SessionOrigin::External {
|
||||
client_id: "test".to_string(),
|
||||
client_capabilities: None,
|
||||
};
|
||||
|
||||
// ToolHandler
|
||||
let _handler_agent = ToolHandler::Agent;
|
||||
let _handler_dirigent = ToolHandler::Dirigent;
|
||||
let _handler_forward = ToolHandler::ForwardToClient;
|
||||
|
||||
// SessionOwnership and its constructors
|
||||
let _ownership_default = SessionOwnership::default();
|
||||
let _ownership_internal = SessionOwnership::internal();
|
||||
let _ownership_forwarded =
|
||||
SessionOwnership::external_forwarded("client-123".to_string(), None);
|
||||
let _ownership_handled = SessionOwnership::external_handled("client-456".to_string(), None);
|
||||
|
||||
// Test helper methods
|
||||
assert!(!_ownership_internal.is_external());
|
||||
assert!(_ownership_forwarded.is_external());
|
||||
assert_eq!(_ownership_internal.client_id(), None);
|
||||
assert_eq!(_ownership_forwarded.client_id(), Some("client-123"));
|
||||
|
||||
// If this compiles, all ownership types are accessible
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
/// Utilities for masking sensitive or verbose content in logs
|
||||
use serde_json::Value;
|
||||
|
||||
/// Truncate long strings to a reasonable length for logging
|
||||
const MAX_LOG_LENGTH: usize = 100;
|
||||
|
||||
/// Extract filename from a file path (handles both Unix and Windows paths)
|
||||
fn extract_filename(path: &str) -> &str {
|
||||
// Try Unix-style path separator first, then Windows
|
||||
if path.contains('/') {
|
||||
path.split('/').last().unwrap_or("file")
|
||||
} else if path.contains('\\') {
|
||||
path.split('\\').last().unwrap_or("file")
|
||||
} else {
|
||||
// No path separator, just a filename
|
||||
path
|
||||
}
|
||||
}
|
||||
|
||||
/// Mask text content in JSON values for concise logging
|
||||
///
|
||||
/// This function recursively processes JSON and replaces long text fields
|
||||
/// with truncated versions or placeholders, while preserving structure
|
||||
/// and non-text metadata.
|
||||
pub fn mask_content(value: &Value) -> Value {
|
||||
match value {
|
||||
Value::String(s) => {
|
||||
if s.len() > MAX_LOG_LENGTH {
|
||||
Value::String(format!("... ({} chars)", s.len()))
|
||||
} else {
|
||||
Value::String(s.clone())
|
||||
}
|
||||
}
|
||||
Value::Array(arr) => {
|
||||
Value::Array(arr.iter().map(mask_content).collect())
|
||||
}
|
||||
Value::Object(obj) => {
|
||||
let mut masked = serde_json::Map::new();
|
||||
for (k, v) in obj {
|
||||
// Mask known content fields (only if they're strings)
|
||||
if k == "text" || k == "content_md" || k == "message" || k == "thinking" {
|
||||
if let Value::String(s) = v {
|
||||
if s.len() > 50 {
|
||||
masked.insert(k.clone(), Value::String(format!("... ({} chars)", s.len())));
|
||||
} else if s.len() > 0 {
|
||||
masked.insert(k.clone(), Value::String("...".to_string()));
|
||||
} else {
|
||||
masked.insert(k.clone(), Value::String("".to_string()));
|
||||
}
|
||||
} else {
|
||||
// Not a string, recurse (e.g., content as object)
|
||||
masked.insert(k.clone(), mask_content(v));
|
||||
}
|
||||
}
|
||||
// Mask raw input/output fields (large blobs of data)
|
||||
else if k == "rawOutput" || k == "rawInput" || k == "raw_output" || k == "raw_input" {
|
||||
if let Value::String(s) = v {
|
||||
if s.len() > 50 {
|
||||
masked.insert(k.clone(), Value::String(format!("... ({} chars)", s.len())));
|
||||
} else if s.len() > 0 {
|
||||
masked.insert(k.clone(), Value::String("...".to_string()));
|
||||
} else {
|
||||
masked.insert(k.clone(), Value::String("".to_string()));
|
||||
}
|
||||
} else {
|
||||
// Not a string, recurse
|
||||
masked.insert(k.clone(), mask_content(v));
|
||||
}
|
||||
}
|
||||
// Mask filename/path fields
|
||||
else if k == "filename" || k == "file" || k == "path" || k == "file_path" || k == "filepath" {
|
||||
if let Value::String(s) = v {
|
||||
// Extract just the filename from path for debugging context
|
||||
let filename = extract_filename(s);
|
||||
masked.insert(k.clone(), Value::String(format!("<{}>", filename)));
|
||||
} else if let Value::Array(arr) = v {
|
||||
// Array of filenames
|
||||
let masked_arr: Vec<Value> = arr.iter().map(|item| {
|
||||
if let Value::String(s) = item {
|
||||
let filename = extract_filename(s);
|
||||
Value::String(format!("<{}>", filename))
|
||||
} else {
|
||||
item.clone()
|
||||
}
|
||||
}).collect();
|
||||
masked.insert(k.clone(), Value::Array(masked_arr));
|
||||
} else {
|
||||
masked.insert(k.clone(), mask_content(v));
|
||||
}
|
||||
}
|
||||
// Mask arrays of filenames
|
||||
else if k == "filenames" || k == "files" || k == "paths" {
|
||||
if let Value::Array(arr) = v {
|
||||
let masked_arr: Vec<Value> = arr.iter().map(|item| {
|
||||
if let Value::String(s) = item {
|
||||
let filename = extract_filename(s);
|
||||
Value::String(format!("<{}>", filename))
|
||||
} else {
|
||||
mask_content(item)
|
||||
}
|
||||
}).collect();
|
||||
masked.insert(k.clone(), Value::Array(masked_arr));
|
||||
} else {
|
||||
masked.insert(k.clone(), mask_content(v));
|
||||
}
|
||||
}
|
||||
else {
|
||||
masked.insert(k.clone(), mask_content(v));
|
||||
}
|
||||
}
|
||||
Value::Object(masked)
|
||||
}
|
||||
_ => value.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a Value for logging with content masked
|
||||
pub fn format_for_log(value: &Value) -> String {
|
||||
let masked = mask_content(value);
|
||||
serde_json::to_string(&masked).unwrap_or_else(|_| "{}".to_string())
|
||||
}
|
||||
|
||||
/// Mask content in a JSON string, returning a masked JSON string
|
||||
pub fn mask_json_string(json_str: &str) -> String {
|
||||
match serde_json::from_str::<Value>(json_str) {
|
||||
Ok(value) => format_for_log(&value),
|
||||
Err(_) => {
|
||||
// If not valid JSON, just truncate
|
||||
if json_str.len() > MAX_LOG_LENGTH {
|
||||
format!("{}... ({} bytes)", &json_str[..MAX_LOG_LENGTH], json_str.len())
|
||||
} else {
|
||||
json_str.to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_mask_short_string() {
|
||||
let value = Value::String("short".to_string());
|
||||
let masked = mask_content(&value);
|
||||
assert_eq!(masked, Value::String("short".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_long_string() {
|
||||
let long_str = "a".repeat(200);
|
||||
let value = Value::String(long_str.clone());
|
||||
let masked = mask_content(&value);
|
||||
if let Value::String(s) = masked {
|
||||
assert!(s.contains("(200 chars)"));
|
||||
assert!(s.len() < long_str.len());
|
||||
} else {
|
||||
panic!("Expected string");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_text_field() {
|
||||
let value = json!({
|
||||
"text": "This is some long message content that should be masked",
|
||||
"message_id": "123",
|
||||
"role": "user"
|
||||
});
|
||||
let masked = mask_content(&value);
|
||||
assert_eq!(masked["message_id"], "123");
|
||||
assert_eq!(masked["role"], "user");
|
||||
// Text is 56 chars, which is > 50, so should be masked with char count
|
||||
assert_eq!(masked["text"], "... (55 chars)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_nested_content() {
|
||||
let value = json!({
|
||||
"sessionId": "abc-123",
|
||||
"update": {
|
||||
"content": [{
|
||||
"type": "content",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "A".repeat(500)
|
||||
}
|
||||
}],
|
||||
"sessionUpdate": "tool_call_update",
|
||||
"status": "completed"
|
||||
}
|
||||
});
|
||||
let masked = mask_content(&value);
|
||||
assert_eq!(masked["sessionId"], "abc-123");
|
||||
assert_eq!(masked["update"]["sessionUpdate"], "tool_call_update");
|
||||
assert_eq!(masked["update"]["status"], "completed");
|
||||
|
||||
// Check that the text field is masked
|
||||
let text_value = &masked["update"]["content"][0]["content"]["text"];
|
||||
if let Value::String(s) = text_value {
|
||||
assert!(s.contains("(500 chars)"));
|
||||
} else {
|
||||
panic!("Expected text to be masked");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_content_as_object() {
|
||||
// Ensure "content" as object is NOT masked, only strings
|
||||
let value = json!({
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "Some message"
|
||||
},
|
||||
"message_id": "123"
|
||||
});
|
||||
let masked = mask_content(&value);
|
||||
assert_eq!(masked["message_id"], "123");
|
||||
assert_eq!(masked["content"]["type"], "text");
|
||||
assert_eq!(masked["content"]["text"], "..."); // text field is masked
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_filename() {
|
||||
let value = json!({
|
||||
"filename": "/Users/name/Projects/dirigent/packages/web/src/main.rs",
|
||||
"operation": "read"
|
||||
});
|
||||
let masked = mask_content(&value);
|
||||
assert_eq!(masked["operation"], "read");
|
||||
assert_eq!(masked["filename"], "<main.rs>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_filenames_array() {
|
||||
let value = json!({
|
||||
"filenames": [
|
||||
"/Users/name/Projects/dirigent/packages/web/src/main.rs",
|
||||
"/Users/name/Projects/dirigent/packages/api/src/core.rs",
|
||||
"C:\\Users\\name\\Documents\\file.txt"
|
||||
],
|
||||
"count": 3
|
||||
});
|
||||
let masked = mask_content(&value);
|
||||
assert_eq!(masked["count"], 3);
|
||||
assert_eq!(masked["filenames"][0], "<main.rs>");
|
||||
assert_eq!(masked["filenames"][1], "<core.rs>");
|
||||
assert_eq!(masked["filenames"][2], "<file.txt>");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
//! Project types for the Dirigent system.
|
||||
//!
|
||||
//! WASM-compatible shared types for the Projects module. These types are
|
||||
//! used by both server and client (web UI) code.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// A project in the Dirigent system.
|
||||
///
|
||||
/// Projects organize work across repositories, sessions, and connectors.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Project {
|
||||
/// Unique project identifier (UUID v7)
|
||||
pub id: Uuid,
|
||||
/// Human-readable project name
|
||||
pub name: String,
|
||||
/// Project description (empty by default)
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
/// Optional icon (emoji or abbreviation)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub icon: Option<String>,
|
||||
/// Owner user ID
|
||||
pub owner: Uuid,
|
||||
/// Member user IDs
|
||||
#[serde(default)]
|
||||
pub members: Vec<Uuid>,
|
||||
/// Categorization tags
|
||||
#[serde(default)]
|
||||
pub tags: Vec<String>,
|
||||
/// Programming languages used
|
||||
#[serde(default)]
|
||||
pub languages: Vec<String>,
|
||||
/// Linked project IDs (for multi-project setups)
|
||||
#[serde(default)]
|
||||
pub linked_projects: Vec<Uuid>,
|
||||
/// Arbitrary metadata
|
||||
#[serde(default = "default_metadata")]
|
||||
pub metadata: serde_json::Value,
|
||||
/// When this project was created
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// When this project was last updated
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
fn default_metadata() -> serde_json::Value {
|
||||
serde_json::Value::Object(serde_json::Map::new())
|
||||
}
|
||||
|
||||
/// A local git repository associated with a project.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ProjectRepository {
|
||||
/// Unique repository identifier (UUID v7)
|
||||
pub id: Uuid,
|
||||
/// Project this repository belongs to
|
||||
pub project_id: Uuid,
|
||||
/// Local filesystem path
|
||||
pub path: PathBuf,
|
||||
/// Whether this is the primary repository
|
||||
#[serde(default)]
|
||||
pub is_primary: bool,
|
||||
/// Optional human-readable label
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub label: Option<String>,
|
||||
/// Access mode
|
||||
#[serde(default)]
|
||||
pub access: AccessMode,
|
||||
/// When this repository was added
|
||||
pub created_at: DateTime<Utc>,
|
||||
/// When this repository was last updated
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Repository access mode.
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum AccessMode {
|
||||
/// Read-only access
|
||||
Read,
|
||||
/// Read and write access
|
||||
#[default]
|
||||
ReadWrite,
|
||||
}
|
||||
|
||||
/// A git worktree linked to a repository.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Worktree {
|
||||
/// Unique worktree identifier (UUID v7)
|
||||
pub id: Uuid,
|
||||
/// Repository this worktree belongs to
|
||||
pub repository_id: Uuid,
|
||||
/// Local filesystem path
|
||||
pub path: PathBuf,
|
||||
/// Branch name
|
||||
pub branch: String,
|
||||
/// Optional work branch name
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub work_branch: Option<String>,
|
||||
/// Optional naming strategy
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub naming_strategy: Option<String>,
|
||||
/// When this worktree was created
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Binding between a project and a connector/session.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ProjectBinding {
|
||||
/// Unique binding identifier (UUID v7)
|
||||
pub id: Uuid,
|
||||
/// Project this binding belongs to
|
||||
pub project_id: Uuid,
|
||||
/// Optional connector ID
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub connector_id: Option<String>,
|
||||
/// Optional session ID
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_id: Option<Uuid>,
|
||||
/// Optional working directory override
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub working_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
/// Runtime git state (not persisted, computed on demand).
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct GitState {
|
||||
/// Current branch name
|
||||
pub branch: String,
|
||||
/// Whether there are uncommitted changes
|
||||
#[serde(default)]
|
||||
pub is_dirty: bool,
|
||||
/// Commits ahead of remote
|
||||
#[serde(default)]
|
||||
pub ahead: u32,
|
||||
/// Commits behind remote
|
||||
#[serde(default)]
|
||||
pub behind: u32,
|
||||
/// Remote names
|
||||
#[serde(default)]
|
||||
pub remotes: Vec<String>,
|
||||
/// Active worktrees
|
||||
#[serde(default)]
|
||||
pub worktrees: Vec<WorktreeInfo>,
|
||||
/// Unexpected conditions
|
||||
#[serde(default)]
|
||||
pub unexpected: Vec<GitWarning>,
|
||||
}
|
||||
|
||||
/// Information about an active worktree.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct WorktreeInfo {
|
||||
/// Worktree filesystem path
|
||||
pub path: PathBuf,
|
||||
/// Branch checked out (None if detached)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub branch: Option<String>,
|
||||
/// Whether HEAD is detached
|
||||
#[serde(default)]
|
||||
pub is_detached: bool,
|
||||
}
|
||||
|
||||
/// A warning about unexpected git state.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct GitWarning {
|
||||
/// Warning code for programmatic handling
|
||||
pub code: String,
|
||||
/// Human-readable warning message
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_project_serialization_roundtrip() {
|
||||
let now = Utc::now();
|
||||
let project = Project {
|
||||
id: Uuid::now_v7(),
|
||||
name: "Test Project".to_string(),
|
||||
description: "A test project".to_string(),
|
||||
icon: Some("🚀".to_string()),
|
||||
owner: Uuid::now_v7(),
|
||||
members: vec![],
|
||||
tags: vec!["rust".to_string()],
|
||||
languages: vec!["Rust".to_string()],
|
||||
linked_projects: vec![],
|
||||
metadata: serde_json::json!({"key": "value"}),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&project).expect("serialize");
|
||||
let deser: Project = serde_json::from_str(&json).expect("deserialize");
|
||||
assert_eq!(deser.id, project.id);
|
||||
assert_eq!(deser.name, project.name);
|
||||
assert_eq!(deser.icon, project.icon);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_project_defaults() {
|
||||
let json = r#"{
|
||||
"id": "019504a0-0000-7000-8000-000000000001",
|
||||
"name": "Minimal",
|
||||
"owner": "019504a0-0000-7000-8000-000000000002",
|
||||
"created_at": "2026-01-01T00:00:00Z",
|
||||
"updated_at": "2026-01-01T00:00:00Z"
|
||||
}"#;
|
||||
|
||||
let project: Project = serde_json::from_str(json).expect("deserialize");
|
||||
assert_eq!(project.description, "");
|
||||
assert!(project.tags.is_empty());
|
||||
assert!(project.members.is_empty());
|
||||
assert!(project.metadata.is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_access_mode_default() {
|
||||
assert_eq!(AccessMode::default(), AccessMode::ReadWrite);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_project_repository_roundtrip() {
|
||||
let now = Utc::now();
|
||||
let repo = ProjectRepository {
|
||||
id: Uuid::now_v7(),
|
||||
project_id: Uuid::now_v7(),
|
||||
path: PathBuf::from("/home/user/project"),
|
||||
is_primary: true,
|
||||
label: Some("main".to_string()),
|
||||
access: AccessMode::ReadWrite,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&repo).expect("serialize");
|
||||
let deser: ProjectRepository = serde_json::from_str(&json).expect("deserialize");
|
||||
assert_eq!(deser.id, repo.id);
|
||||
assert!(deser.is_primary);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_state_default() {
|
||||
let state = GitState::default();
|
||||
assert_eq!(state.branch, "");
|
||||
assert!(!state.is_dirty);
|
||||
assert_eq!(state.ahead, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binding_roundtrip() {
|
||||
let binding = ProjectBinding {
|
||||
id: Uuid::now_v7(),
|
||||
project_id: Uuid::now_v7(),
|
||||
connector_id: Some("conn-1".to_string()),
|
||||
session_id: None,
|
||||
working_dir: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&binding).expect("serialize");
|
||||
let deser: ProjectBinding = serde_json::from_str(&json).expect("deserialize");
|
||||
assert_eq!(deser.id, binding.id);
|
||||
assert!(deser.session_id.is_none());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,972 @@
|
||||
use crate::types::meta::Meta;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Session {
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
pub metadata: SessionMetadata,
|
||||
/// Working directory for this session (if known)
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub cwd: Option<String>,
|
||||
/// ACP model state (available models and current model)
|
||||
/// Populated from archivist for archived sessions, or from SSE events for live sessions.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub models: Option<SessionModelState>,
|
||||
/// ACP mode state (available modes and current mode)
|
||||
/// Populated from archivist for archived sessions, or from SSE events for live sessions.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub modes: Option<SessionModeState>,
|
||||
/// ACP config options (replaces modes/models in future ACP versions)
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub config_options: Option<Vec<ConfigOption>>,
|
||||
/// ACP client ID that owns this session.
|
||||
/// For sessions created via ACP Server (incoming connections), this identifies
|
||||
/// which connected client created/owns this session.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub acp_client_id: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ACP Session Mode/Model Types
|
||||
// ============================================================================
|
||||
// These types match the Agent-Client Protocol (ACP) specification exactly.
|
||||
// They use camelCase serialization to match Claude-ACP's JSON format.
|
||||
// See: docs/architecture/agent_client_protocol/schema.md
|
||||
|
||||
/// Type alias for session mode identifiers (e.g., "default", "plan", "bypassPermissions")
|
||||
pub type SessionModeId = String;
|
||||
|
||||
/// Type alias for model identifiers (e.g., "default", "sonnet", "haiku", "opus")
|
||||
pub type ModelId = String;
|
||||
|
||||
/// Session mode state from ACP `session/new` response
|
||||
///
|
||||
/// Contains the list of available modes and the currently active mode.
|
||||
/// This is part of the stable ACP specification.
|
||||
///
|
||||
/// # Example (from Claude-ACP)
|
||||
/// ```json
|
||||
/// {
|
||||
/// "currentModeId": "default",
|
||||
/// "availableModes": [
|
||||
/// {"id": "default", "name": "Always Ask", "description": "..."},
|
||||
/// {"id": "plan", "name": "Plan Mode", "description": "..."}
|
||||
/// ]
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionModeState {
|
||||
/// The currently active mode ID
|
||||
pub current_mode_id: SessionModeId,
|
||||
/// List of all available modes for this session
|
||||
pub available_modes: Vec<SessionMode>,
|
||||
}
|
||||
|
||||
/// A single session mode definition
|
||||
///
|
||||
/// Modes affect agent behavior, tool availability, and permission handling.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionMode {
|
||||
/// Unique identifier for this mode
|
||||
pub id: SessionModeId,
|
||||
/// Human-readable display name
|
||||
pub name: String,
|
||||
/// Optional description of what this mode does
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
/// Session model state from ACP `session/new` response
|
||||
///
|
||||
/// Contains the list of available models and the currently selected model.
|
||||
/// Note: This field is marked UNSTABLE in the ACP spec but is used by Claude-ACP.
|
||||
///
|
||||
/// # Example (from Claude-ACP)
|
||||
/// ```json
|
||||
/// {
|
||||
/// "availableModels": [
|
||||
/// {"modelId": "default", "name": "Default (recommended)", "description": "..."},
|
||||
/// {"modelId": "sonnet", "name": "Sonnet", "description": "..."}
|
||||
/// ],
|
||||
/// "currentModelId": "default"
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SessionModelState {
|
||||
/// List of all available models for this session
|
||||
pub available_models: Vec<ModelInfo>,
|
||||
/// The currently selected model ID
|
||||
pub current_model_id: ModelId,
|
||||
}
|
||||
|
||||
/// Information about a single model
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ModelInfo {
|
||||
/// Unique identifier for this model
|
||||
pub model_id: ModelId,
|
||||
/// Human-readable display name
|
||||
pub name: String,
|
||||
/// Optional description of the model's capabilities
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ACP Config Options (replaces modes/models in future ACP versions)
|
||||
// ============================================================================
|
||||
|
||||
/// A configuration option for a session (ACP configOptions).
|
||||
///
|
||||
/// Agents provide config options in session/new and session/load responses.
|
||||
/// Clients should use these instead of the legacy `modes`/`models` fields.
|
||||
/// See: docs/architecture/agent_client_protocol/session-config-options.md
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConfigOption {
|
||||
/// Unique identifier (e.g., "mode", "model")
|
||||
pub id: String,
|
||||
/// Human-readable label
|
||||
pub name: String,
|
||||
/// Optional description
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
/// Semantic category for UX grouping (e.g., "mode", "model", "thought_level")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub category: Option<String>,
|
||||
/// Input type (currently only "select" is defined)
|
||||
#[serde(rename = "type")]
|
||||
pub option_type: ConfigOptionType,
|
||||
/// Currently selected value
|
||||
pub current_value: String,
|
||||
/// Available values for select-type options
|
||||
pub options: Vec<ConfigOptionValue>,
|
||||
}
|
||||
|
||||
/// Type of configuration option input
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ConfigOptionType {
|
||||
Select,
|
||||
}
|
||||
|
||||
/// A single value choice within a config option
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ConfigOptionValue {
|
||||
/// Value identifier (sent back when setting this option)
|
||||
pub value: String,
|
||||
/// Human-readable display name
|
||||
pub name: String,
|
||||
/// Optional description
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct SessionMetadata {
|
||||
pub project_path: String,
|
||||
pub model: Option<String>,
|
||||
/// Total count of user and assistant messages in the session (excludes system messages).
|
||||
/// A value of 0 may indicate either an empty session or that the count has not yet been calculated.
|
||||
/// Counts are populated lazily when messages are loaded for a session.
|
||||
/// See `docs/architecture/session_message_counts.md` for details.
|
||||
pub total_messages: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_message: Option<String>,
|
||||
/// Current mode identifier for future mode tracking
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub current_mode_id: Option<String>,
|
||||
/// Provider metadata for tracking original IDs and debugging information
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub _meta: Option<Meta>,
|
||||
/// Optional project ID linking this session to a dirigent_projects Project.
|
||||
/// When set, the session belongs to the specified project for organizational purposes.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub project_id: Option<uuid::Uuid>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Session Ownership Model
|
||||
// ============================================================================
|
||||
// These types define how sessions are owned and how tool execution is routed.
|
||||
// See: docs/architecture/session_ownership.md (Phase 7)
|
||||
|
||||
/// The origin of a session - who initiated it
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum SessionOrigin {
|
||||
/// Session created by Dirigent UI user
|
||||
Internal,
|
||||
|
||||
/// Session forwarded from external ACP client
|
||||
External {
|
||||
/// The ACP client ID that owns this session
|
||||
client_id: String,
|
||||
/// Cached client capabilities (from initialization)
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
client_capabilities: Option<serde_json::Value>,
|
||||
},
|
||||
|
||||
/// Session representing a subagent or internal task (future)
|
||||
Subagent {
|
||||
/// Parent session that spawned this subagent
|
||||
parent_session_id: String,
|
||||
/// Task identifier
|
||||
task_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Default for SessionOrigin {
|
||||
fn default() -> Self {
|
||||
Self::Internal
|
||||
}
|
||||
}
|
||||
|
||||
/// Who handles tool execution for this session
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolHandler {
|
||||
/// Agent handles its own tools (default)
|
||||
#[default]
|
||||
Agent,
|
||||
|
||||
/// Dirigent intercepts and handles tools via dirigent_tools (future)
|
||||
Dirigent,
|
||||
|
||||
/// Forward tool requests to originating client (External sessions only)
|
||||
ForwardToClient,
|
||||
}
|
||||
|
||||
/// Complete ownership model for a session
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct SessionOwnership {
|
||||
/// Where this session originated
|
||||
#[serde(default)]
|
||||
pub origin: SessionOrigin,
|
||||
|
||||
/// How tool requests are handled
|
||||
#[serde(default)]
|
||||
pub tool_handler: ToolHandler,
|
||||
}
|
||||
|
||||
impl SessionOwnership {
|
||||
/// Internal session with agent handling tools (default UI case)
|
||||
pub fn internal() -> Self {
|
||||
Self {
|
||||
origin: SessionOrigin::Internal,
|
||||
tool_handler: ToolHandler::Agent,
|
||||
}
|
||||
}
|
||||
|
||||
/// External session with tools forwarded to client
|
||||
pub fn external_forwarded(client_id: String, capabilities: Option<serde_json::Value>) -> Self {
|
||||
Self {
|
||||
origin: SessionOrigin::External {
|
||||
client_id,
|
||||
client_capabilities: capabilities,
|
||||
},
|
||||
tool_handler: ToolHandler::ForwardToClient,
|
||||
}
|
||||
}
|
||||
|
||||
/// External session but Dirigent handles tools
|
||||
pub fn external_handled(client_id: String, capabilities: Option<serde_json::Value>) -> Self {
|
||||
Self {
|
||||
origin: SessionOrigin::External {
|
||||
client_id,
|
||||
client_capabilities: capabilities,
|
||||
},
|
||||
tool_handler: ToolHandler::Dirigent,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get capabilities to advertise to agent based on ownership
|
||||
pub fn capabilities_for_agent(&self) -> serde_json::Value {
|
||||
match (&self.origin, &self.tool_handler) {
|
||||
// External + ForwardToClient: use client's capabilities
|
||||
(
|
||||
SessionOrigin::External {
|
||||
client_capabilities: Some(caps),
|
||||
..
|
||||
},
|
||||
ToolHandler::ForwardToClient,
|
||||
) => caps.clone(),
|
||||
// Dirigent handles tools: advertise dirigent_tools capabilities
|
||||
(_, ToolHandler::Dirigent) => {
|
||||
serde_json::json!({
|
||||
"fs": { "readTextFile": true, "writeTextFile": true },
|
||||
"terminal": true
|
||||
})
|
||||
}
|
||||
// Agent handles tools or no client caps: empty (agent uses its own)
|
||||
_ => serde_json::json!({}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the client ID if this should forward requests to a client
|
||||
pub fn forward_to_client(&self) -> Option<&str> {
|
||||
match (&self.origin, &self.tool_handler) {
|
||||
(SessionOrigin::External { client_id, .. }, ToolHandler::ForwardToClient) => {
|
||||
Some(client_id.as_str())
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if this is an external (forwarded) session
|
||||
pub fn is_external(&self) -> bool {
|
||||
matches!(self.origin, SessionOrigin::External { .. })
|
||||
}
|
||||
|
||||
/// Get the originating client ID if external
|
||||
pub fn client_id(&self) -> Option<&str> {
|
||||
match &self.origin {
|
||||
SessionOrigin::External { client_id, .. } => Some(client_id.as_str()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::meta::{Meta, ProviderMeta};
|
||||
use std::collections::HashMap;
|
||||
|
||||
// ========================================================================
|
||||
// ACP Session Mode/Model Type Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_session_mode_state_serialization_camel_case() {
|
||||
// Verify camelCase serialization matches Claude-ACP format
|
||||
let mode_state = SessionModeState {
|
||||
current_mode_id: "default".to_string(),
|
||||
available_modes: vec![
|
||||
SessionMode {
|
||||
id: "default".to_string(),
|
||||
name: "Always Ask".to_string(),
|
||||
description: Some(
|
||||
"Prompts for permission on first use of each tool".to_string(),
|
||||
),
|
||||
},
|
||||
SessionMode {
|
||||
id: "plan".to_string(),
|
||||
name: "Plan Mode".to_string(),
|
||||
description: Some("Claude can analyze but not modify files".to_string()),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&mode_state).unwrap();
|
||||
// Verify camelCase field names
|
||||
assert!(json.contains("currentModeId"));
|
||||
assert!(json.contains("availableModes"));
|
||||
// Verify content
|
||||
assert!(json.contains(r#""currentModeId":"default"#));
|
||||
assert!(json.contains(r#""name":"Always Ask"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_mode_state_deserialization_from_claude_format() {
|
||||
// Test deserialization of actual Claude-ACP format
|
||||
let json = r#"{
|
||||
"currentModeId": "default",
|
||||
"availableModes": [
|
||||
{
|
||||
"id": "default",
|
||||
"name": "Always Ask",
|
||||
"description": "Prompts for permission on first use of each tool"
|
||||
},
|
||||
{
|
||||
"id": "acceptEdits",
|
||||
"name": "Accept Edits",
|
||||
"description": "Automatically accepts file edit permissions for the session"
|
||||
},
|
||||
{
|
||||
"id": "plan",
|
||||
"name": "Plan Mode",
|
||||
"description": "Claude can analyze but not modify files or execute commands"
|
||||
},
|
||||
{
|
||||
"id": "bypassPermissions",
|
||||
"name": "Bypass Permissions",
|
||||
"description": "Skips all permission prompts"
|
||||
}
|
||||
]
|
||||
}"#;
|
||||
|
||||
let mode_state: SessionModeState = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(mode_state.current_mode_id, "default");
|
||||
assert_eq!(mode_state.available_modes.len(), 4);
|
||||
assert_eq!(mode_state.available_modes[0].id, "default");
|
||||
assert_eq!(mode_state.available_modes[0].name, "Always Ask");
|
||||
assert_eq!(mode_state.available_modes[3].id, "bypassPermissions");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_mode_state_roundtrip() {
|
||||
let original = SessionModeState {
|
||||
current_mode_id: "plan".to_string(),
|
||||
available_modes: vec![
|
||||
SessionMode {
|
||||
id: "default".to_string(),
|
||||
name: "Default".to_string(),
|
||||
description: None,
|
||||
},
|
||||
SessionMode {
|
||||
id: "plan".to_string(),
|
||||
name: "Plan".to_string(),
|
||||
description: Some("Planning mode".to_string()),
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: SessionModeState = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_mode_skip_none_description() {
|
||||
// Test that None descriptions are not serialized
|
||||
let mode = SessionMode {
|
||||
id: "test".to_string(),
|
||||
name: "Test".to_string(),
|
||||
description: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&mode).unwrap();
|
||||
assert!(!json.contains("description"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_model_state_serialization_camel_case() {
|
||||
// Verify camelCase serialization matches Claude-ACP format
|
||||
let model_state = SessionModelState {
|
||||
available_models: vec![
|
||||
ModelInfo {
|
||||
model_id: "default".to_string(),
|
||||
name: "Default (recommended)".to_string(),
|
||||
description: Some("Opus 4.5 · Most capable for complex work".to_string()),
|
||||
},
|
||||
ModelInfo {
|
||||
model_id: "sonnet".to_string(),
|
||||
name: "Sonnet".to_string(),
|
||||
description: Some("Sonnet 4.5 · Best for everyday tasks".to_string()),
|
||||
},
|
||||
],
|
||||
current_model_id: "default".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&model_state).unwrap();
|
||||
// Verify camelCase field names
|
||||
assert!(json.contains("availableModels"));
|
||||
assert!(json.contains("currentModelId"));
|
||||
assert!(json.contains("modelId"));
|
||||
// Verify content
|
||||
assert!(json.contains(r#""currentModelId":"default"#));
|
||||
assert!(json.contains(r#""name":"Sonnet"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_model_state_deserialization_from_claude_format() {
|
||||
// Test deserialization of actual Claude-ACP format (from zed_claude_code_direct_acp_log.txt)
|
||||
let json = r#"{
|
||||
"availableModels": [
|
||||
{
|
||||
"modelId": "default",
|
||||
"name": "Default (recommended)",
|
||||
"description": "Opus 4.5 · Most capable for complex work"
|
||||
},
|
||||
{
|
||||
"modelId": "sonnet",
|
||||
"name": "Sonnet",
|
||||
"description": "Sonnet 4.5 · Best for everyday tasks"
|
||||
},
|
||||
{
|
||||
"modelId": "haiku",
|
||||
"name": "Haiku",
|
||||
"description": "Haiku 4.5 · Fastest for quick answers"
|
||||
},
|
||||
{
|
||||
"modelId": "opus",
|
||||
"name": "opus",
|
||||
"description": "Custom model"
|
||||
}
|
||||
],
|
||||
"currentModelId": "default"
|
||||
}"#;
|
||||
|
||||
let model_state: SessionModelState = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(model_state.current_model_id, "default");
|
||||
assert_eq!(model_state.available_models.len(), 4);
|
||||
assert_eq!(model_state.available_models[0].model_id, "default");
|
||||
assert_eq!(
|
||||
model_state.available_models[0].name,
|
||||
"Default (recommended)"
|
||||
);
|
||||
assert_eq!(model_state.available_models[2].model_id, "haiku");
|
||||
assert_eq!(model_state.available_models[3].model_id, "opus");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_model_state_roundtrip() {
|
||||
let original = SessionModelState {
|
||||
available_models: vec![ModelInfo {
|
||||
model_id: "default".to_string(),
|
||||
name: "Default".to_string(),
|
||||
description: None,
|
||||
}],
|
||||
current_model_id: "default".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: SessionModelState = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_info_skip_none_description() {
|
||||
// Test that None descriptions are not serialized
|
||||
let model = ModelInfo {
|
||||
model_id: "test".to_string(),
|
||||
name: "Test".to_string(),
|
||||
description: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&model).unwrap();
|
||||
assert!(!json.contains("description"));
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// SessionMetadata Tests (existing)
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_session_metadata_backward_compatibility() {
|
||||
// Test that existing SessionMetadata without new fields can be deserialized
|
||||
let json = r#"{
|
||||
"project_path": "/test/path",
|
||||
"model": "gpt-4",
|
||||
"total_messages": 10,
|
||||
"system_message": "System prompt"
|
||||
}"#;
|
||||
|
||||
let metadata: SessionMetadata = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(metadata.project_path, "/test/path");
|
||||
assert_eq!(metadata.model, Some("gpt-4".to_string()));
|
||||
assert_eq!(metadata.total_messages, 10);
|
||||
assert_eq!(metadata.system_message, Some("System prompt".to_string()));
|
||||
assert_eq!(metadata.current_mode_id, None);
|
||||
assert_eq!(metadata._meta, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_metadata_skip_serializing_none() {
|
||||
// Test that None values are skipped during serialization
|
||||
let metadata = SessionMetadata {
|
||||
project_path: "/test".to_string(),
|
||||
model: Some("gpt-4".to_string()),
|
||||
total_messages: 0,
|
||||
system_message: None,
|
||||
current_mode_id: None,
|
||||
_meta: None,
|
||||
project_id: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&metadata).unwrap();
|
||||
// Should not contain system_message, current_mode_id, or _meta fields
|
||||
assert!(!json.contains("system_message"));
|
||||
assert!(!json.contains("current_mode_id"));
|
||||
assert!(!json.contains("_meta"));
|
||||
// Should contain the present fields
|
||||
assert!(json.contains("project_path"));
|
||||
assert!(json.contains("model"));
|
||||
assert!(json.contains("total_messages"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_metadata_with_current_mode_id() {
|
||||
// Test serialization/deserialization with current_mode_id
|
||||
let metadata = SessionMetadata {
|
||||
project_path: "/test".to_string(),
|
||||
model: Some("gpt-4".to_string()),
|
||||
total_messages: 5,
|
||||
system_message: None,
|
||||
current_mode_id: Some("code_mode".to_string()),
|
||||
_meta: None,
|
||||
project_id: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&metadata).unwrap();
|
||||
assert!(json.contains("current_mode_id"));
|
||||
assert!(json.contains("code_mode"));
|
||||
|
||||
let deserialized: SessionMetadata = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.current_mode_id, Some("code_mode".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_metadata_with_meta() {
|
||||
// Test serialization/deserialization with provider metadata
|
||||
let meta = Meta {
|
||||
provider: Some(ProviderMeta {
|
||||
name: "opencode".to_string(),
|
||||
original_ids: Some(HashMap::from([(
|
||||
"session_id".to_string(),
|
||||
"ses_abc123".to_string(),
|
||||
)])),
|
||||
raw_excerpt: None,
|
||||
}),
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
|
||||
let metadata = SessionMetadata {
|
||||
project_path: "/test".to_string(),
|
||||
model: Some("gpt-4".to_string()),
|
||||
total_messages: 5,
|
||||
system_message: None,
|
||||
current_mode_id: None,
|
||||
_meta: Some(meta),
|
||||
project_id: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&metadata).unwrap();
|
||||
assert!(json.contains("_meta"));
|
||||
assert!(json.contains("opencode"));
|
||||
assert!(json.contains("ses_abc123"));
|
||||
|
||||
let deserialized: SessionMetadata = serde_json::from_str(&json).unwrap();
|
||||
assert!(deserialized._meta.is_some());
|
||||
let deserialized_meta = deserialized._meta.unwrap();
|
||||
assert!(deserialized_meta.provider.is_some());
|
||||
assert_eq!(deserialized_meta.provider.unwrap().name, "opencode");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_metadata_with_all_fields() {
|
||||
// Test with all fields populated
|
||||
let meta = Meta {
|
||||
provider: Some(ProviderMeta {
|
||||
name: "anthropic".to_string(),
|
||||
original_ids: Some(HashMap::from([(
|
||||
"conversation_id".to_string(),
|
||||
"conv_xyz".to_string(),
|
||||
)])),
|
||||
raw_excerpt: Some(serde_json::json!({"version": "1.0"})),
|
||||
}),
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
|
||||
let metadata = SessionMetadata {
|
||||
project_path: "/project".to_string(),
|
||||
model: Some("claude-3".to_string()),
|
||||
total_messages: 42,
|
||||
system_message: Some("Be helpful".to_string()),
|
||||
current_mode_id: Some("architect".to_string()),
|
||||
_meta: Some(meta),
|
||||
project_id: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&metadata).unwrap();
|
||||
let deserialized: SessionMetadata = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(metadata, deserialized);
|
||||
assert!(json.contains("system_message"));
|
||||
assert!(json.contains("current_mode_id"));
|
||||
assert!(json.contains("_meta"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_roundtrip_with_new_fields() {
|
||||
// Test that a Session with new metadata fields survives roundtrip
|
||||
let now = Utc::now();
|
||||
let session = Session {
|
||||
id: "ses_test123".to_string(),
|
||||
title: "Test Session".to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
metadata: SessionMetadata {
|
||||
project_path: "/workspace".to_string(),
|
||||
model: Some("gpt-4-turbo".to_string()),
|
||||
total_messages: 7,
|
||||
system_message: Some("You are a coding assistant".to_string()),
|
||||
current_mode_id: Some("debug_mode".to_string()),
|
||||
_meta: Some(Meta {
|
||||
provider: Some(ProviderMeta {
|
||||
name: "test_provider".to_string(),
|
||||
original_ids: None,
|
||||
raw_excerpt: None,
|
||||
}),
|
||||
extra: HashMap::new(),
|
||||
}),
|
||||
project_id: None,
|
||||
},
|
||||
cwd: None,
|
||||
models: None,
|
||||
modes: None,
|
||||
config_options: None,
|
||||
acp_client_id: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&session).unwrap();
|
||||
let deserialized: Session = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(session, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_with_models_and_modes() {
|
||||
// Test Session with models and modes populated
|
||||
let now = Utc::now();
|
||||
let session = Session {
|
||||
id: "ses_test456".to_string(),
|
||||
title: "Test Session with ACP".to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
metadata: SessionMetadata {
|
||||
project_path: "/workspace".to_string(),
|
||||
model: Some("default".to_string()),
|
||||
total_messages: 5,
|
||||
system_message: None,
|
||||
current_mode_id: Some("default".to_string()),
|
||||
_meta: None,
|
||||
project_id: None,
|
||||
},
|
||||
cwd: None,
|
||||
models: Some(SessionModelState {
|
||||
available_models: vec![
|
||||
ModelInfo {
|
||||
model_id: "default".to_string(),
|
||||
name: "Default".to_string(),
|
||||
description: Some("Default model".to_string()),
|
||||
},
|
||||
ModelInfo {
|
||||
model_id: "sonnet".to_string(),
|
||||
name: "Sonnet".to_string(),
|
||||
description: None,
|
||||
},
|
||||
],
|
||||
current_model_id: "default".to_string(),
|
||||
}),
|
||||
modes: Some(SessionModeState {
|
||||
current_mode_id: "default".to_string(),
|
||||
available_modes: vec![SessionMode {
|
||||
id: "default".to_string(),
|
||||
name: "Always Ask".to_string(),
|
||||
description: None,
|
||||
}],
|
||||
}),
|
||||
config_options: None,
|
||||
acp_client_id: Some("test-client-123".to_string()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&session).unwrap();
|
||||
assert!(json.contains("models"));
|
||||
assert!(json.contains("modes"));
|
||||
assert!(json.contains("acp_client_id"));
|
||||
assert!(json.contains("availableModels"));
|
||||
assert!(json.contains("currentModeId"));
|
||||
|
||||
let deserialized: Session = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(session, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_backward_compatibility_no_models_modes() {
|
||||
// Test that old Session JSON without models/modes can be deserialized
|
||||
let json = r#"{
|
||||
"id": "ses_old",
|
||||
"title": "Old Session",
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z",
|
||||
"metadata": {
|
||||
"project_path": "/test",
|
||||
"model": "gpt-4",
|
||||
"total_messages": 10
|
||||
}
|
||||
}"#;
|
||||
|
||||
let session: Session = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(session.id, "ses_old");
|
||||
assert!(session.models.is_none());
|
||||
assert!(session.modes.is_none());
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Session Ownership Model Tests
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_session_origin_default() {
|
||||
// Verify Internal is the default
|
||||
let origin = SessionOrigin::default();
|
||||
assert_eq!(origin, SessionOrigin::Internal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_origin_serialization() {
|
||||
// Test Internal variant
|
||||
let internal = SessionOrigin::Internal;
|
||||
let json = serde_json::to_string(&internal).unwrap();
|
||||
assert!(json.contains(r#""type":"internal"#));
|
||||
|
||||
// Test External variant
|
||||
let external = SessionOrigin::External {
|
||||
client_id: "client-123".to_string(),
|
||||
client_capabilities: Some(serde_json::json!({"tools": ["bash"]})),
|
||||
};
|
||||
let json = serde_json::to_string(&external).unwrap();
|
||||
assert!(json.contains(r#""type":"external"#));
|
||||
assert!(json.contains(r#""client_id":"client-123"#));
|
||||
assert!(json.contains("tools"));
|
||||
|
||||
// Test Subagent variant
|
||||
let subagent = SessionOrigin::Subagent {
|
||||
parent_session_id: "parent-456".to_string(),
|
||||
task_id: "task-789".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&subagent).unwrap();
|
||||
assert!(json.contains(r#""type":"subagent"#));
|
||||
assert!(json.contains(r#""parent_session_id":"parent-456"#));
|
||||
assert!(json.contains(r#""task_id":"task-789"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_handler_default() {
|
||||
// Verify Agent is the default
|
||||
let handler = ToolHandler::default();
|
||||
assert_eq!(handler, ToolHandler::Agent);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_handler_serialization() {
|
||||
let agent = ToolHandler::Agent;
|
||||
let json = serde_json::to_string(&agent).unwrap();
|
||||
assert_eq!(json, r#""agent""#);
|
||||
|
||||
let dirigent = ToolHandler::Dirigent;
|
||||
let json = serde_json::to_string(&dirigent).unwrap();
|
||||
assert_eq!(json, r#""dirigent""#);
|
||||
|
||||
let forward = ToolHandler::ForwardToClient;
|
||||
let json = serde_json::to_string(&forward).unwrap();
|
||||
assert_eq!(json, r#""forward_to_client""#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_ownership_internal() {
|
||||
let ownership = SessionOwnership::internal();
|
||||
assert_eq!(ownership.origin, SessionOrigin::Internal);
|
||||
assert_eq!(ownership.tool_handler, ToolHandler::Agent);
|
||||
assert!(!ownership.is_external());
|
||||
assert_eq!(ownership.client_id(), None);
|
||||
assert_eq!(ownership.forward_to_client(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_ownership_external_forwarded() {
|
||||
let caps = serde_json::json!({"tools": ["bash", "edit"]});
|
||||
let ownership =
|
||||
SessionOwnership::external_forwarded("client-123".to_string(), Some(caps.clone()));
|
||||
|
||||
match &ownership.origin {
|
||||
SessionOrigin::External {
|
||||
client_id,
|
||||
client_capabilities,
|
||||
} => {
|
||||
assert_eq!(client_id, "client-123");
|
||||
assert_eq!(client_capabilities.as_ref().unwrap(), &caps);
|
||||
}
|
||||
_ => panic!("Expected External origin"),
|
||||
}
|
||||
|
||||
assert_eq!(ownership.tool_handler, ToolHandler::ForwardToClient);
|
||||
assert!(ownership.is_external());
|
||||
assert_eq!(ownership.client_id(), Some("client-123"));
|
||||
assert_eq!(ownership.forward_to_client(), Some("client-123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_ownership_external_handled() {
|
||||
let ownership = SessionOwnership::external_handled("client-456".to_string(), None);
|
||||
|
||||
match &ownership.origin {
|
||||
SessionOrigin::External {
|
||||
client_id,
|
||||
client_capabilities,
|
||||
} => {
|
||||
assert_eq!(client_id, "client-456");
|
||||
assert!(client_capabilities.is_none());
|
||||
}
|
||||
_ => panic!("Expected External origin"),
|
||||
}
|
||||
|
||||
assert_eq!(ownership.tool_handler, ToolHandler::Dirigent);
|
||||
assert!(ownership.is_external());
|
||||
assert_eq!(ownership.client_id(), Some("client-456"));
|
||||
assert_eq!(ownership.forward_to_client(), None); // Dirigent handles, not forwarded
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_capabilities_for_agent_external_forwarded() {
|
||||
let client_caps = serde_json::json!({"fs": true, "terminal": true});
|
||||
let ownership = SessionOwnership::external_forwarded(
|
||||
"client-123".to_string(),
|
||||
Some(client_caps.clone()),
|
||||
);
|
||||
|
||||
let caps = ownership.capabilities_for_agent();
|
||||
assert_eq!(caps, client_caps);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_capabilities_for_agent_dirigent() {
|
||||
let ownership = SessionOwnership {
|
||||
origin: SessionOrigin::Internal,
|
||||
tool_handler: ToolHandler::Dirigent,
|
||||
};
|
||||
|
||||
let caps = ownership.capabilities_for_agent();
|
||||
assert!(caps.is_object());
|
||||
assert!(caps.get("fs").is_some());
|
||||
assert!(caps.get("terminal").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_capabilities_for_agent_agent_handled() {
|
||||
let ownership = SessionOwnership::internal();
|
||||
let caps = ownership.capabilities_for_agent();
|
||||
assert_eq!(caps, serde_json::json!({}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_ownership_serialization_roundtrip() {
|
||||
let original = SessionOwnership::external_forwarded(
|
||||
"test-client".to_string(),
|
||||
Some(serde_json::json!({"test": true})),
|
||||
);
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: SessionOwnership = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.tool_handler, original.tool_handler);
|
||||
assert_eq!(deserialized.client_id(), Some("test-client"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_ownership_default() {
|
||||
let ownership = SessionOwnership::default();
|
||||
assert_eq!(ownership.origin, SessionOrigin::Internal);
|
||||
assert_eq!(ownership.tool_handler, ToolHandler::Agent);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
//! Session sharing abstraction
|
||||
//!
|
||||
//! The `SessionShare` trait abstracts bidirectional bridges between Dirigent
|
||||
//! sessions and external communication systems (Matrix, Slack, etc.).
|
||||
//!
|
||||
//! A share attaches to a (connector_id, session_id) pair without taking
|
||||
//! ownership of the session. Multiple shares can coexist on the same session.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Unique identifier for a share instance.
|
||||
pub type ShareId = String;
|
||||
|
||||
/// Summary info about an active share.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ShareSummary {
|
||||
/// Share identifier (e.g., "matrix:connector-1:session-abc")
|
||||
pub id: ShareId,
|
||||
/// Connector this share is attached to
|
||||
pub connector_id: String,
|
||||
/// Session this share is attached to
|
||||
pub session_id: String,
|
||||
/// Backend type (e.g., "matrix", "slack")
|
||||
pub backend: String,
|
||||
/// Backend-specific destination (e.g., Matrix room ID)
|
||||
pub destination: String,
|
||||
/// Whether the share is currently active
|
||||
pub active: bool,
|
||||
}
|
||||
|
||||
/// Trait for session share backends.
|
||||
///
|
||||
/// Implementors provide bidirectional bridging between a Dirigent session
|
||||
/// and an external system. The trait is deliberately minimal — the concrete
|
||||
/// implementation handles all backend-specific details.
|
||||
///
|
||||
/// # Design Notes
|
||||
///
|
||||
/// - Shares do NOT modify the Connector trait or require special connector support
|
||||
/// - Shares use existing channels: `connector.subscribe()` for events,
|
||||
/// `connector.command_tx()` for sending messages
|
||||
/// - Multiple shares can be active on the same session simultaneously
|
||||
#[async_trait]
|
||||
pub trait SessionShare: Send + Sync {
|
||||
/// Get summary information about this share.
|
||||
fn summary(&self) -> ShareSummary;
|
||||
|
||||
/// Check if the share is actively forwarding.
|
||||
fn is_active(&self) -> bool;
|
||||
|
||||
/// Gracefully shut down this share.
|
||||
async fn shutdown(&self);
|
||||
}
|
||||
@@ -0,0 +1,353 @@
|
||||
//! Bus event envelope: wraps `Event` with routing context for subscribers.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::Event;
|
||||
|
||||
/// Full bus event envelope: the `Event` plus routing context derived at emit time.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BusEvent {
|
||||
pub routing: EventRouting,
|
||||
pub origin: EventOrigin,
|
||||
pub event: Arc<Event>,
|
||||
}
|
||||
|
||||
/// Routing metadata attached to every `BusEvent`.
|
||||
///
|
||||
/// `scroll_id` is intentionally left `None` at construction time and filled in
|
||||
/// later by the bus cache once the archivist has registered the session.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct EventRouting {
|
||||
pub connector_uid: Option<Uuid>,
|
||||
pub scroll_id: Option<Uuid>,
|
||||
pub connector_id: Option<String>,
|
||||
pub native_session_id: Option<String>,
|
||||
pub kind: EventKind,
|
||||
}
|
||||
|
||||
/// High-level classification of a `BusEvent`, used for subscriber filtering.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum EventKind {
|
||||
#[default]
|
||||
SessionLifecycle,
|
||||
Message,
|
||||
Update,
|
||||
System,
|
||||
}
|
||||
|
||||
/// Records the subsystem that originally produced the event.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum EventOrigin {
|
||||
Connector {
|
||||
connector_uid: Option<Uuid>,
|
||||
connector_id: String,
|
||||
},
|
||||
Archivist,
|
||||
Runtime,
|
||||
Replay {
|
||||
replay_id: Uuid,
|
||||
},
|
||||
}
|
||||
|
||||
// ─── BusEvent constructors ───────────────────────────────────────────────────
|
||||
|
||||
impl BusEvent {
|
||||
/// Wrap a connector-sourced `Event` in a `BusEvent`.
|
||||
///
|
||||
/// Routing metadata is derived from the event fields; `scroll_id` is left
|
||||
/// `None` and must be patched by the bus cache after archivist registration.
|
||||
pub fn from_connector_event(
|
||||
event: Event,
|
||||
connector_uid: Option<Uuid>,
|
||||
connector_id: String,
|
||||
) -> Self {
|
||||
let routing = EventRouting::derive(&event, connector_uid, &connector_id);
|
||||
Self {
|
||||
routing,
|
||||
origin: EventOrigin::Connector {
|
||||
connector_uid,
|
||||
connector_id,
|
||||
},
|
||||
event: Arc::new(event),
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap an archivist-sourced `Event` (e.g. `SessionRegistered`) in a
|
||||
/// `BusEvent`. The archivist knows the canonical `(connector_id,
|
||||
/// native_session_id, scroll_id)` triple; we pass it directly so the
|
||||
/// bus does not have to consult its scroll-id cache to route the
|
||||
/// event. Origin is set to `EventOrigin::Archivist`.
|
||||
pub fn from_archivist_event(
|
||||
event: Event,
|
||||
connector_id: &str,
|
||||
native_session_id: &str,
|
||||
scroll_id: Option<Uuid>,
|
||||
) -> Self {
|
||||
let (kind, _) = classify(&event);
|
||||
let routing = EventRouting {
|
||||
connector_uid: None,
|
||||
scroll_id,
|
||||
connector_id: Some(connector_id.to_string()),
|
||||
native_session_id: Some(native_session_id.to_string()),
|
||||
kind,
|
||||
};
|
||||
Self {
|
||||
routing,
|
||||
origin: EventOrigin::Archivist,
|
||||
event: Arc::new(event),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── EventRouting::derive ────────────────────────────────────────────────────
|
||||
|
||||
impl EventRouting {
|
||||
/// Derive routing from an `Event` plus the emitting connector's identity.
|
||||
///
|
||||
/// `scroll_id` is always `None` here; the bus cache fills it in later.
|
||||
pub fn derive(event: &Event, connector_uid: Option<Uuid>, connector_id: &str) -> Self {
|
||||
let (kind, native_session_id) = classify(event);
|
||||
Self {
|
||||
connector_uid,
|
||||
scroll_id: None,
|
||||
connector_id: Some(connector_id.to_string()),
|
||||
native_session_id,
|
||||
kind,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the `(EventKind, Option<native_session_id>)` for a given `Event`.
|
||||
///
|
||||
/// Every current variant is matched explicitly; the `_` arm is a safety net for
|
||||
/// future additions and classifies them as `System` with no session context.
|
||||
fn classify(event: &Event) -> (EventKind, Option<String>) {
|
||||
use EventKind::*;
|
||||
|
||||
match event {
|
||||
// ── SessionLifecycle ─────────────────────────────────────────────────
|
||||
Event::SessionsListed { .. } => (SessionLifecycle, None),
|
||||
|
||||
Event::SessionCreated { session, .. } => {
|
||||
(SessionLifecycle, Some(session.id.clone()))
|
||||
}
|
||||
|
||||
Event::SessionUpdated { session, .. } => {
|
||||
(SessionLifecycle, Some(session.id.clone()))
|
||||
}
|
||||
|
||||
Event::SessionMetadataUpdated { session_id, .. } => {
|
||||
(SessionLifecycle, Some(session_id.clone()))
|
||||
}
|
||||
|
||||
Event::SessionDeleted { session_id } => {
|
||||
(SessionLifecycle, Some(session_id.clone()))
|
||||
}
|
||||
|
||||
Event::SessionClosed { session_id, .. } => {
|
||||
(SessionLifecycle, Some(session_id.clone()))
|
||||
}
|
||||
|
||||
Event::SessionSystemMessageSet { session_id, .. } => {
|
||||
(SessionLifecycle, Some(session_id.clone()))
|
||||
}
|
||||
|
||||
Event::SessionIdle { session_id, .. } => {
|
||||
(SessionLifecycle, Some(session_id.clone()))
|
||||
}
|
||||
|
||||
Event::SessionMetadataReceived { session_id, .. } => {
|
||||
(SessionLifecycle, Some(session_id.clone()))
|
||||
}
|
||||
|
||||
Event::SessionTransferred { from_session, .. } => {
|
||||
(SessionLifecycle, Some(from_session.clone()))
|
||||
}
|
||||
|
||||
Event::SessionRegistered { session_id, .. } => {
|
||||
(SessionLifecycle, Some(session_id.clone()))
|
||||
}
|
||||
|
||||
Event::ForwardingPanic { session_id, .. } => {
|
||||
(SessionLifecycle, Some(session_id.clone()))
|
||||
}
|
||||
|
||||
Event::Connected => (SessionLifecycle, None),
|
||||
Event::Disconnected => (SessionLifecycle, None),
|
||||
|
||||
Event::ConnectorCreated { .. } => (SessionLifecycle, None),
|
||||
Event::ConnectorRemoved { .. } => (SessionLifecycle, None),
|
||||
Event::ConnectorStateChanged { .. } => (SessionLifecycle, None),
|
||||
|
||||
Event::AcpClientConnected { .. } => (SessionLifecycle, None),
|
||||
Event::AcpClientDisconnected { .. } => (SessionLifecycle, None),
|
||||
|
||||
Event::AcpClientSessionOpened {
|
||||
client_session_id, ..
|
||||
} => (SessionLifecycle, Some(client_session_id.clone())),
|
||||
|
||||
Event::AcpClientSessionRouted { from_session_id, .. } => {
|
||||
(SessionLifecycle, Some(from_session_id.clone()))
|
||||
}
|
||||
|
||||
Event::AgentRequest { session_id, .. } => {
|
||||
(SessionLifecycle, Some(session_id.clone()))
|
||||
}
|
||||
|
||||
// ── Message ──────────────────────────────────────────────────────────
|
||||
Event::MessagesListed { .. } => (Message, None),
|
||||
|
||||
Event::MessageStarted { message, .. } => {
|
||||
(Message, Some(message.session_id.clone()))
|
||||
}
|
||||
|
||||
Event::MessageCompleted { message, .. } => {
|
||||
(Message, Some(message.session_id.clone()))
|
||||
}
|
||||
|
||||
Event::MessageFailed { .. } => (Message, None),
|
||||
|
||||
Event::TurnComplete { session_id, .. } => (Message, Some(session_id.clone())),
|
||||
|
||||
// ── Update ───────────────────────────────────────────────────────────
|
||||
Event::SessionUpdate { session_id, .. } => (Update, Some(session_id.clone())),
|
||||
|
||||
// ── System ───────────────────────────────────────────────────────────
|
||||
Event::Error { .. } => (System, None),
|
||||
Event::SessionError { session_id, .. } => (System, Some(session_id.clone())),
|
||||
Event::InspectorSnapshot { .. } => (System, None),
|
||||
Event::InspectorNodeRegistered { .. } => (System, None),
|
||||
Event::InspectorNodeRemoved { .. } => (System, None),
|
||||
Event::InspectorStateChanged { .. } => (System, None),
|
||||
Event::InspectorPropertiesUpdated { .. } => (System, None),
|
||||
Event::SystemTaskStatusChanged { .. } => (System, None),
|
||||
|
||||
// Safety net: future variants not yet listed above.
|
||||
// #[allow] is intentional — this arm exists to catch additions to Event.
|
||||
#[allow(unreachable_patterns)]
|
||||
_ => (System, None),
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{Event, Session, SessionMetadata};
|
||||
use chrono::Utc;
|
||||
|
||||
fn minimal_session(id: &str) -> Session {
|
||||
let now = Utc::now();
|
||||
Session {
|
||||
id: id.to_string(),
|
||||
title: "test".to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
metadata: SessionMetadata {
|
||||
project_path: "/tmp".to_string(),
|
||||
model: None,
|
||||
total_messages: 0,
|
||||
system_message: None,
|
||||
current_mode_id: None,
|
||||
_meta: None,
|
||||
project_id: None,
|
||||
},
|
||||
cwd: None,
|
||||
models: None,
|
||||
modes: None,
|
||||
config_options: None,
|
||||
acp_client_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_extracts_session_id_on_session_created() {
|
||||
let event = Event::SessionCreated {
|
||||
connector_id: "conn-1".to_string(),
|
||||
session: minimal_session("ses-abc"),
|
||||
};
|
||||
let routing = EventRouting::derive(&event, None, "conn-1");
|
||||
assert_eq!(routing.native_session_id.as_deref(), Some("ses-abc"));
|
||||
assert_eq!(routing.kind, EventKind::SessionLifecycle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_sets_kind_update_on_session_update() {
|
||||
use crate::SessionUpdate as SU;
|
||||
let event = Event::SessionUpdate {
|
||||
connector_id: "conn-1".to_string(),
|
||||
session_id: "ses-xyz".to_string(),
|
||||
update: SU::Unknown {
|
||||
data: serde_json::json!({"type": "unknown_future"}),
|
||||
},
|
||||
};
|
||||
let routing = EventRouting::derive(&event, None, "conn-1");
|
||||
assert_eq!(routing.kind, EventKind::Update);
|
||||
assert_eq!(routing.native_session_id.as_deref(), Some("ses-xyz"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_sets_kind_message_for_message_started() {
|
||||
use crate::conversation::{Message, MessageRole, MessageStatus};
|
||||
let now = Utc::now();
|
||||
let msg = Message {
|
||||
id: "msg-1".to_string(),
|
||||
session_id: "ses-msg".to_string(),
|
||||
role: MessageRole::User,
|
||||
created_at: now,
|
||||
content: vec![],
|
||||
status: MessageStatus::Pending,
|
||||
metadata: None,
|
||||
};
|
||||
let event = Event::MessageStarted {
|
||||
connector_id: "conn-1".to_string(),
|
||||
message: msg,
|
||||
};
|
||||
let routing = EventRouting::derive(&event, None, "conn-1");
|
||||
assert_eq!(routing.kind, EventKind::Message);
|
||||
assert_eq!(routing.native_session_id.as_deref(), Some("ses-msg"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn derive_sets_kind_system_for_error() {
|
||||
let event = Event::Error {
|
||||
message: "something went wrong".to_string(),
|
||||
};
|
||||
let routing = EventRouting::derive(&event, None, "conn-1");
|
||||
assert_eq!(routing.kind, EventKind::System);
|
||||
assert!(routing.native_session_id.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn event_origin_replay_roundtrips_replay_id() {
|
||||
let id = Uuid::new_v4();
|
||||
let origin = EventOrigin::Replay { replay_id: id };
|
||||
match &origin {
|
||||
EventOrigin::Replay { replay_id } => assert_eq!(*replay_id, id),
|
||||
_ => panic!("wrong variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn from_connector_event_produces_connector_origin() {
|
||||
let event = Event::Connected;
|
||||
let uid = Uuid::nil();
|
||||
let bus = BusEvent::from_connector_event(event, Some(uid), "conn-test".to_string());
|
||||
match &bus.origin {
|
||||
EventOrigin::Connector {
|
||||
connector_uid,
|
||||
connector_id,
|
||||
} => {
|
||||
assert_eq!(*connector_uid, Some(uid));
|
||||
assert_eq!(connector_id, "conn-test");
|
||||
}
|
||||
_ => panic!("expected Connector origin"),
|
||||
}
|
||||
assert_eq!(bus.routing.connector_id.as_deref(), Some("conn-test"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,244 @@
|
||||
//! Filters applied on the subscriber side of the SharingBus.
|
||||
|
||||
use std::ops::{BitOr, BitOrAssign};
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::bus_event::{BusEvent, EventKind};
|
||||
|
||||
/// A subscriber-side predicate that selects which `BusEvent`s to forward.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum EventFilter {
|
||||
/// Accept every event unconditionally.
|
||||
All,
|
||||
/// Accept only events whose `routing.scroll_id` matches the given UUID.
|
||||
ScrollId(Uuid),
|
||||
/// Accept only events whose `routing.connector_uid` matches the given UUID.
|
||||
ConnectorUid(Uuid),
|
||||
/// Accept only events whose `routing.kind` is set in the mask.
|
||||
Kinds(EventKindMask),
|
||||
/// Accept events that satisfy at least one of the inner filters.
|
||||
AnyOf(Vec<EventFilter>),
|
||||
/// Accept events that satisfy all of the inner filters.
|
||||
AllOf(Vec<EventFilter>),
|
||||
}
|
||||
|
||||
/// Bit-mask over `EventKind` variants for efficient kind-based filtering.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub struct EventKindMask(pub u8);
|
||||
|
||||
impl EventKindMask {
|
||||
pub const SESSION_LIFECYCLE: Self = Self(1 << 0);
|
||||
pub const MESSAGE: Self = Self(1 << 1);
|
||||
pub const UPDATE: Self = Self(1 << 2);
|
||||
pub const SYSTEM: Self = Self(1 << 3);
|
||||
pub const ALL: Self = Self(0b1111);
|
||||
|
||||
/// Returns `true` if `kind` is set in this mask.
|
||||
pub fn contains(self, kind: EventKind) -> bool {
|
||||
let bit = match kind {
|
||||
EventKind::SessionLifecycle => Self::SESSION_LIFECYCLE,
|
||||
EventKind::Message => Self::MESSAGE,
|
||||
EventKind::Update => Self::UPDATE,
|
||||
EventKind::System => Self::SYSTEM,
|
||||
};
|
||||
(self.0 & bit.0) != 0
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOr for EventKindMask {
|
||||
type Output = Self;
|
||||
|
||||
fn bitor(self, rhs: Self) -> Self::Output {
|
||||
Self(self.0 | rhs.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl BitOrAssign for EventKindMask {
|
||||
fn bitor_assign(&mut self, rhs: Self) {
|
||||
self.0 |= rhs.0;
|
||||
}
|
||||
}
|
||||
|
||||
impl EventFilter {
|
||||
/// Returns `true` if this filter accepts the given `BusEvent`.
|
||||
pub fn matches(&self, event: &BusEvent) -> bool {
|
||||
match self {
|
||||
EventFilter::All => true,
|
||||
EventFilter::ScrollId(s) => event.routing.scroll_id == Some(*s),
|
||||
EventFilter::ConnectorUid(u) => event.routing.connector_uid == Some(*u),
|
||||
EventFilter::Kinds(m) => m.contains(event.routing.kind),
|
||||
EventFilter::AnyOf(filters) => filters.iter().any(|f| f.matches(event)),
|
||||
EventFilter::AllOf(filters) => filters.iter().all(|f| f.matches(event)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Tests ───────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::*;
|
||||
use crate::{
|
||||
streaming::bus_event::{EventOrigin, EventRouting},
|
||||
Event,
|
||||
};
|
||||
|
||||
/// Build a minimal `BusEvent` for testing.
|
||||
fn make_event(
|
||||
scroll_id: Option<Uuid>,
|
||||
connector_uid: Option<Uuid>,
|
||||
kind: EventKind,
|
||||
) -> BusEvent {
|
||||
BusEvent {
|
||||
routing: EventRouting {
|
||||
scroll_id,
|
||||
connector_uid,
|
||||
kind,
|
||||
..Default::default()
|
||||
},
|
||||
origin: EventOrigin::Runtime,
|
||||
event: Arc::new(Event::Connected),
|
||||
}
|
||||
}
|
||||
|
||||
// 1. EventFilter::All matches any BusEvent.
|
||||
#[test]
|
||||
fn all_matches_any_event() {
|
||||
let ev = make_event(None, None, EventKind::System);
|
||||
assert!(EventFilter::All.matches(&ev));
|
||||
|
||||
let ev2 = make_event(Some(Uuid::new_v4()), Some(Uuid::new_v4()), EventKind::Message);
|
||||
assert!(EventFilter::All.matches(&ev2));
|
||||
}
|
||||
|
||||
// 2. EventFilter::ScrollId matches Some(x), rejects Some(y), rejects None.
|
||||
#[test]
|
||||
fn scroll_id_matches_correct_uuid_only() {
|
||||
let x = Uuid::new_v4();
|
||||
let y = Uuid::new_v4();
|
||||
|
||||
let filter = EventFilter::ScrollId(x);
|
||||
|
||||
let ev_match = make_event(Some(x), None, EventKind::SessionLifecycle);
|
||||
assert!(filter.matches(&ev_match), "should match Some(x)");
|
||||
|
||||
let ev_other = make_event(Some(y), None, EventKind::SessionLifecycle);
|
||||
assert!(!filter.matches(&ev_other), "should reject Some(y)");
|
||||
|
||||
let ev_none = make_event(None, None, EventKind::SessionLifecycle);
|
||||
assert!(!filter.matches(&ev_none), "should reject None");
|
||||
}
|
||||
|
||||
// 3. EventFilter::ConnectorUid matches only when routing.connector_uid == Some(u).
|
||||
#[test]
|
||||
fn connector_uid_matches_correct_uuid_only() {
|
||||
let u = Uuid::new_v4();
|
||||
let other = Uuid::new_v4();
|
||||
|
||||
let filter = EventFilter::ConnectorUid(u);
|
||||
|
||||
let ev_match = make_event(None, Some(u), EventKind::Update);
|
||||
assert!(filter.matches(&ev_match));
|
||||
|
||||
let ev_other = make_event(None, Some(other), EventKind::Update);
|
||||
assert!(!filter.matches(&ev_other));
|
||||
|
||||
let ev_none = make_event(None, None, EventKind::Update);
|
||||
assert!(!filter.matches(&ev_none));
|
||||
}
|
||||
|
||||
// 4. EventFilter::Kinds(MESSAGE) matches Message, rejects Update.
|
||||
#[test]
|
||||
fn kinds_mask_message_matches_message_only() {
|
||||
let filter = EventFilter::Kinds(EventKindMask::MESSAGE);
|
||||
|
||||
let ev_msg = make_event(None, None, EventKind::Message);
|
||||
assert!(filter.matches(&ev_msg));
|
||||
|
||||
let ev_upd = make_event(None, None, EventKind::Update);
|
||||
assert!(!filter.matches(&ev_upd));
|
||||
}
|
||||
|
||||
// 5. AnyOf([ScrollId(X), ConnectorUid(Y)]) matches when either matches, rejects otherwise.
|
||||
#[test]
|
||||
fn any_of_matches_when_at_least_one_sub_filter_matches() {
|
||||
let x = Uuid::new_v4();
|
||||
let y = Uuid::new_v4();
|
||||
let z = Uuid::new_v4();
|
||||
|
||||
let filter = EventFilter::AnyOf(vec![
|
||||
EventFilter::ScrollId(x),
|
||||
EventFilter::ConnectorUid(y),
|
||||
]);
|
||||
|
||||
// scroll_id matches
|
||||
let ev_scroll = make_event(Some(x), None, EventKind::Message);
|
||||
assert!(filter.matches(&ev_scroll));
|
||||
|
||||
// connector_uid matches
|
||||
let ev_conn = make_event(None, Some(y), EventKind::Message);
|
||||
assert!(filter.matches(&ev_conn));
|
||||
|
||||
// both match
|
||||
let ev_both = make_event(Some(x), Some(y), EventKind::Message);
|
||||
assert!(filter.matches(&ev_both));
|
||||
|
||||
// neither matches
|
||||
let ev_neither = make_event(Some(z), Some(z), EventKind::System);
|
||||
assert!(!filter.matches(&ev_neither));
|
||||
}
|
||||
|
||||
// 6. AllOf([ScrollId(X), Kinds(MESSAGE)]) matches only when both hold.
|
||||
#[test]
|
||||
fn all_of_matches_only_when_all_sub_filters_match() {
|
||||
let x = Uuid::new_v4();
|
||||
|
||||
let filter = EventFilter::AllOf(vec![
|
||||
EventFilter::ScrollId(x),
|
||||
EventFilter::Kinds(EventKindMask::MESSAGE),
|
||||
]);
|
||||
|
||||
// both conditions satisfied
|
||||
let ev_both = make_event(Some(x), None, EventKind::Message);
|
||||
assert!(filter.matches(&ev_both));
|
||||
|
||||
// scroll_id matches but wrong kind
|
||||
let ev_wrong_kind = make_event(Some(x), None, EventKind::Update);
|
||||
assert!(!filter.matches(&ev_wrong_kind));
|
||||
|
||||
// right kind but wrong scroll_id
|
||||
let ev_wrong_scroll = make_event(Some(Uuid::new_v4()), None, EventKind::Message);
|
||||
assert!(!filter.matches(&ev_wrong_scroll));
|
||||
|
||||
// neither matches
|
||||
let ev_neither = make_event(None, None, EventKind::System);
|
||||
assert!(!filter.matches(&ev_neither));
|
||||
}
|
||||
|
||||
// 7. BitOr combining masks: (MESSAGE | UPDATE).contains(kind) is true for both.
|
||||
#[test]
|
||||
fn bitor_combines_masks_correctly() {
|
||||
let combined = EventKindMask::MESSAGE | EventKindMask::UPDATE;
|
||||
|
||||
assert!(combined.contains(EventKind::Message));
|
||||
assert!(combined.contains(EventKind::Update));
|
||||
assert!(!combined.contains(EventKind::SessionLifecycle));
|
||||
assert!(!combined.contains(EventKind::System));
|
||||
}
|
||||
|
||||
// Bonus: verify BitOrAssign works the same way.
|
||||
#[test]
|
||||
fn bitor_assign_accumulates_bits() {
|
||||
let mut mask = EventKindMask::MESSAGE;
|
||||
mask |= EventKindMask::SYSTEM;
|
||||
|
||||
assert!(mask.contains(EventKind::Message));
|
||||
assert!(mask.contains(EventKind::System));
|
||||
assert!(!mask.contains(EventKind::Update));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
//! Streaming primitives shared across the runtime, archivist, and sink crates.
|
||||
//!
|
||||
//! `BusEvent` wraps the existing `Event` with routing context; `SessionStream`
|
||||
//! is the trait every uni-directional sink implements. The existing
|
||||
//! `SessionShare` trait (bi-directional, Matrix) lives in `crate::sharing`
|
||||
//! and is not superseded.
|
||||
|
||||
pub mod bus_event;
|
||||
pub mod filter;
|
||||
pub mod receiver;
|
||||
pub mod stream;
|
||||
|
||||
pub use bus_event::{BusEvent, EventKind, EventOrigin, EventRouting};
|
||||
pub use filter::{EventFilter, EventKindMask};
|
||||
pub use receiver::BusReceiver;
|
||||
pub use stream::{
|
||||
SessionStream, StreamError, StreamKind, StreamOutcome, StreamScope, StreamSummary,
|
||||
};
|
||||
@@ -0,0 +1,27 @@
|
||||
//! `BusReceiver`: subscriber handle returned by a SharingBus-like fan-out.
|
||||
//!
|
||||
//! This type lives in `dirigent_protocol` (rather than next to its only
|
||||
//! producer in `dirigent_core::sharing::bus`) so that downstream consumers
|
||||
//! such as `dirigent_archivist` can accept a `BusReceiver` without taking
|
||||
//! on a dependency on `dirigent_core` — which would be a dependency cycle.
|
||||
//!
|
||||
//! It is intentionally a dumb data container: just `id`, `rx`, and the
|
||||
//! `lagged` counter that the producer task increments when it has to drop
|
||||
//! events for a slow subscriber. No logic lives here.
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::streaming::BusEvent;
|
||||
|
||||
/// Receiver handle returned to subscribers by a SharingBus-style fan-out.
|
||||
///
|
||||
/// `lagged` counts how many events were dropped because the underlying
|
||||
/// mpsc queue was full when the worker tried to deliver.
|
||||
pub struct BusReceiver {
|
||||
pub id: u64,
|
||||
pub rx: mpsc::Receiver<BusEvent>,
|
||||
pub lagged: Arc<AtomicU64>,
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
//! SessionStream: uni-directional sink trait. Archive backends use
|
||||
//! ArchiveBackend; live-write sinks like Langfuse use SessionStream.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::bus_event::BusEvent;
|
||||
|
||||
#[async_trait]
|
||||
pub trait SessionStream: Send + Sync {
|
||||
fn summary(&self) -> StreamSummary;
|
||||
fn scope(&self) -> StreamScope;
|
||||
async fn on_event(&self, event: &BusEvent) -> StreamOutcome;
|
||||
async fn shutdown(&self);
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamSummary {
|
||||
pub name: String,
|
||||
pub kind: StreamKind,
|
||||
pub target: String, // human-readable ("langfuse: https://…", "matrix: #room:server")
|
||||
pub active_since: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum StreamKind {
|
||||
Matrix,
|
||||
Langfuse,
|
||||
Slack,
|
||||
Webhook,
|
||||
Custom,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case", tag = "kind")]
|
||||
pub enum StreamScope {
|
||||
Session { scroll_id: Uuid },
|
||||
Connector { connector_uid: Uuid },
|
||||
ArchiveWide { acknowledged: bool },
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum StreamOutcome {
|
||||
Ok,
|
||||
Skipped,
|
||||
Failed(StreamError),
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum StreamError {
|
||||
#[error("transport: {0}")] Transport(String),
|
||||
#[error("serialisation: {0}")] Serialisation(String),
|
||||
#[error("rejected: {0}")] Rejected(String),
|
||||
#[error("shutdown")] Shutdown,
|
||||
}
|
||||
@@ -0,0 +1,55 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// MCP-style content block for displayable content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ContentBlock {
|
||||
Text {
|
||||
text: String,
|
||||
},
|
||||
ResourceLink {
|
||||
uri: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
mime_type: Option<String>,
|
||||
},
|
||||
// Future: Resource, Image, Audio (marked as out-of-scope for phase 1)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_text_serialization() {
|
||||
let block = ContentBlock::Text {
|
||||
text: "Hello".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&block).unwrap();
|
||||
assert!(json.contains(r#""type":"text"#));
|
||||
assert!(json.contains(r#""text":"Hello"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resource_link_serialization() {
|
||||
let block = ContentBlock::ResourceLink {
|
||||
uri: "file:///path/to/file".to_string(),
|
||||
name: Some("file.txt".to_string()),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
};
|
||||
let json = serde_json::to_string(&block).unwrap();
|
||||
assert!(json.contains(r#""type":"resource_link"#));
|
||||
assert!(json.contains(r#""uri":"file:///path/to/file"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip() {
|
||||
let block = ContentBlock::Text {
|
||||
text: "Test".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&block).unwrap();
|
||||
let deserialized: ContentBlock = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(block, deserialized);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
|
||||
pub struct Meta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub provider: Option<ProviderMeta>,
|
||||
|
||||
/// Arbitrary extra fields
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ProviderMeta {
|
||||
/// Provider name (e.g., "opencode", "anthropic")
|
||||
pub name: String,
|
||||
|
||||
/// Original provider-specific IDs
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub original_ids: Option<HashMap<String, String>>,
|
||||
|
||||
/// Minimal raw payload excerpts for debugging (optional)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub raw_excerpt: Option<Value>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_meta_default() {
|
||||
let meta = Meta::default();
|
||||
assert_eq!(meta.provider, None);
|
||||
assert!(meta.extra.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meta_with_provider() {
|
||||
let meta = Meta {
|
||||
provider: Some(ProviderMeta {
|
||||
name: "opencode".to_string(),
|
||||
original_ids: Some(HashMap::from([
|
||||
("session_id".to_string(), "abc123".to_string()),
|
||||
])),
|
||||
raw_excerpt: None,
|
||||
}),
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
let json = serde_json::to_string(&meta).unwrap();
|
||||
assert!(json.contains(r#""name":"opencode"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_serializing_if_none() {
|
||||
let meta = Meta::default();
|
||||
let json = serde_json::to_string(&meta).unwrap();
|
||||
// Should be empty object since provider is None
|
||||
assert_eq!(json, "{}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip() {
|
||||
let meta = Meta {
|
||||
provider: Some(ProviderMeta {
|
||||
name: "test".to_string(),
|
||||
original_ids: None,
|
||||
raw_excerpt: None,
|
||||
}),
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
let json = serde_json::to_string(&meta).unwrap();
|
||||
let deserialized: Meta = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(meta, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meta_preserves_claude_code_tool_response() {
|
||||
use serde_json::json;
|
||||
|
||||
// Test that complex nested _meta data is preserved (T013 requirement)
|
||||
let meta_json = json!({
|
||||
"claudeCode": {
|
||||
"toolResponse": {
|
||||
"mode": "content",
|
||||
"numFiles": 0,
|
||||
"filenames": [],
|
||||
"content": "some grep output here",
|
||||
"numLines": 58,
|
||||
"appliedLimit": 100
|
||||
},
|
||||
"toolName": "Grep"
|
||||
}
|
||||
});
|
||||
|
||||
// Deserialize into Meta
|
||||
let meta: Meta = serde_json::from_value(meta_json.clone()).unwrap();
|
||||
|
||||
// Verify claudeCode is in extra
|
||||
assert!(meta.extra.contains_key("claudeCode"));
|
||||
|
||||
// Serialize back to JSON
|
||||
let serialized = serde_json::to_value(&meta).unwrap();
|
||||
|
||||
// Verify all fields are preserved
|
||||
assert_eq!(serialized["claudeCode"]["toolName"], "Grep");
|
||||
assert!(serialized["claudeCode"]["toolResponse"].is_object());
|
||||
assert_eq!(serialized["claudeCode"]["toolResponse"]["mode"], "content");
|
||||
assert_eq!(serialized["claudeCode"]["toolResponse"]["numFiles"], 0);
|
||||
assert_eq!(serialized["claudeCode"]["toolResponse"]["numLines"], 58);
|
||||
assert_eq!(serialized["claudeCode"]["toolResponse"]["appliedLimit"], 100);
|
||||
assert_eq!(
|
||||
serialized["claudeCode"]["toolResponse"]["content"],
|
||||
"some grep output here"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meta_round_trip_preservation() {
|
||||
use serde_json::json;
|
||||
|
||||
// Test incoming → serialize → deserialize → serialize preserves all fields (T013)
|
||||
let original_meta_json = json!({
|
||||
"claudeCode": {
|
||||
"toolResponse": {
|
||||
"mode": "content",
|
||||
"numFiles": 3,
|
||||
"filenames": ["file1.rs", "file2.rs", "file3.rs"],
|
||||
"content": "match results",
|
||||
"numLines": 42,
|
||||
"appliedLimit": 100,
|
||||
"customField": "should be preserved"
|
||||
},
|
||||
"toolName": "Grep",
|
||||
"additionalField": "also preserved"
|
||||
},
|
||||
"provider": {
|
||||
"name": "anthropic",
|
||||
"original_ids": {
|
||||
"session_id": "sess_123"
|
||||
}
|
||||
},
|
||||
"customTopLevel": "preserved too"
|
||||
});
|
||||
|
||||
// First round trip
|
||||
let meta1: Meta = serde_json::from_value(original_meta_json.clone()).unwrap();
|
||||
let json1 = serde_json::to_value(&meta1).unwrap();
|
||||
|
||||
// Second round trip
|
||||
let meta2: Meta = serde_json::from_value(json1.clone()).unwrap();
|
||||
let json2 = serde_json::to_value(&meta2).unwrap();
|
||||
|
||||
// Verify both serializations are identical (stable)
|
||||
assert_eq!(json1, json2);
|
||||
|
||||
// Verify all nested fields preserved
|
||||
assert_eq!(json2["claudeCode"]["toolResponse"]["customField"], "should be preserved");
|
||||
assert_eq!(json2["claudeCode"]["additionalField"], "also preserved");
|
||||
assert_eq!(json2["customTopLevel"], "preserved too");
|
||||
assert_eq!(json2["provider"]["name"], "anthropic");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meta_extra_fields_with_flatten() {
|
||||
use serde_json::json;
|
||||
|
||||
// Test that #[serde(flatten)] correctly captures arbitrary fields
|
||||
let json = json!({
|
||||
"provider": {
|
||||
"name": "test_provider"
|
||||
},
|
||||
"arbitraryField1": "value1",
|
||||
"arbitraryField2": {
|
||||
"nested": "structure"
|
||||
},
|
||||
"arbitraryField3": [1, 2, 3]
|
||||
});
|
||||
|
||||
let meta: Meta = serde_json::from_value(json.clone()).unwrap();
|
||||
|
||||
// Verify provider is parsed correctly
|
||||
assert!(meta.provider.is_some());
|
||||
assert_eq!(meta.provider.as_ref().unwrap().name, "test_provider");
|
||||
|
||||
// Verify extra fields are captured
|
||||
assert_eq!(meta.extra.len(), 3);
|
||||
assert!(meta.extra.contains_key("arbitraryField1"));
|
||||
assert!(meta.extra.contains_key("arbitraryField2"));
|
||||
assert!(meta.extra.contains_key("arbitraryField3"));
|
||||
|
||||
// Verify round-trip preserves all fields
|
||||
let serialized = serde_json::to_value(&meta).unwrap();
|
||||
assert_eq!(serialized["arbitraryField1"], "value1");
|
||||
assert_eq!(serialized["arbitraryField2"]["nested"], "structure");
|
||||
assert_eq!(serialized["arbitraryField3"], json!([1, 2, 3]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_meta_no_data_loss_on_unknown_fields() {
|
||||
use serde_json::json;
|
||||
|
||||
// Simulate receiving _meta from Claude with fields we don't know about yet
|
||||
let future_meta = json!({
|
||||
"claudeCode": {
|
||||
"toolName": "FutureTool",
|
||||
"futureFeature1": "some value",
|
||||
"futureFeature2": {
|
||||
"deeplyNested": {
|
||||
"data": [1, 2, 3]
|
||||
}
|
||||
}
|
||||
},
|
||||
"unknownTopLevel": "should not be lost"
|
||||
});
|
||||
|
||||
let meta: Meta = serde_json::from_value(future_meta.clone()).unwrap();
|
||||
let serialized = serde_json::to_value(&meta).unwrap();
|
||||
|
||||
// Verify NO data loss - all unknown fields preserved
|
||||
assert_eq!(serialized["claudeCode"]["toolName"], "FutureTool");
|
||||
assert_eq!(serialized["claudeCode"]["futureFeature1"], "some value");
|
||||
assert_eq!(
|
||||
serialized["claudeCode"]["futureFeature2"]["deeplyNested"]["data"],
|
||||
json!([1, 2, 3])
|
||||
);
|
||||
assert_eq!(serialized["unknownTopLevel"], "should not be lost");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
pub mod content;
|
||||
pub mod meta;
|
||||
pub mod permission;
|
||||
pub mod tool;
|
||||
pub mod updates;
|
||||
|
||||
pub use content::ContentBlock;
|
||||
pub use meta::{Meta, ProviderMeta};
|
||||
pub use permission::{
|
||||
PermissionOption, PermissionOptionKind, RequestPermissionOutcome, RequestPermissionResponse,
|
||||
ToolCallInfo, ToolCallLocation, ToolCallStatus as PermissionToolCallStatus, ToolKind,
|
||||
};
|
||||
pub use tool::{ToolCall, ToolCallContent, ToolCallId, ToolCallStatus, ToolOrigin};
|
||||
pub use updates::SessionUpdate;
|
||||
@@ -0,0 +1,454 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// ACP permission option presented to the user
|
||||
///
|
||||
/// When a tool requires permission, the system presents the user with a list of
|
||||
/// options that define how they want to handle the request. Each option has a kind
|
||||
/// that determines the scope of the permission grant or rejection.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use dirigent_protocol::{PermissionOption, PermissionOptionKind};
|
||||
///
|
||||
/// let allow_once = PermissionOption {
|
||||
/// option_id: "allow_once_1".to_string(),
|
||||
/// name: "Allow this time".to_string(),
|
||||
/// kind: PermissionOptionKind::AllowOnce,
|
||||
/// };
|
||||
///
|
||||
/// let allow_always = PermissionOption {
|
||||
/// option_id: "allow_always_1".to_string(),
|
||||
/// name: "Always allow for this session".to_string(),
|
||||
/// kind: PermissionOptionKind::AllowAlways,
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub struct PermissionOption {
|
||||
/// Unique identifier for this permission option
|
||||
#[serde(rename = "optionId")]
|
||||
pub option_id: String,
|
||||
/// User-facing name/label for this option
|
||||
pub name: String,
|
||||
/// The kind of permission action this option represents
|
||||
pub kind: PermissionOptionKind,
|
||||
}
|
||||
|
||||
/// Kind of permission option defining scope of grant/rejection
|
||||
///
|
||||
/// These variants define how the user's permission decision should be applied:
|
||||
/// - **Once**: Applies to the current request only
|
||||
/// - **Always**: Applies to all similar requests in the session/context
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PermissionOptionKind {
|
||||
/// Grant permission for this single request
|
||||
AllowOnce,
|
||||
/// Grant permission for all similar requests
|
||||
AllowAlways,
|
||||
/// Reject this single request
|
||||
RejectOnce,
|
||||
/// Reject all similar requests
|
||||
RejectAlways,
|
||||
}
|
||||
|
||||
/// Response from a permission request
|
||||
///
|
||||
/// When the system requests permission from the user, the response indicates
|
||||
/// either that the user selected a specific option, or that they cancelled
|
||||
/// the request entirely.
|
||||
///
|
||||
/// # ACP Wire Format
|
||||
///
|
||||
/// The response is structured with the outcome nested inside an `outcome` field:
|
||||
/// ```json
|
||||
/// {"outcome": {"outcome": "selected", "optionId": "allow_once_1"}}
|
||||
/// ```
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use dirigent_protocol::{RequestPermissionResponse, RequestPermissionOutcome};
|
||||
///
|
||||
/// // User selected an option
|
||||
/// let selected = RequestPermissionResponse {
|
||||
/// outcome: RequestPermissionOutcome::Selected {
|
||||
/// option_id: "allow_once_1".to_string(),
|
||||
/// },
|
||||
/// };
|
||||
///
|
||||
/// // User cancelled the request
|
||||
/// let cancelled = RequestPermissionResponse {
|
||||
/// outcome: RequestPermissionOutcome::Cancelled,
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RequestPermissionResponse {
|
||||
/// The outcome of the permission request (contains optionId if selected)
|
||||
pub outcome: RequestPermissionOutcome,
|
||||
}
|
||||
|
||||
/// Outcome of a permission request
|
||||
///
|
||||
/// Uses internal tagging to produce the ACP wire format:
|
||||
/// - Selected: `{"outcome": "selected", "optionId": "..."}`
|
||||
/// - Cancelled: `{"outcome": "cancelled"}`
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "outcome", rename_all = "snake_case")]
|
||||
pub enum RequestPermissionOutcome {
|
||||
/// User selected one of the provided options
|
||||
Selected {
|
||||
/// The ID of the selected option
|
||||
#[serde(rename = "optionId")]
|
||||
option_id: String,
|
||||
},
|
||||
/// User cancelled the permission request
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// Information about a tool call for permission requests
|
||||
///
|
||||
/// When requesting permission for a tool execution, this provides context
|
||||
/// about what the tool will do, including its kind, status, and affected locations.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use dirigent_protocol::{ToolCallInfo, ToolKind, ToolCallStatus, ToolCallLocation};
|
||||
///
|
||||
/// let info = ToolCallInfo {
|
||||
/// tool_call_id: "call_123".to_string(),
|
||||
/// title: "Read configuration file".to_string(),
|
||||
/// kind: Some(ToolKind::Read),
|
||||
/// status: Some(ToolCallStatus::Pending),
|
||||
/// locations: Some(vec![
|
||||
/// ToolCallLocation {
|
||||
/// path: "/etc/config.toml".to_string(),
|
||||
/// line: None,
|
||||
/// }
|
||||
/// ]),
|
||||
/// raw_input: None,
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCallInfo {
|
||||
/// Unique identifier for this tool call
|
||||
#[serde(rename = "toolCallId")]
|
||||
pub tool_call_id: String,
|
||||
/// User-facing title describing what the tool will do
|
||||
pub title: String,
|
||||
/// The kind/category of operation this tool performs
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub kind: Option<ToolKind>,
|
||||
/// Current status of the tool call
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub status: Option<ToolCallStatus>,
|
||||
/// File/resource locations affected by this tool call
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub locations: Option<Vec<ToolCallLocation>>,
|
||||
/// Raw input parameters for debugging/inspection
|
||||
#[serde(rename = "rawInput", skip_serializing_if = "Option::is_none")]
|
||||
pub raw_input: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Category of tool operation
|
||||
///
|
||||
/// Provides semantic categorization of tool functionality to help users
|
||||
/// understand the impact and risk level of granting permission.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolKind {
|
||||
/// Read-only operations (viewing files, searching)
|
||||
Read,
|
||||
/// Modify existing content
|
||||
Edit,
|
||||
/// Remove content
|
||||
Delete,
|
||||
/// Relocate content
|
||||
Move,
|
||||
/// Search operations
|
||||
Search,
|
||||
/// Execute commands or scripts
|
||||
Execute,
|
||||
/// Internal reasoning/planning (no external effects)
|
||||
Think,
|
||||
/// Fetch remote resources
|
||||
Fetch,
|
||||
/// Other/uncategorized operations
|
||||
Other,
|
||||
}
|
||||
|
||||
/// Status of a tool call
|
||||
///
|
||||
/// Note: This duplicates `ToolCallStatus` from `tool.rs` for now.
|
||||
/// In the future, we may consolidate these types.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolCallStatus {
|
||||
/// Tool call created but not yet started
|
||||
Pending,
|
||||
/// Tool call is currently executing
|
||||
#[serde(rename = "in_progress")]
|
||||
Running,
|
||||
/// Tool call completed successfully
|
||||
Completed,
|
||||
/// Tool call failed with an error
|
||||
#[serde(rename = "failed")]
|
||||
Failed,
|
||||
}
|
||||
|
||||
/// Location affected by a tool call
|
||||
///
|
||||
/// Represents a file or resource that will be read, modified, or otherwise
|
||||
/// affected by the tool execution. May optionally include a specific line number.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use dirigent_protocol::ToolCallLocation;
|
||||
///
|
||||
/// // File-level location
|
||||
/// let file_location = ToolCallLocation {
|
||||
/// path: "/src/main.rs".to_string(),
|
||||
/// line: None,
|
||||
/// };
|
||||
///
|
||||
/// // Specific line location
|
||||
/// let line_location = ToolCallLocation {
|
||||
/// path: "/src/lib.rs".to_string(),
|
||||
/// line: Some(42),
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCallLocation {
|
||||
/// File path or resource identifier
|
||||
pub path: String,
|
||||
/// Optional line number within the file
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line: Option<i32>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_permission_option_serialization() {
|
||||
let option = PermissionOption {
|
||||
option_id: "allow_once_1".to_string(),
|
||||
name: "Allow this time".to_string(),
|
||||
kind: PermissionOptionKind::AllowOnce,
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&option).unwrap();
|
||||
assert_eq!(json["optionId"], "allow_once_1");
|
||||
assert_eq!(json["name"], "Allow this time");
|
||||
assert_eq!(json["kind"], "allow_once");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_permission_option_kind_variants() {
|
||||
let kinds = vec![
|
||||
(PermissionOptionKind::AllowOnce, "allow_once"),
|
||||
(PermissionOptionKind::AllowAlways, "allow_always"),
|
||||
(PermissionOptionKind::RejectOnce, "reject_once"),
|
||||
(PermissionOptionKind::RejectAlways, "reject_always"),
|
||||
];
|
||||
|
||||
for (kind, expected) in kinds {
|
||||
let json = serde_json::to_string(&kind).unwrap();
|
||||
assert_eq!(json, format!(r#""{}""#, expected));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_permission_response_selected() {
|
||||
let response = RequestPermissionResponse {
|
||||
outcome: RequestPermissionOutcome::Selected {
|
||||
option_id: "allow_once_1".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&response).unwrap();
|
||||
// The outcome field should contain an object with nested outcome and optionId
|
||||
assert_eq!(json["outcome"]["outcome"], "selected");
|
||||
assert_eq!(json["outcome"]["optionId"], "allow_once_1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_permission_response_cancelled() {
|
||||
let response = RequestPermissionResponse {
|
||||
outcome: RequestPermissionOutcome::Cancelled,
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&response).unwrap();
|
||||
// The outcome field should contain an object with just outcome
|
||||
assert_eq!(json["outcome"]["outcome"], "cancelled");
|
||||
// optionId should not be present for cancelled
|
||||
assert!(json["outcome"].get("optionId").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_permission_response_wire_format() {
|
||||
// Test that the serialization matches ACP wire format exactly
|
||||
let response = RequestPermissionResponse {
|
||||
outcome: RequestPermissionOutcome::Selected {
|
||||
option_id: "allow".to_string(),
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&response).unwrap();
|
||||
// Should produce: {"outcome": {"outcome": "selected", "optionId": "allow"}}
|
||||
let expected = json!({
|
||||
"outcome": {
|
||||
"outcome": "selected",
|
||||
"optionId": "allow"
|
||||
}
|
||||
});
|
||||
assert_eq!(json, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_info_serialization() {
|
||||
let info = ToolCallInfo {
|
||||
tool_call_id: "call_123".to_string(),
|
||||
title: "Read file".to_string(),
|
||||
kind: Some(ToolKind::Read),
|
||||
status: Some(ToolCallStatus::Pending),
|
||||
locations: Some(vec![ToolCallLocation {
|
||||
path: "/test.txt".to_string(),
|
||||
line: Some(10),
|
||||
}]),
|
||||
raw_input: Some(json!({"path": "/test.txt"})),
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&info).unwrap();
|
||||
assert_eq!(json["toolCallId"], "call_123");
|
||||
assert_eq!(json["title"], "Read file");
|
||||
assert_eq!(json["kind"], "read");
|
||||
assert_eq!(json["status"], "pending");
|
||||
assert!(json["locations"].is_array());
|
||||
assert!(json["rawInput"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_info_minimal() {
|
||||
let info = ToolCallInfo {
|
||||
tool_call_id: "call_456".to_string(),
|
||||
title: "Execute command".to_string(),
|
||||
kind: None,
|
||||
status: None,
|
||||
locations: None,
|
||||
raw_input: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&info).unwrap();
|
||||
assert_eq!(json["toolCallId"], "call_456");
|
||||
assert_eq!(json["title"], "Execute command");
|
||||
// Optional fields should be omitted
|
||||
assert!(json.get("kind").is_none());
|
||||
assert!(json.get("status").is_none());
|
||||
assert!(json.get("locations").is_none());
|
||||
assert!(json.get("rawInput").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_kind_variants() {
|
||||
let kinds = vec![
|
||||
(ToolKind::Read, "read"),
|
||||
(ToolKind::Edit, "edit"),
|
||||
(ToolKind::Delete, "delete"),
|
||||
(ToolKind::Move, "move"),
|
||||
(ToolKind::Search, "search"),
|
||||
(ToolKind::Execute, "execute"),
|
||||
(ToolKind::Think, "think"),
|
||||
(ToolKind::Fetch, "fetch"),
|
||||
(ToolKind::Other, "other"),
|
||||
];
|
||||
|
||||
for (kind, expected) in kinds {
|
||||
let json = serde_json::to_string(&kind).unwrap();
|
||||
assert_eq!(json, format!(r#""{}""#, expected));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_status_variants() {
|
||||
let statuses = vec![
|
||||
(ToolCallStatus::Pending, "pending"),
|
||||
(ToolCallStatus::Running, "in_progress"),
|
||||
(ToolCallStatus::Completed, "completed"),
|
||||
(ToolCallStatus::Failed, "failed"),
|
||||
];
|
||||
|
||||
for (status, expected) in statuses {
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
assert_eq!(json, format!(r#""{}""#, expected));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_location_with_line() {
|
||||
let location = ToolCallLocation {
|
||||
path: "/src/main.rs".to_string(),
|
||||
line: Some(42),
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&location).unwrap();
|
||||
assert_eq!(json["path"], "/src/main.rs");
|
||||
assert_eq!(json["line"], 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_location_without_line() {
|
||||
let location = ToolCallLocation {
|
||||
path: "/config.toml".to_string(),
|
||||
line: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&location).unwrap();
|
||||
assert_eq!(json["path"], "/config.toml");
|
||||
assert!(json.get("line").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_permission_option() {
|
||||
let original = PermissionOption {
|
||||
option_id: "test_id".to_string(),
|
||||
name: "Test Option".to_string(),
|
||||
kind: PermissionOptionKind::AllowAlways,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: PermissionOption = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roundtrip_tool_call_info() {
|
||||
let original = ToolCallInfo {
|
||||
tool_call_id: "call_789".to_string(),
|
||||
title: "Complex operation".to_string(),
|
||||
kind: Some(ToolKind::Edit),
|
||||
status: Some(ToolCallStatus::Running),
|
||||
locations: Some(vec![
|
||||
ToolCallLocation {
|
||||
path: "/file1.rs".to_string(),
|
||||
line: Some(10),
|
||||
},
|
||||
ToolCallLocation {
|
||||
path: "/file2.rs".to_string(),
|
||||
line: None,
|
||||
},
|
||||
]),
|
||||
raw_input: Some(json!({"key": "value"})),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: ToolCallInfo = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,513 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::types::content::ContentBlock;
|
||||
|
||||
/// ACP-compliant tool call content wrapper supporting multiple content types
|
||||
///
|
||||
/// The ACP protocol requires tool call content to be wrapped in a discriminated
|
||||
/// union that supports three types:
|
||||
/// - **Content** - Regular content blocks (text, images, etc.)
|
||||
/// - **Diff** - File change diffs (oldText, newText)
|
||||
/// - **Terminal** - Terminal session references
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use dirigent_protocol::{ToolCallContent, ContentBlock};
|
||||
///
|
||||
/// // Create a text content wrapper
|
||||
/// let content = ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
/// text: "Tool output".to_string()
|
||||
/// });
|
||||
///
|
||||
/// // Create a diff
|
||||
/// let diff = ToolCallContent::diff(
|
||||
/// "/src/main.rs".to_string(),
|
||||
/// Some("old code".to_string()),
|
||||
/// "new code".to_string(),
|
||||
/// );
|
||||
///
|
||||
/// // Create a terminal reference
|
||||
/// let terminal = ToolCallContent::terminal("term_123".to_string());
|
||||
/// ```
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ToolCallContent {
|
||||
/// Regular content (text, images, etc.)
|
||||
Content {
|
||||
content: ContentBlock,
|
||||
},
|
||||
|
||||
/// File diff showing changes
|
||||
Diff {
|
||||
path: String,
|
||||
#[serde(rename = "oldText", skip_serializing_if = "Option::is_none")]
|
||||
old_text: Option<String>,
|
||||
#[serde(rename = "newText")]
|
||||
new_text: String,
|
||||
},
|
||||
|
||||
/// Reference to a terminal session
|
||||
Terminal {
|
||||
#[serde(rename = "terminalId")]
|
||||
terminal_id: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl ToolCallContent {
|
||||
/// Create a content wrapper from a ContentBlock
|
||||
pub fn from_content_block(block: ContentBlock) -> Self {
|
||||
Self::Content { content: block }
|
||||
}
|
||||
|
||||
/// Create a diff wrapper
|
||||
pub fn diff(path: String, old_text: Option<String>, new_text: String) -> Self {
|
||||
Self::Diff { path, old_text, new_text }
|
||||
}
|
||||
|
||||
/// Create a terminal reference
|
||||
pub fn terminal(terminal_id: String) -> Self {
|
||||
Self::Terminal { terminal_id }
|
||||
}
|
||||
}
|
||||
|
||||
/// Unique identifier for a tool call
|
||||
pub type ToolCallId = String;
|
||||
|
||||
/// Status of a tool call in its lifecycle
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolCallStatus {
|
||||
/// Tool call has been created but not yet started
|
||||
Pending,
|
||||
/// Tool call is currently executing
|
||||
#[serde(rename = "in_progress")]
|
||||
Running,
|
||||
/// Tool call completed successfully
|
||||
Completed,
|
||||
/// Tool call failed with an error
|
||||
#[serde(rename = "failed")]
|
||||
Error,
|
||||
}
|
||||
|
||||
/// Origin of a tool execution
|
||||
///
|
||||
/// Distinguishes where the tool is actually executed:
|
||||
/// - Internal: Dirigent runs the tool after user permission
|
||||
/// - External: Agent runs tool directly (we observe)
|
||||
/// - Forwarded: Upstream ACP server (transitionary)
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ToolOrigin {
|
||||
/// Tool executed by Dirigent after user permission
|
||||
Internal,
|
||||
/// Tool executed by agent directly (we observe)
|
||||
#[default]
|
||||
External,
|
||||
/// Tool forwarded from upstream ACP server
|
||||
Forwarded,
|
||||
}
|
||||
|
||||
/// Represents a tool call and its lifecycle
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
/// Unique identifier for this tool call
|
||||
pub id: ToolCallId,
|
||||
/// Name of the tool being called
|
||||
pub tool_name: String,
|
||||
/// Current status of the tool call
|
||||
pub status: ToolCallStatus,
|
||||
/// Content associated with this tool call (wrapped in ACP-compliant format)
|
||||
///
|
||||
/// Each content item is wrapped in a discriminated union that can be:
|
||||
/// - Content (text, images, etc.)
|
||||
/// - Diff (file changes)
|
||||
/// - Terminal (terminal session reference)
|
||||
#[serde(default)]
|
||||
pub content: Vec<ToolCallContent>,
|
||||
/// Raw input parameters (preserved for debugging/inspection)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub raw_input: Option<Value>,
|
||||
/// Raw output result (preserved for debugging/inspection)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub raw_output: Option<Value>,
|
||||
/// Optional title for the tool call
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub title: Option<String>,
|
||||
/// Error message if status is Error
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
/// Additional metadata
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<Value>,
|
||||
/// Origin of this tool execution
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub origin: Option<ToolOrigin>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_status_pending_serialization() {
|
||||
let status = ToolCallStatus::Pending;
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
assert_eq!(json, r#""pending""#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_status_running_serialization() {
|
||||
let status = ToolCallStatus::Running;
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
assert_eq!(json, r#""in_progress""#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_status_completed_serialization() {
|
||||
let status = ToolCallStatus::Completed;
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
assert_eq!(json, r#""completed""#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_status_error_serialization() {
|
||||
let status = ToolCallStatus::Error;
|
||||
let json = serde_json::to_string(&status).unwrap();
|
||||
assert_eq!(json, r#""failed""#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_serialization_minimal() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&tool_call).unwrap();
|
||||
|
||||
// Verify required fields are present
|
||||
assert!(json.contains(r#""id":"call_123""#));
|
||||
assert!(json.contains(r#""tool_name":"bash""#));
|
||||
assert!(json.contains(r#""status":"pending""#));
|
||||
assert!(json.contains(r#""content":[]"#));
|
||||
|
||||
// Verify optional fields are NOT present when None
|
||||
assert!(!json.contains(r#""raw_input""#));
|
||||
assert!(!json.contains(r#""raw_output""#));
|
||||
assert!(!json.contains(r#""title""#));
|
||||
assert!(!json.contains(r#""error""#));
|
||||
assert!(!json.contains(r#""metadata""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_serialization_complete() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_456".to_string(),
|
||||
tool_name: "read_file".to_string(),
|
||||
status: ToolCallStatus::Completed,
|
||||
content: vec![ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
text: "File contents".to_string(),
|
||||
})],
|
||||
raw_input: Some(json!({"path": "/tmp/test.txt"})),
|
||||
raw_output: Some(json!({"success": true})),
|
||||
title: Some("Read test file".to_string()),
|
||||
error: None,
|
||||
metadata: Some(json!({"duration_ms": 42})),
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&tool_call).unwrap();
|
||||
|
||||
// Verify all fields are present
|
||||
assert!(json.contains(r#""id":"call_456""#));
|
||||
assert!(json.contains(r#""tool_name":"read_file""#));
|
||||
assert!(json.contains(r#""status":"completed""#));
|
||||
assert!(json.contains(r#""text":"File contents""#));
|
||||
assert!(json.contains(r#""raw_input""#));
|
||||
assert!(json.contains(r#""raw_output""#));
|
||||
assert!(json.contains(r#""title":"Read test file""#));
|
||||
assert!(json.contains(r#""metadata""#));
|
||||
assert!(!json.contains(r#""error""#)); // Still None
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_serialization_with_error() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_789".to_string(),
|
||||
tool_name: "write_file".to_string(),
|
||||
status: ToolCallStatus::Error,
|
||||
content: vec![],
|
||||
raw_input: Some(json!({"path": "/tmp/readonly.txt"})),
|
||||
raw_output: None,
|
||||
title: Some("Write to readonly file".to_string()),
|
||||
error: Some("Permission denied".to_string()),
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&tool_call).unwrap();
|
||||
|
||||
assert!(json.contains(r#""status":"failed""#));
|
||||
assert!(json.contains(r#""error":"Permission denied""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_roundtrip() {
|
||||
let original = ToolCall {
|
||||
id: "call_roundtrip".to_string(),
|
||||
tool_name: "test_tool".to_string(),
|
||||
status: ToolCallStatus::Running,
|
||||
content: vec![
|
||||
ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
text: "Output line 1".to_string(),
|
||||
}),
|
||||
ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
text: "Output line 2".to_string(),
|
||||
}),
|
||||
],
|
||||
raw_input: Some(json!({"arg": "value"})),
|
||||
raw_output: None,
|
||||
title: Some("Test Tool Call".to_string()),
|
||||
error: None,
|
||||
metadata: Some(json!({"test": true})),
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: ToolCall = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_default_content() {
|
||||
// Test that content defaults to empty vec when not present in JSON
|
||||
let json = r#"{
|
||||
"id": "call_default",
|
||||
"tool_name": "test",
|
||||
"status": "pending"
|
||||
}"#;
|
||||
|
||||
let tool_call: ToolCall = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(tool_call.content, vec![]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_optional_fields_skip_when_none() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_skip".to_string(),
|
||||
tool_name: "test".to_string(),
|
||||
status: ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_value(&tool_call).unwrap();
|
||||
let obj = json.as_object().unwrap();
|
||||
|
||||
// Verify optional fields are not in the serialized object
|
||||
assert!(!obj.contains_key("raw_input"));
|
||||
assert!(!obj.contains_key("raw_output"));
|
||||
assert!(!obj.contains_key("title"));
|
||||
assert!(!obj.contains_key("error"));
|
||||
assert!(!obj.contains_key("metadata"));
|
||||
|
||||
// Verify required fields ARE in the serialized object
|
||||
assert!(obj.contains_key("id"));
|
||||
assert!(obj.contains_key("tool_name"));
|
||||
assert!(obj.contains_key("status"));
|
||||
assert!(obj.contains_key("content"));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ToolCallContent Tests
|
||||
// ============================================================================
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_content_wrapper_content_serialization() {
|
||||
let content = ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
text: "test output".to_string(),
|
||||
});
|
||||
|
||||
let json = serde_json::to_value(&content).unwrap();
|
||||
|
||||
assert_eq!(json["type"], "content");
|
||||
assert!(json["content"].is_object());
|
||||
assert_eq!(json["content"]["type"], "text");
|
||||
assert_eq!(json["content"]["text"], "test output");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_content_wrapper_content_deserialization() {
|
||||
let json = json!({
|
||||
"type": "content",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "deserialized output"
|
||||
}
|
||||
});
|
||||
|
||||
let content: ToolCallContent = serde_json::from_value(json).unwrap();
|
||||
|
||||
match content {
|
||||
ToolCallContent::Content { content } => {
|
||||
match content {
|
||||
ContentBlock::Text { text } => {
|
||||
assert_eq!(text, "deserialized output");
|
||||
}
|
||||
_ => panic!("Expected Text content block"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Content variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_content_diff_serialization() {
|
||||
let diff = ToolCallContent::diff(
|
||||
"/src/main.rs".to_string(),
|
||||
Some("old code".to_string()),
|
||||
"new code".to_string(),
|
||||
);
|
||||
|
||||
let json = serde_json::to_value(&diff).unwrap();
|
||||
|
||||
assert_eq!(json["type"], "diff");
|
||||
assert_eq!(json["path"], "/src/main.rs");
|
||||
assert_eq!(json["oldText"], "old code");
|
||||
assert_eq!(json["newText"], "new code");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_content_diff_without_old_text() {
|
||||
let diff = ToolCallContent::diff(
|
||||
"/src/new_file.rs".to_string(),
|
||||
None,
|
||||
"new file content".to_string(),
|
||||
);
|
||||
|
||||
let json = serde_json::to_value(&diff).unwrap();
|
||||
|
||||
assert_eq!(json["type"], "diff");
|
||||
assert_eq!(json["path"], "/src/new_file.rs");
|
||||
assert!(json.get("oldText").is_none()); // Should be omitted when None
|
||||
assert_eq!(json["newText"], "new file content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_content_diff_deserialization() {
|
||||
let json = json!({
|
||||
"type": "diff",
|
||||
"path": "/test/file.rs",
|
||||
"oldText": "before",
|
||||
"newText": "after"
|
||||
});
|
||||
|
||||
let content: ToolCallContent = serde_json::from_value(json).unwrap();
|
||||
|
||||
match content {
|
||||
ToolCallContent::Diff { path, old_text, new_text } => {
|
||||
assert_eq!(path, "/test/file.rs");
|
||||
assert_eq!(old_text, Some("before".to_string()));
|
||||
assert_eq!(new_text, "after");
|
||||
}
|
||||
_ => panic!("Expected Diff variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_content_terminal_serialization() {
|
||||
let terminal = ToolCallContent::terminal("term_123".to_string());
|
||||
|
||||
let json = serde_json::to_value(&terminal).unwrap();
|
||||
|
||||
assert_eq!(json["type"], "terminal");
|
||||
assert_eq!(json["terminalId"], "term_123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_content_terminal_deserialization() {
|
||||
let json = json!({
|
||||
"type": "terminal",
|
||||
"terminalId": "term_456"
|
||||
});
|
||||
|
||||
let content: ToolCallContent = serde_json::from_value(json).unwrap();
|
||||
|
||||
match content {
|
||||
ToolCallContent::Terminal { terminal_id } => {
|
||||
assert_eq!(terminal_id, "term_456");
|
||||
}
|
||||
_ => panic!("Expected Terminal variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_content_roundtrip() {
|
||||
// Test all three variants
|
||||
let variants = vec![
|
||||
ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
text: "test".to_string(),
|
||||
}),
|
||||
ToolCallContent::diff("path.rs".to_string(), Some("old".to_string()), "new".to_string()),
|
||||
ToolCallContent::terminal("term_789".to_string()),
|
||||
];
|
||||
|
||||
for original in variants {
|
||||
let json = serde_json::to_value(&original).unwrap();
|
||||
let deserialized: ToolCallContent = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_with_mixed_content_types() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_mixed".to_string(),
|
||||
tool_name: "edit_tool".to_string(),
|
||||
status: ToolCallStatus::Completed,
|
||||
content: vec![
|
||||
ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
text: "Editing file...".to_string(),
|
||||
}),
|
||||
ToolCallContent::diff(
|
||||
"/src/lib.rs".to_string(),
|
||||
Some("fn old() {}".to_string()),
|
||||
"fn new() {}".to_string(),
|
||||
),
|
||||
ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
text: "Edit complete".to_string(),
|
||||
}),
|
||||
],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: Some("Edit file".to_string()),
|
||||
error: None,
|
||||
metadata: None,
|
||||
origin: None,
|
||||
};
|
||||
|
||||
// Serialize and deserialize
|
||||
let json = serde_json::to_value(&tool_call).unwrap();
|
||||
let deserialized: ToolCall = serde_json::from_value(json).unwrap();
|
||||
|
||||
assert_eq!(tool_call, deserialized);
|
||||
assert_eq!(deserialized.content.len(), 3);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,654 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::types::content::ContentBlock;
|
||||
use crate::types::meta::Meta;
|
||||
use crate::types::tool::ToolCall;
|
||||
|
||||
/// ACP-style session updates for streaming content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum SessionUpdate {
|
||||
/// User message content chunk
|
||||
UserMessageChunk {
|
||||
message_id: String,
|
||||
content: ContentBlock,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
_meta: Option<Meta>,
|
||||
},
|
||||
/// Agent message content chunk
|
||||
AgentMessageChunk {
|
||||
message_id: String,
|
||||
content: ContentBlock,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
_meta: Option<Meta>,
|
||||
},
|
||||
/// Agent thought content chunk (internal reasoning)
|
||||
AgentThoughtChunk {
|
||||
message_id: String,
|
||||
content: ContentBlock,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
_meta: Option<Meta>,
|
||||
},
|
||||
/// Tool call created or initiated
|
||||
ToolCall {
|
||||
message_id: String,
|
||||
tool_call: ToolCall,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
_meta: Option<Meta>,
|
||||
},
|
||||
/// Tool call update (status change, new content, etc.)
|
||||
ToolCallUpdate {
|
||||
message_id: String,
|
||||
tool_call_id: String,
|
||||
tool_call: ToolCall,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
_meta: Option<Meta>,
|
||||
},
|
||||
/// Unknown update type (forward compatibility - pass through as raw JSON)
|
||||
#[serde(untagged)]
|
||||
Unknown {
|
||||
#[serde(flatten)]
|
||||
data: serde_json::Value,
|
||||
},
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ToolCallContent;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_user_message_chunk_serialization() {
|
||||
let update = SessionUpdate::UserMessageChunk {
|
||||
message_id: "msg_123".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&update).unwrap();
|
||||
assert!(json.contains(r#""type":"user_message_chunk"#));
|
||||
assert!(json.contains(r#""message_id":"msg_123"#));
|
||||
assert!(json.contains(r#""text":"Hello"#));
|
||||
assert!(!json.contains(r#""_meta""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_message_chunk_deserialization() {
|
||||
let json = r#"{
|
||||
"type": "user_message_chunk",
|
||||
"message_id": "msg_123",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "Hello"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let update: SessionUpdate = serde_json::from_str(json).unwrap();
|
||||
match update {
|
||||
SessionUpdate::UserMessageChunk {
|
||||
message_id,
|
||||
content,
|
||||
_meta,
|
||||
} => {
|
||||
assert_eq!(message_id, "msg_123");
|
||||
assert_eq!(
|
||||
content,
|
||||
ContentBlock::Text {
|
||||
text: "Hello".to_string()
|
||||
}
|
||||
);
|
||||
assert_eq!(_meta, None);
|
||||
}
|
||||
_ => panic!("Expected UserMessageChunk"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_message_chunk_roundtrip() {
|
||||
let original = SessionUpdate::UserMessageChunk {
|
||||
message_id: "msg_456".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Roundtrip test".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_user_message_chunk_with_meta() {
|
||||
let update = SessionUpdate::UserMessageChunk {
|
||||
message_id: "msg_789".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "With meta".to_string(),
|
||||
},
|
||||
_meta: Some(Meta {
|
||||
provider: None,
|
||||
extra: std::collections::HashMap::new(),
|
||||
}),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&update).unwrap();
|
||||
assert!(json.contains(r#""_meta":{}"#));
|
||||
|
||||
let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(update, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_message_chunk_serialization() {
|
||||
let update = SessionUpdate::AgentMessageChunk {
|
||||
message_id: "msg_agent_1".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Agent response".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&update).unwrap();
|
||||
assert!(json.contains(r#""type":"agent_message_chunk"#));
|
||||
assert!(json.contains(r#""message_id":"msg_agent_1"#));
|
||||
assert!(json.contains(r#""text":"Agent response"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_message_chunk_deserialization() {
|
||||
let json = r#"{
|
||||
"type": "agent_message_chunk",
|
||||
"message_id": "msg_agent_2",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "Agent here"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let update: SessionUpdate = serde_json::from_str(json).unwrap();
|
||||
match update {
|
||||
SessionUpdate::AgentMessageChunk {
|
||||
message_id,
|
||||
content,
|
||||
_meta,
|
||||
} => {
|
||||
assert_eq!(message_id, "msg_agent_2");
|
||||
assert_eq!(
|
||||
content,
|
||||
ContentBlock::Text {
|
||||
text: "Agent here".to_string()
|
||||
}
|
||||
);
|
||||
assert_eq!(_meta, None);
|
||||
}
|
||||
_ => panic!("Expected AgentMessageChunk"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_message_chunk_roundtrip() {
|
||||
let original = SessionUpdate::AgentMessageChunk {
|
||||
message_id: "msg_agent_rt".to_string(),
|
||||
content: ContentBlock::ResourceLink {
|
||||
uri: "file:///test.txt".to_string(),
|
||||
name: Some("test.txt".to_string()),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_message_chunk_with_meta() {
|
||||
let mut extra = std::collections::HashMap::new();
|
||||
extra.insert("timestamp".to_string(), json!("2025-11-10T12:00:00Z"));
|
||||
|
||||
let update = SessionUpdate::AgentMessageChunk {
|
||||
message_id: "msg_agent_meta".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "With metadata".to_string(),
|
||||
},
|
||||
_meta: Some(Meta {
|
||||
provider: None,
|
||||
extra,
|
||||
}),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&update).unwrap();
|
||||
assert!(json.contains(r#""_meta""#));
|
||||
assert!(json.contains(r#""timestamp""#));
|
||||
|
||||
let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(update, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_thought_chunk_serialization() {
|
||||
let update = SessionUpdate::AgentThoughtChunk {
|
||||
message_id: "msg_thought_1".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Thinking...".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&update).unwrap();
|
||||
assert!(json.contains(r#""type":"agent_thought_chunk"#));
|
||||
assert!(json.contains(r#""message_id":"msg_thought_1"#));
|
||||
assert!(json.contains(r#""text":"Thinking..."#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_thought_chunk_deserialization() {
|
||||
let json = r#"{
|
||||
"type": "agent_thought_chunk",
|
||||
"message_id": "msg_thought_2",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "Analyzing the problem..."
|
||||
}
|
||||
}"#;
|
||||
|
||||
let update: SessionUpdate = serde_json::from_str(json).unwrap();
|
||||
match update {
|
||||
SessionUpdate::AgentThoughtChunk {
|
||||
message_id,
|
||||
content,
|
||||
_meta,
|
||||
} => {
|
||||
assert_eq!(message_id, "msg_thought_2");
|
||||
assert_eq!(
|
||||
content,
|
||||
ContentBlock::Text {
|
||||
text: "Analyzing the problem...".to_string()
|
||||
}
|
||||
);
|
||||
assert_eq!(_meta, None);
|
||||
}
|
||||
_ => panic!("Expected AgentThoughtChunk"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_thought_chunk_roundtrip() {
|
||||
let original = SessionUpdate::AgentThoughtChunk {
|
||||
message_id: "msg_thought_rt".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Internal reasoning".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_thought_chunk_with_meta() {
|
||||
let update = SessionUpdate::AgentThoughtChunk {
|
||||
message_id: "msg_thought_meta".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "Thought with meta".to_string(),
|
||||
},
|
||||
_meta: Some(Meta::default()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&update).unwrap();
|
||||
assert!(json.contains(r#""_meta":{}"#));
|
||||
|
||||
let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(update, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_serialization() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: crate::types::tool::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: Some("Run bash command".to_string()),
|
||||
error: None,
|
||||
metadata: None,
|
||||
|
||||
origin: None,
|
||||
|
||||
};
|
||||
|
||||
let update = SessionUpdate::ToolCall {
|
||||
message_id: "msg_tool_1".to_string(),
|
||||
tool_call,
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&update).unwrap();
|
||||
assert!(json.contains(r#""type":"tool_call"#));
|
||||
assert!(json.contains(r#""message_id":"msg_tool_1"#));
|
||||
assert!(json.contains(r#""tool_name":"bash"#));
|
||||
assert!(json.contains(r#""status":"pending"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_deserialization() {
|
||||
let json = r#"{
|
||||
"type": "tool_call",
|
||||
"message_id": "msg_tool_2",
|
||||
"tool_call": {
|
||||
"id": "call_456",
|
||||
"tool_name": "read",
|
||||
"status": "in_progress",
|
||||
"content": []
|
||||
}
|
||||
}"#;
|
||||
|
||||
let update: SessionUpdate = serde_json::from_str(json).unwrap();
|
||||
match update {
|
||||
SessionUpdate::ToolCall {
|
||||
message_id,
|
||||
tool_call,
|
||||
_meta,
|
||||
} => {
|
||||
assert_eq!(message_id, "msg_tool_2");
|
||||
assert_eq!(tool_call.id, "call_456");
|
||||
assert_eq!(tool_call.tool_name, "read");
|
||||
assert_eq!(
|
||||
tool_call.status,
|
||||
crate::types::tool::ToolCallStatus::Running
|
||||
);
|
||||
assert_eq!(_meta, None);
|
||||
}
|
||||
_ => panic!("Expected ToolCall"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_roundtrip() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_rt".to_string(),
|
||||
tool_name: "write".to_string(),
|
||||
status: crate::types::tool::ToolCallStatus::Completed,
|
||||
content: vec![ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
text: "File written".to_string(),
|
||||
})],
|
||||
raw_input: Some(json!({"path": "/tmp/test.txt"})),
|
||||
raw_output: Some(json!({"success": true})),
|
||||
title: Some("Write file".to_string()),
|
||||
error: None,
|
||||
metadata: None,
|
||||
|
||||
origin: None,
|
||||
|
||||
};
|
||||
|
||||
let original = SessionUpdate::ToolCall {
|
||||
message_id: "msg_tool_rt".to_string(),
|
||||
tool_call,
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_with_meta() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_meta".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: crate::types::tool::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
|
||||
origin: None,
|
||||
|
||||
};
|
||||
|
||||
let update = SessionUpdate::ToolCall {
|
||||
message_id: "msg_tool_meta".to_string(),
|
||||
tool_call,
|
||||
_meta: Some(Meta::default()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&update).unwrap();
|
||||
assert!(json.contains(r#""_meta":{}"#));
|
||||
|
||||
let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(update, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_update_serialization() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_update_1".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: crate::types::tool::ToolCallStatus::Running,
|
||||
content: vec![ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
text: "Output line 1".to_string(),
|
||||
})],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: Some("Running bash".to_string()),
|
||||
error: None,
|
||||
metadata: None,
|
||||
|
||||
origin: None,
|
||||
|
||||
};
|
||||
|
||||
let update = SessionUpdate::ToolCallUpdate {
|
||||
message_id: "msg_update_1".to_string(),
|
||||
tool_call_id: "call_update_1".to_string(),
|
||||
tool_call,
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&update).unwrap();
|
||||
assert!(json.contains(r#""type":"tool_call_update"#));
|
||||
assert!(json.contains(r#""message_id":"msg_update_1"#));
|
||||
assert!(json.contains(r#""tool_call_id":"call_update_1"#));
|
||||
// ToolCallStatus::Running serializes as "in_progress"
|
||||
assert!(json.contains(r#""status":"in_progress"#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_update_deserialization() {
|
||||
let json = r#"{
|
||||
"type": "tool_call_update",
|
||||
"message_id": "msg_update_2",
|
||||
"tool_call_id": "call_update_2",
|
||||
"tool_call": {
|
||||
"id": "call_update_2",
|
||||
"tool_name": "read",
|
||||
"status": "completed",
|
||||
"content": [
|
||||
{
|
||||
"type": "content",
|
||||
"content": {
|
||||
"type": "text",
|
||||
"text": "File contents"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}"#;
|
||||
|
||||
let update: SessionUpdate = serde_json::from_str(json).unwrap();
|
||||
match update {
|
||||
SessionUpdate::ToolCallUpdate {
|
||||
message_id,
|
||||
tool_call_id,
|
||||
tool_call,
|
||||
_meta,
|
||||
} => {
|
||||
assert_eq!(message_id, "msg_update_2");
|
||||
assert_eq!(tool_call_id, "call_update_2");
|
||||
assert_eq!(tool_call.id, "call_update_2");
|
||||
assert_eq!(
|
||||
tool_call.status,
|
||||
crate::types::tool::ToolCallStatus::Completed
|
||||
);
|
||||
assert_eq!(
|
||||
tool_call.content,
|
||||
vec![crate::types::tool::ToolCallContent::from_content_block(ContentBlock::Text {
|
||||
text: "File contents".to_string()
|
||||
})]
|
||||
);
|
||||
assert_eq!(_meta, None);
|
||||
}
|
||||
_ => panic!("Expected ToolCallUpdate"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_update_roundtrip() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_update_rt".to_string(),
|
||||
tool_name: "bash".to_string(),
|
||||
status: crate::types::tool::ToolCallStatus::Error,
|
||||
content: vec![],
|
||||
raw_input: Some(json!({"command": "invalid"})),
|
||||
raw_output: None,
|
||||
title: Some("Failed command".to_string()),
|
||||
error: Some("Command not found".to_string()),
|
||||
metadata: None,
|
||||
|
||||
origin: None,
|
||||
|
||||
};
|
||||
|
||||
let original = SessionUpdate::ToolCallUpdate {
|
||||
message_id: "msg_update_rt".to_string(),
|
||||
tool_call_id: "call_update_rt".to_string(),
|
||||
tool_call,
|
||||
_meta: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&original).unwrap();
|
||||
let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(original, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_update_with_meta() {
|
||||
let tool_call = ToolCall {
|
||||
id: "call_update_meta".to_string(),
|
||||
tool_name: "write".to_string(),
|
||||
status: crate::types::tool::ToolCallStatus::Completed,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
|
||||
origin: None,
|
||||
|
||||
};
|
||||
|
||||
let update = SessionUpdate::ToolCallUpdate {
|
||||
message_id: "msg_update_meta".to_string(),
|
||||
tool_call_id: "call_update_meta".to_string(),
|
||||
tool_call,
|
||||
_meta: Some(Meta::default()),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&update).unwrap();
|
||||
assert!(json.contains(r#""_meta":{}"#));
|
||||
|
||||
let deserialized: SessionUpdate = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(update, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_variants_have_snake_case_type_tags() {
|
||||
// Test that each variant serializes with the correct snake_case type tag
|
||||
|
||||
let user_chunk = SessionUpdate::UserMessageChunk {
|
||||
message_id: "m1".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "test".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
let json = serde_json::to_string(&user_chunk).unwrap();
|
||||
assert!(json.contains(r#""type":"user_message_chunk"#));
|
||||
|
||||
let agent_chunk = SessionUpdate::AgentMessageChunk {
|
||||
message_id: "m2".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "test".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
let json = serde_json::to_string(&agent_chunk).unwrap();
|
||||
assert!(json.contains(r#""type":"agent_message_chunk"#));
|
||||
|
||||
let thought_chunk = SessionUpdate::AgentThoughtChunk {
|
||||
message_id: "m3".to_string(),
|
||||
content: ContentBlock::Text {
|
||||
text: "test".to_string(),
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
let json = serde_json::to_string(&thought_chunk).unwrap();
|
||||
assert!(json.contains(r#""type":"agent_thought_chunk"#));
|
||||
|
||||
let tool_call = SessionUpdate::ToolCall {
|
||||
message_id: "m4".to_string(),
|
||||
tool_call: ToolCall {
|
||||
id: "c1".to_string(),
|
||||
tool_name: "test".to_string(),
|
||||
status: crate::types::tool::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
|
||||
origin: None,
|
||||
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
let json = serde_json::to_string(&tool_call).unwrap();
|
||||
assert!(json.contains(r#""type":"tool_call"#));
|
||||
|
||||
let tool_call_update = SessionUpdate::ToolCallUpdate {
|
||||
message_id: "m5".to_string(),
|
||||
tool_call_id: "c2".to_string(),
|
||||
tool_call: ToolCall {
|
||||
id: "c2".to_string(),
|
||||
tool_name: "test".to_string(),
|
||||
status: crate::types::tool::ToolCallStatus::Pending,
|
||||
content: vec![],
|
||||
raw_input: None,
|
||||
raw_output: None,
|
||||
title: None,
|
||||
error: None,
|
||||
metadata: None,
|
||||
|
||||
origin: None,
|
||||
|
||||
},
|
||||
_meta: None,
|
||||
};
|
||||
let json = serde_json::to_string(&tool_call_update).unwrap();
|
||||
assert!(json.contains(r#""type":"tool_call_update"#));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user