718 lines
24 KiB
Rust
718 lines
24 KiB
Rust
//! Protocol-level message accumulator for incremental message assembly.
|
|
//!
|
|
//! Handles streaming message deltas and assembles them into complete
|
|
//! [`AccumulatedMessage`] values using protocol types. Unlike the archivist's
|
|
//! accumulator, this module produces protocol-native output without UUID
|
|
//! parsing, markdown generation, or storage metadata.
|
|
//!
|
|
//! The accumulator preserves the order of content parts (text, thinking, tool
|
|
//! calls) as they arrive in the event stream, enabling inline tool rendering
|
|
//! and faithful forwarding to downstream consumers.
|
|
|
|
use chrono::{DateTime, Utc};
|
|
use serde_json::Value;
|
|
use std::collections::HashMap;
|
|
|
|
use crate::conversation::MessagePart;
|
|
use crate::types::ContentBlock;
|
|
|
|
/// Tool call data accumulated during streaming.
|
|
#[derive(Debug, Clone)]
|
|
pub struct ToolCallData {
|
|
pub id: String,
|
|
pub tool_name: String,
|
|
pub input: Value,
|
|
pub output: Option<Value>,
|
|
}
|
|
|
|
/// A single accumulated content part, preserving event-stream order.
|
|
#[derive(Debug, Clone)]
|
|
pub enum AccumulatedPart {
|
|
Text { text: String },
|
|
Thinking { text: String },
|
|
Tool { data: ToolCallData },
|
|
}
|
|
|
|
/// A fully accumulated message assembled from streaming chunks.
|
|
#[derive(Debug, Clone)]
|
|
pub struct AccumulatedMessage {
|
|
pub message_id: String,
|
|
pub session_id: String,
|
|
pub connector_id: String,
|
|
pub role: String,
|
|
pub parts: Vec<AccumulatedPart>,
|
|
pub created_at: Option<DateTime<Utc>>,
|
|
pub last_activity: DateTime<Utc>,
|
|
}
|
|
|
|
impl AccumulatedMessage {
|
|
/// Build an `AccumulatedMessage` directly from a slice of [`MessagePart`]s.
|
|
///
|
|
/// This is the path for non-streaming clients that deliver content in a
|
|
/// single `MessageCompleted` event rather than via incremental chunks.
|
|
pub fn from_message_parts(
|
|
message_id: String,
|
|
session_id: String,
|
|
connector_id: String,
|
|
role: String,
|
|
parts: &[MessagePart],
|
|
) -> Self {
|
|
let now = Utc::now();
|
|
let mut accumulated_parts = Vec::new();
|
|
|
|
for part in parts {
|
|
match part {
|
|
MessagePart::Text { text } => {
|
|
if !text.is_empty() {
|
|
accumulated_parts.push(AccumulatedPart::Text { text: text.clone() });
|
|
}
|
|
}
|
|
MessagePart::Thinking { text } => {
|
|
if !text.is_empty() {
|
|
accumulated_parts.push(AccumulatedPart::Thinking { text: text.clone() });
|
|
}
|
|
}
|
|
MessagePart::Code { language, code } => {
|
|
if !code.is_empty() {
|
|
let fenced = format!("```{}\n{}\n```", language, code);
|
|
accumulated_parts.push(AccumulatedPart::Text { text: fenced });
|
|
}
|
|
}
|
|
MessagePart::Tool {
|
|
tool,
|
|
tool_call_id,
|
|
input,
|
|
output,
|
|
} => {
|
|
accumulated_parts.push(AccumulatedPart::Tool {
|
|
data: ToolCallData {
|
|
id: tool_call_id
|
|
.clone()
|
|
.unwrap_or_else(|| String::new()),
|
|
tool_name: tool.clone(),
|
|
input: input.clone(),
|
|
output: output.clone(),
|
|
},
|
|
});
|
|
}
|
|
MessagePart::File { path, content: _ } => {
|
|
let label = format!("\u{1f4c4} File: {}", path);
|
|
accumulated_parts.push(AccumulatedPart::Text { text: label });
|
|
}
|
|
}
|
|
}
|
|
|
|
Self {
|
|
message_id,
|
|
session_id,
|
|
connector_id,
|
|
role,
|
|
parts: accumulated_parts,
|
|
created_at: Some(now),
|
|
last_activity: now,
|
|
}
|
|
}
|
|
|
|
/// Convert accumulated parts back to protocol [`MessagePart`]s.
|
|
pub fn to_message_parts(&self) -> Vec<MessagePart> {
|
|
self.parts
|
|
.iter()
|
|
.map(|part| match part {
|
|
AccumulatedPart::Text { text } => MessagePart::Text { text: text.clone() },
|
|
AccumulatedPart::Thinking { text } => {
|
|
MessagePart::Thinking { text: text.clone() }
|
|
}
|
|
AccumulatedPart::Tool { data } => MessagePart::Tool {
|
|
tool: data.tool_name.clone(),
|
|
tool_call_id: if data.id.is_empty() {
|
|
None
|
|
} else {
|
|
Some(data.id.clone())
|
|
},
|
|
input: data.input.clone(),
|
|
output: data.output.clone(),
|
|
},
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
/// Returns `true` when the message has no accumulated content.
|
|
pub fn is_empty(&self) -> bool {
|
|
self.parts.is_empty()
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Internal buffer
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Buffer for accumulating streaming chunks into a complete message.
|
|
#[derive(Debug)]
|
|
struct MessageBuffer {
|
|
message_id: String,
|
|
session_id: String,
|
|
connector_id: String,
|
|
role: String,
|
|
parts: Vec<AccumulatedPart>,
|
|
created_at: Option<DateTime<Utc>>,
|
|
last_activity: DateTime<Utc>,
|
|
}
|
|
|
|
impl MessageBuffer {
|
|
fn new(message_id: String, session_id: String, connector_id: String, role: String) -> Self {
|
|
let now = Utc::now();
|
|
Self {
|
|
message_id,
|
|
session_id,
|
|
connector_id,
|
|
role,
|
|
parts: Vec::new(),
|
|
created_at: None,
|
|
last_activity: now,
|
|
}
|
|
}
|
|
|
|
fn touch(&mut self) {
|
|
let now = Utc::now();
|
|
if self.created_at.is_none() {
|
|
self.created_at = Some(now);
|
|
}
|
|
self.last_activity = now;
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// MessageAccumulator
|
|
// ---------------------------------------------------------------------------
|
|
|
|
/// Accumulator for assembling streaming message deltas into complete messages.
|
|
///
|
|
/// Each in-flight message is identified by its `message_id` and tracked in an
|
|
/// internal buffer. Text and thinking chunks are coalesced when consecutive;
|
|
/// tool calls are deduplicated by `tool_call_id`.
|
|
#[derive(Debug, Default)]
|
|
pub struct MessageAccumulator {
|
|
buffers: HashMap<String, MessageBuffer>,
|
|
}
|
|
|
|
impl MessageAccumulator {
|
|
/// Create a new, empty accumulator.
|
|
pub fn new() -> Self {
|
|
Self {
|
|
buffers: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Add a content chunk to the message buffer.
|
|
///
|
|
/// Consecutive text chunks are coalesced into a single `AccumulatedPart::Text`.
|
|
pub fn add_chunk(
|
|
&mut self,
|
|
message_id: &str,
|
|
session_id: &str,
|
|
connector_id: &str,
|
|
role: &str,
|
|
content: ContentBlock,
|
|
) {
|
|
let buffer = self
|
|
.buffers
|
|
.entry(message_id.to_string())
|
|
.or_insert_with(|| {
|
|
MessageBuffer::new(
|
|
message_id.to_string(),
|
|
session_id.to_string(),
|
|
connector_id.to_string(),
|
|
role.to_string(),
|
|
)
|
|
});
|
|
|
|
buffer.touch();
|
|
|
|
match content {
|
|
ContentBlock::Text { text } => {
|
|
if let Some(AccumulatedPart::Text { text: existing }) = buffer.parts.last_mut() {
|
|
existing.push_str(&text);
|
|
} else {
|
|
buffer.parts.push(AccumulatedPart::Text { text });
|
|
}
|
|
}
|
|
ContentBlock::ResourceLink { .. } => {
|
|
// ResourceLink is not accumulated as text content for now.
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Add thinking content to the message buffer.
|
|
///
|
|
/// Consecutive thinking chunks are coalesced into a single
|
|
/// `AccumulatedPart::Thinking`.
|
|
pub fn add_thinking(
|
|
&mut self,
|
|
message_id: &str,
|
|
session_id: &str,
|
|
connector_id: &str,
|
|
content: &str,
|
|
) {
|
|
let buffer = self
|
|
.buffers
|
|
.entry(message_id.to_string())
|
|
.or_insert_with(|| {
|
|
MessageBuffer::new(
|
|
message_id.to_string(),
|
|
session_id.to_string(),
|
|
connector_id.to_string(),
|
|
"assistant".to_string(),
|
|
)
|
|
});
|
|
|
|
buffer.touch();
|
|
|
|
if let Some(AccumulatedPart::Thinking { text: existing }) = buffer.parts.last_mut() {
|
|
existing.push_str(content);
|
|
} else {
|
|
buffer
|
|
.parts
|
|
.push(AccumulatedPart::Thinking { text: content.to_string() });
|
|
}
|
|
}
|
|
|
|
/// Add or update a tool call in the message buffer.
|
|
///
|
|
/// If a tool call with the same `id` already exists in the buffer, the
|
|
/// existing entry is updated (input is overwritten only when non-empty;
|
|
/// output is overwritten when `Some`). Otherwise a new entry is appended,
|
|
/// preserving event-stream ordering.
|
|
pub fn add_or_update_tool_call(&mut self, message_id: &str, tool_call: ToolCallData) {
|
|
if let Some(buffer) = self.buffers.get_mut(message_id) {
|
|
buffer.last_activity = Utc::now();
|
|
|
|
// Try to find and update an existing tool call with the same id.
|
|
for part in buffer.parts.iter_mut() {
|
|
if let AccumulatedPart::Tool { data } = part {
|
|
if data.id == tool_call.id {
|
|
data.tool_name = tool_call.tool_name;
|
|
|
|
if tool_call.input != Value::Null
|
|
&& tool_call.input != serde_json::json!({})
|
|
{
|
|
data.input = tool_call.input;
|
|
}
|
|
|
|
if tool_call.output.is_some() {
|
|
data.output = tool_call.output;
|
|
}
|
|
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
// First time seeing this tool_call_id -- append.
|
|
buffer
|
|
.parts
|
|
.push(AccumulatedPart::Tool { data: tool_call });
|
|
}
|
|
}
|
|
|
|
/// Finalize a message and return its accumulated content.
|
|
///
|
|
/// The internal buffer for `message_id` is removed. Returns `None` if no
|
|
/// buffer exists for the given id.
|
|
pub fn finalize(&mut self, message_id: &str) -> Option<AccumulatedMessage> {
|
|
let buffer = self.buffers.remove(message_id)?;
|
|
|
|
Some(AccumulatedMessage {
|
|
message_id: buffer.message_id,
|
|
session_id: buffer.session_id,
|
|
connector_id: buffer.connector_id,
|
|
role: buffer.role,
|
|
parts: buffer.parts,
|
|
created_at: buffer.created_at,
|
|
last_activity: buffer.last_activity,
|
|
})
|
|
}
|
|
|
|
/// Returns `true` if a buffer exists for the given message id.
|
|
pub fn has_buffer(&self, message_id: &str) -> bool {
|
|
self.buffers.contains_key(message_id)
|
|
}
|
|
|
|
/// Return all buffered message ids that belong to `session_id`.
|
|
pub fn message_ids_for_session(&self, session_id: &str) -> Vec<String> {
|
|
self.buffers
|
|
.iter()
|
|
.filter(|(_, buf)| buf.session_id == session_id)
|
|
.map(|(id, _)| id.clone())
|
|
.collect()
|
|
}
|
|
|
|
/// Return all currently buffered message ids.
|
|
pub fn active_message_ids(&self) -> Vec<String> {
|
|
self.buffers.keys().cloned().collect()
|
|
}
|
|
|
|
/// Return message ids whose buffers have not been touched for longer than
|
|
/// `threshold`.
|
|
pub fn stale_message_ids(&self, threshold: std::time::Duration) -> Vec<String> {
|
|
let now = Utc::now();
|
|
self.buffers
|
|
.iter()
|
|
.filter(|(_, buf)| {
|
|
let inactive = now.signed_duration_since(buf.last_activity);
|
|
inactive
|
|
.to_std()
|
|
.unwrap_or(std::time::Duration::ZERO)
|
|
> threshold
|
|
})
|
|
.map(|(id, _)| id.clone())
|
|
.collect()
|
|
}
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Tests
|
|
// ---------------------------------------------------------------------------
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_empty_accumulator() {
|
|
let mut acc = MessageAccumulator::new();
|
|
assert!(acc.finalize("nonexistent").is_none());
|
|
assert!(acc.active_message_ids().is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_text_chunk_coalescing() {
|
|
let mut acc = MessageAccumulator::new();
|
|
|
|
acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text {
|
|
text: "Hello, ".to_string(),
|
|
});
|
|
acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text {
|
|
text: "world!".to_string(),
|
|
});
|
|
|
|
let msg = acc.finalize("msg1").unwrap();
|
|
assert_eq!(msg.parts.len(), 1);
|
|
match &msg.parts[0] {
|
|
AccumulatedPart::Text { text } => assert_eq!(text, "Hello, world!"),
|
|
other => panic!("Expected Text, got {:?}", other),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_thinking_coalescing() {
|
|
let mut acc = MessageAccumulator::new();
|
|
|
|
acc.add_thinking("msg1", "s1", "c1", "First. ");
|
|
acc.add_thinking("msg1", "s1", "c1", "Second.");
|
|
|
|
let msg = acc.finalize("msg1").unwrap();
|
|
assert_eq!(msg.parts.len(), 1);
|
|
match &msg.parts[0] {
|
|
AccumulatedPart::Thinking { text } => assert_eq!(text, "First. Second."),
|
|
other => panic!("Expected Thinking, got {:?}", other),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_interleaved_parts_preserve_order() {
|
|
let mut acc = MessageAccumulator::new();
|
|
|
|
// text, tool, text -- should produce 3 distinct parts
|
|
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
|
text: "Before tool.".to_string(),
|
|
});
|
|
|
|
acc.add_or_update_tool_call("msg1", ToolCallData {
|
|
id: "tc1".to_string(),
|
|
tool_name: "grep".to_string(),
|
|
input: serde_json::json!({"q": "x"}),
|
|
output: None,
|
|
});
|
|
|
|
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
|
text: "After tool.".to_string(),
|
|
});
|
|
|
|
let msg = acc.finalize("msg1").unwrap();
|
|
assert_eq!(msg.parts.len(), 3);
|
|
assert!(matches!(&msg.parts[0], AccumulatedPart::Text { .. }));
|
|
assert!(matches!(&msg.parts[1], AccumulatedPart::Tool { .. }));
|
|
assert!(matches!(&msg.parts[2], AccumulatedPart::Text { .. }));
|
|
}
|
|
|
|
#[test]
|
|
fn test_tool_call_deduplication() {
|
|
let mut acc = MessageAccumulator::new();
|
|
|
|
// Create buffer first
|
|
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
|
text: "hi".to_string(),
|
|
});
|
|
|
|
// Initial tool call
|
|
acc.add_or_update_tool_call("msg1", ToolCallData {
|
|
id: "tc1".to_string(),
|
|
tool_name: "read".to_string(),
|
|
input: serde_json::json!({"path": "foo.rs"}),
|
|
output: None,
|
|
});
|
|
|
|
// Update same tool call with output (empty input should NOT overwrite)
|
|
acc.add_or_update_tool_call("msg1", ToolCallData {
|
|
id: "tc1".to_string(),
|
|
tool_name: "read".to_string(),
|
|
input: serde_json::json!({}),
|
|
output: Some(serde_json::json!({"content": "fn main() {}"})),
|
|
});
|
|
|
|
let msg = acc.finalize("msg1").unwrap();
|
|
// text + 1 tool (not 2)
|
|
assert_eq!(msg.parts.len(), 2);
|
|
|
|
match &msg.parts[1] {
|
|
AccumulatedPart::Tool { data } => {
|
|
assert_eq!(data.id, "tc1");
|
|
// Input preserved from first call (non-empty), not overwritten by empty update
|
|
assert_eq!(data.input, serde_json::json!({"path": "foo.rs"}));
|
|
// Output set from update
|
|
assert_eq!(
|
|
data.output,
|
|
Some(serde_json::json!({"content": "fn main() {}"}))
|
|
);
|
|
}
|
|
other => panic!("Expected Tool, got {:?}", other),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_from_message_parts_non_streaming() {
|
|
let parts = vec![
|
|
MessagePart::Text {
|
|
text: "Hello".to_string(),
|
|
},
|
|
MessagePart::Thinking {
|
|
text: "hmm".to_string(),
|
|
},
|
|
MessagePart::Code {
|
|
language: "rs".to_string(),
|
|
code: "fn main() {}".to_string(),
|
|
},
|
|
MessagePart::Tool {
|
|
tool: "grep".to_string(),
|
|
tool_call_id: Some("tc1".to_string()),
|
|
input: serde_json::json!({"q": "x"}),
|
|
output: Some(serde_json::json!("found")),
|
|
},
|
|
MessagePart::File {
|
|
path: "README.md".to_string(),
|
|
content: "# Title".to_string(),
|
|
},
|
|
// Empty text and code should be skipped
|
|
MessagePart::Text {
|
|
text: String::new(),
|
|
},
|
|
MessagePart::Code {
|
|
language: "py".to_string(),
|
|
code: String::new(),
|
|
},
|
|
];
|
|
|
|
let msg = AccumulatedMessage::from_message_parts(
|
|
"msg1".into(),
|
|
"s1".into(),
|
|
"c1".into(),
|
|
"assistant".into(),
|
|
&parts,
|
|
);
|
|
|
|
// 5 non-empty parts: text, thinking, code-as-text, tool, file-as-text
|
|
assert_eq!(msg.parts.len(), 5);
|
|
|
|
match &msg.parts[0] {
|
|
AccumulatedPart::Text { text } => assert_eq!(text, "Hello"),
|
|
other => panic!("Expected Text, got {:?}", other),
|
|
}
|
|
match &msg.parts[1] {
|
|
AccumulatedPart::Thinking { text } => assert_eq!(text, "hmm"),
|
|
other => panic!("Expected Thinking, got {:?}", other),
|
|
}
|
|
match &msg.parts[2] {
|
|
AccumulatedPart::Text { text } => {
|
|
assert!(text.contains("```rs"));
|
|
assert!(text.contains("fn main() {}"));
|
|
}
|
|
other => panic!("Expected Text (code), got {:?}", other),
|
|
}
|
|
match &msg.parts[3] {
|
|
AccumulatedPart::Tool { data } => {
|
|
assert_eq!(data.tool_name, "grep");
|
|
assert_eq!(data.id, "tc1");
|
|
assert_eq!(data.output, Some(serde_json::json!("found")));
|
|
}
|
|
other => panic!("Expected Tool, got {:?}", other),
|
|
}
|
|
match &msg.parts[4] {
|
|
AccumulatedPart::Text { text } => {
|
|
assert!(text.contains("File: README.md"));
|
|
}
|
|
other => panic!("Expected Text (file), got {:?}", other),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_session_and_stale_queries() {
|
|
let mut acc = MessageAccumulator::new();
|
|
|
|
acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text {
|
|
text: "a".to_string(),
|
|
});
|
|
acc.add_chunk("msg2", "s1", "c1", "assistant", ContentBlock::Text {
|
|
text: "b".to_string(),
|
|
});
|
|
acc.add_chunk("msg3", "s2", "c1", "user", ContentBlock::Text {
|
|
text: "c".to_string(),
|
|
});
|
|
|
|
// message_ids_for_session
|
|
let mut s1_ids = acc.message_ids_for_session("s1");
|
|
s1_ids.sort();
|
|
assert_eq!(s1_ids, vec!["msg1", "msg2"]);
|
|
|
|
let s2_ids = acc.message_ids_for_session("s2");
|
|
assert_eq!(s2_ids, vec!["msg3"]);
|
|
|
|
assert!(acc.message_ids_for_session("s3").is_empty());
|
|
|
|
// active_message_ids
|
|
let mut all = acc.active_message_ids();
|
|
all.sort();
|
|
assert_eq!(all, vec!["msg1", "msg2", "msg3"]);
|
|
|
|
// has_buffer
|
|
assert!(acc.has_buffer("msg1"));
|
|
assert!(!acc.has_buffer("msg99"));
|
|
|
|
// stale_message_ids with zero threshold -- everything is stale
|
|
// (last_activity <= now)
|
|
let _stale_zero = acc.stale_message_ids(std::time::Duration::ZERO);
|
|
// All three should be considered stale since last_activity <= now
|
|
// (Due to timing, they might not all be strictly < now, so we check
|
|
// with a small threshold instead.)
|
|
let stale_lenient = acc.stale_message_ids(std::time::Duration::from_secs(0));
|
|
// At minimum, none should be stale with a huge threshold
|
|
let not_stale = acc.stale_message_ids(std::time::Duration::from_secs(3600));
|
|
assert!(not_stale.is_empty());
|
|
|
|
// Verify stale detection works by checking the lenient case doesn't
|
|
// return more ids than we have buffers
|
|
assert!(stale_lenient.len() <= 3);
|
|
}
|
|
|
|
#[test]
|
|
fn test_to_message_parts() {
|
|
let mut acc = MessageAccumulator::new();
|
|
|
|
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
|
text: "Hello ".to_string(),
|
|
});
|
|
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
|
text: "world.".to_string(),
|
|
});
|
|
acc.add_thinking("msg1", "s1", "c1", "thinking...");
|
|
acc.add_or_update_tool_call("msg1", ToolCallData {
|
|
id: "tc1".to_string(),
|
|
tool_name: "search".to_string(),
|
|
input: serde_json::json!({"q": "test"}),
|
|
output: Some(serde_json::json!("result")),
|
|
});
|
|
|
|
let msg = acc.finalize("msg1").unwrap();
|
|
let parts = msg.to_message_parts();
|
|
|
|
assert_eq!(parts.len(), 3);
|
|
|
|
// Coalesced text
|
|
match &parts[0] {
|
|
MessagePart::Text { text } => assert_eq!(text, "Hello world."),
|
|
other => panic!("Expected Text, got {:?}", other),
|
|
}
|
|
|
|
// Thinking
|
|
match &parts[1] {
|
|
MessagePart::Thinking { text } => assert_eq!(text, "thinking..."),
|
|
other => panic!("Expected Thinking, got {:?}", other),
|
|
}
|
|
|
|
// Tool roundtrip
|
|
match &parts[2] {
|
|
MessagePart::Tool {
|
|
tool,
|
|
tool_call_id,
|
|
input,
|
|
output,
|
|
} => {
|
|
assert_eq!(tool, "search");
|
|
assert_eq!(tool_call_id, &Some("tc1".to_string()));
|
|
assert_eq!(input, &serde_json::json!({"q": "test"}));
|
|
assert_eq!(output, &Some(serde_json::json!("result")));
|
|
}
|
|
other => panic!("Expected Tool, got {:?}", other),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_is_empty() {
|
|
let msg = AccumulatedMessage::from_message_parts(
|
|
"msg1".into(),
|
|
"s1".into(),
|
|
"c1".into(),
|
|
"user".into(),
|
|
&[],
|
|
);
|
|
assert!(msg.is_empty());
|
|
|
|
let msg2 = AccumulatedMessage::from_message_parts(
|
|
"msg2".into(),
|
|
"s1".into(),
|
|
"c1".into(),
|
|
"user".into(),
|
|
&[MessagePart::Text {
|
|
text: "hi".to_string(),
|
|
}],
|
|
);
|
|
assert!(!msg2.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_to_message_parts_empty_tool_id() {
|
|
// Tool with no tool_call_id should roundtrip as None
|
|
let parts = vec![MessagePart::Tool {
|
|
tool: "bash".to_string(),
|
|
tool_call_id: None,
|
|
input: serde_json::json!({"cmd": "ls"}),
|
|
output: None,
|
|
}];
|
|
|
|
let msg = AccumulatedMessage::from_message_parts(
|
|
"msg1".into(),
|
|
"s1".into(),
|
|
"c1".into(),
|
|
"assistant".into(),
|
|
&parts,
|
|
);
|
|
|
|
let roundtripped = msg.to_message_parts();
|
|
match &roundtripped[0] {
|
|
MessagePart::Tool { tool_call_id, .. } => {
|
|
assert_eq!(tool_call_id, &None);
|
|
}
|
|
other => panic!("Expected Tool, got {:?}", other),
|
|
}
|
|
}
|
|
}
|