sync from monorepo @ 2452e92e
This commit is contained in:
@@ -0,0 +1,717 @@
|
||||
//! Protocol-level message accumulator for incremental message assembly.
|
||||
//!
|
||||
//! Handles streaming message deltas and assembles them into complete
|
||||
//! [`AccumulatedMessage`] values using protocol types. Unlike the archivist's
|
||||
//! accumulator, this module produces protocol-native output without UUID
|
||||
//! parsing, markdown generation, or storage metadata.
|
||||
//!
|
||||
//! The accumulator preserves the order of content parts (text, thinking, tool
|
||||
//! calls) as they arrive in the event stream, enabling inline tool rendering
|
||||
//! and faithful forwarding to downstream consumers.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::conversation::MessagePart;
|
||||
use crate::types::ContentBlock;
|
||||
|
||||
/// Tool call data accumulated during streaming.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallData {
|
||||
pub id: String,
|
||||
pub tool_name: String,
|
||||
pub input: Value,
|
||||
pub output: Option<Value>,
|
||||
}
|
||||
|
||||
/// A single accumulated content part, preserving event-stream order.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AccumulatedPart {
|
||||
Text { text: String },
|
||||
Thinking { text: String },
|
||||
Tool { data: ToolCallData },
|
||||
}
|
||||
|
||||
/// A fully accumulated message assembled from streaming chunks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AccumulatedMessage {
|
||||
pub message_id: String,
|
||||
pub session_id: String,
|
||||
pub connector_id: String,
|
||||
pub role: String,
|
||||
pub parts: Vec<AccumulatedPart>,
|
||||
pub created_at: Option<DateTime<Utc>>,
|
||||
pub last_activity: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl AccumulatedMessage {
|
||||
/// Build an `AccumulatedMessage` directly from a slice of [`MessagePart`]s.
|
||||
///
|
||||
/// This is the path for non-streaming clients that deliver content in a
|
||||
/// single `MessageCompleted` event rather than via incremental chunks.
|
||||
pub fn from_message_parts(
|
||||
message_id: String,
|
||||
session_id: String,
|
||||
connector_id: String,
|
||||
role: String,
|
||||
parts: &[MessagePart],
|
||||
) -> Self {
|
||||
let now = Utc::now();
|
||||
let mut accumulated_parts = Vec::new();
|
||||
|
||||
for part in parts {
|
||||
match part {
|
||||
MessagePart::Text { text } => {
|
||||
if !text.is_empty() {
|
||||
accumulated_parts.push(AccumulatedPart::Text { text: text.clone() });
|
||||
}
|
||||
}
|
||||
MessagePart::Thinking { text } => {
|
||||
if !text.is_empty() {
|
||||
accumulated_parts.push(AccumulatedPart::Thinking { text: text.clone() });
|
||||
}
|
||||
}
|
||||
MessagePart::Code { language, code } => {
|
||||
if !code.is_empty() {
|
||||
let fenced = format!("```{}\n{}\n```", language, code);
|
||||
accumulated_parts.push(AccumulatedPart::Text { text: fenced });
|
||||
}
|
||||
}
|
||||
MessagePart::Tool {
|
||||
tool,
|
||||
tool_call_id,
|
||||
input,
|
||||
output,
|
||||
} => {
|
||||
accumulated_parts.push(AccumulatedPart::Tool {
|
||||
data: ToolCallData {
|
||||
id: tool_call_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| String::new()),
|
||||
tool_name: tool.clone(),
|
||||
input: input.clone(),
|
||||
output: output.clone(),
|
||||
},
|
||||
});
|
||||
}
|
||||
MessagePart::File { path, content: _ } => {
|
||||
let label = format!("\u{1f4c4} File: {}", path);
|
||||
accumulated_parts.push(AccumulatedPart::Text { text: label });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
message_id,
|
||||
session_id,
|
||||
connector_id,
|
||||
role,
|
||||
parts: accumulated_parts,
|
||||
created_at: Some(now),
|
||||
last_activity: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert accumulated parts back to protocol [`MessagePart`]s.
|
||||
pub fn to_message_parts(&self) -> Vec<MessagePart> {
|
||||
self.parts
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
AccumulatedPart::Text { text } => MessagePart::Text { text: text.clone() },
|
||||
AccumulatedPart::Thinking { text } => {
|
||||
MessagePart::Thinking { text: text.clone() }
|
||||
}
|
||||
AccumulatedPart::Tool { data } => MessagePart::Tool {
|
||||
tool: data.tool_name.clone(),
|
||||
tool_call_id: if data.id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(data.id.clone())
|
||||
},
|
||||
input: data.input.clone(),
|
||||
output: data.output.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Returns `true` when the message has no accumulated content.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.parts.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal buffer
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Buffer for accumulating streaming chunks into a complete message.
|
||||
#[derive(Debug)]
|
||||
struct MessageBuffer {
|
||||
message_id: String,
|
||||
session_id: String,
|
||||
connector_id: String,
|
||||
role: String,
|
||||
parts: Vec<AccumulatedPart>,
|
||||
created_at: Option<DateTime<Utc>>,
|
||||
last_activity: DateTime<Utc>,
|
||||
}
|
||||
|
||||
impl MessageBuffer {
|
||||
fn new(message_id: String, session_id: String, connector_id: String, role: String) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
message_id,
|
||||
session_id,
|
||||
connector_id,
|
||||
role,
|
||||
parts: Vec::new(),
|
||||
created_at: None,
|
||||
last_activity: now,
|
||||
}
|
||||
}
|
||||
|
||||
fn touch(&mut self) {
|
||||
let now = Utc::now();
|
||||
if self.created_at.is_none() {
|
||||
self.created_at = Some(now);
|
||||
}
|
||||
self.last_activity = now;
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MessageAccumulator
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Accumulator for assembling streaming message deltas into complete messages.
|
||||
///
|
||||
/// Each in-flight message is identified by its `message_id` and tracked in an
|
||||
/// internal buffer. Text and thinking chunks are coalesced when consecutive;
|
||||
/// tool calls are deduplicated by `tool_call_id`.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MessageAccumulator {
|
||||
buffers: HashMap<String, MessageBuffer>,
|
||||
}
|
||||
|
||||
impl MessageAccumulator {
|
||||
/// Create a new, empty accumulator.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
buffers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a content chunk to the message buffer.
|
||||
///
|
||||
/// Consecutive text chunks are coalesced into a single `AccumulatedPart::Text`.
|
||||
pub fn add_chunk(
|
||||
&mut self,
|
||||
message_id: &str,
|
||||
session_id: &str,
|
||||
connector_id: &str,
|
||||
role: &str,
|
||||
content: ContentBlock,
|
||||
) {
|
||||
let buffer = self
|
||||
.buffers
|
||||
.entry(message_id.to_string())
|
||||
.or_insert_with(|| {
|
||||
MessageBuffer::new(
|
||||
message_id.to_string(),
|
||||
session_id.to_string(),
|
||||
connector_id.to_string(),
|
||||
role.to_string(),
|
||||
)
|
||||
});
|
||||
|
||||
buffer.touch();
|
||||
|
||||
match content {
|
||||
ContentBlock::Text { text } => {
|
||||
if let Some(AccumulatedPart::Text { text: existing }) = buffer.parts.last_mut() {
|
||||
existing.push_str(&text);
|
||||
} else {
|
||||
buffer.parts.push(AccumulatedPart::Text { text });
|
||||
}
|
||||
}
|
||||
ContentBlock::ResourceLink { .. } => {
|
||||
// ResourceLink is not accumulated as text content for now.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add thinking content to the message buffer.
|
||||
///
|
||||
/// Consecutive thinking chunks are coalesced into a single
|
||||
/// `AccumulatedPart::Thinking`.
|
||||
pub fn add_thinking(
|
||||
&mut self,
|
||||
message_id: &str,
|
||||
session_id: &str,
|
||||
connector_id: &str,
|
||||
content: &str,
|
||||
) {
|
||||
let buffer = self
|
||||
.buffers
|
||||
.entry(message_id.to_string())
|
||||
.or_insert_with(|| {
|
||||
MessageBuffer::new(
|
||||
message_id.to_string(),
|
||||
session_id.to_string(),
|
||||
connector_id.to_string(),
|
||||
"assistant".to_string(),
|
||||
)
|
||||
});
|
||||
|
||||
buffer.touch();
|
||||
|
||||
if let Some(AccumulatedPart::Thinking { text: existing }) = buffer.parts.last_mut() {
|
||||
existing.push_str(content);
|
||||
} else {
|
||||
buffer
|
||||
.parts
|
||||
.push(AccumulatedPart::Thinking { text: content.to_string() });
|
||||
}
|
||||
}
|
||||
|
||||
/// Add or update a tool call in the message buffer.
|
||||
///
|
||||
/// If a tool call with the same `id` already exists in the buffer, the
|
||||
/// existing entry is updated (input is overwritten only when non-empty;
|
||||
/// output is overwritten when `Some`). Otherwise a new entry is appended,
|
||||
/// preserving event-stream ordering.
|
||||
pub fn add_or_update_tool_call(&mut self, message_id: &str, tool_call: ToolCallData) {
|
||||
if let Some(buffer) = self.buffers.get_mut(message_id) {
|
||||
buffer.last_activity = Utc::now();
|
||||
|
||||
// Try to find and update an existing tool call with the same id.
|
||||
for part in buffer.parts.iter_mut() {
|
||||
if let AccumulatedPart::Tool { data } = part {
|
||||
if data.id == tool_call.id {
|
||||
data.tool_name = tool_call.tool_name;
|
||||
|
||||
if tool_call.input != Value::Null
|
||||
&& tool_call.input != serde_json::json!({})
|
||||
{
|
||||
data.input = tool_call.input;
|
||||
}
|
||||
|
||||
if tool_call.output.is_some() {
|
||||
data.output = tool_call.output;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// First time seeing this tool_call_id -- append.
|
||||
buffer
|
||||
.parts
|
||||
.push(AccumulatedPart::Tool { data: tool_call });
|
||||
}
|
||||
}
|
||||
|
||||
/// Finalize a message and return its accumulated content.
|
||||
///
|
||||
/// The internal buffer for `message_id` is removed. Returns `None` if no
|
||||
/// buffer exists for the given id.
|
||||
pub fn finalize(&mut self, message_id: &str) -> Option<AccumulatedMessage> {
|
||||
let buffer = self.buffers.remove(message_id)?;
|
||||
|
||||
Some(AccumulatedMessage {
|
||||
message_id: buffer.message_id,
|
||||
session_id: buffer.session_id,
|
||||
connector_id: buffer.connector_id,
|
||||
role: buffer.role,
|
||||
parts: buffer.parts,
|
||||
created_at: buffer.created_at,
|
||||
last_activity: buffer.last_activity,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns `true` if a buffer exists for the given message id.
|
||||
pub fn has_buffer(&self, message_id: &str) -> bool {
|
||||
self.buffers.contains_key(message_id)
|
||||
}
|
||||
|
||||
/// Return all buffered message ids that belong to `session_id`.
|
||||
pub fn message_ids_for_session(&self, session_id: &str) -> Vec<String> {
|
||||
self.buffers
|
||||
.iter()
|
||||
.filter(|(_, buf)| buf.session_id == session_id)
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Return all currently buffered message ids.
|
||||
pub fn active_message_ids(&self) -> Vec<String> {
|
||||
self.buffers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Return message ids whose buffers have not been touched for longer than
|
||||
/// `threshold`.
|
||||
pub fn stale_message_ids(&self, threshold: std::time::Duration) -> Vec<String> {
|
||||
let now = Utc::now();
|
||||
self.buffers
|
||||
.iter()
|
||||
.filter(|(_, buf)| {
|
||||
let inactive = now.signed_duration_since(buf.last_activity);
|
||||
inactive
|
||||
.to_std()
|
||||
.unwrap_or(std::time::Duration::ZERO)
|
||||
> threshold
|
||||
})
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_empty_accumulator() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
assert!(acc.finalize("nonexistent").is_none());
|
||||
assert!(acc.active_message_ids().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_chunk_coalescing() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text {
|
||||
text: "Hello, ".to_string(),
|
||||
});
|
||||
acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text {
|
||||
text: "world!".to_string(),
|
||||
});
|
||||
|
||||
let msg = acc.finalize("msg1").unwrap();
|
||||
assert_eq!(msg.parts.len(), 1);
|
||||
match &msg.parts[0] {
|
||||
AccumulatedPart::Text { text } => assert_eq!(text, "Hello, world!"),
|
||||
other => panic!("Expected Text, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thinking_coalescing() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
acc.add_thinking("msg1", "s1", "c1", "First. ");
|
||||
acc.add_thinking("msg1", "s1", "c1", "Second.");
|
||||
|
||||
let msg = acc.finalize("msg1").unwrap();
|
||||
assert_eq!(msg.parts.len(), 1);
|
||||
match &msg.parts[0] {
|
||||
AccumulatedPart::Thinking { text } => assert_eq!(text, "First. Second."),
|
||||
other => panic!("Expected Thinking, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_interleaved_parts_preserve_order() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
// text, tool, text -- should produce 3 distinct parts
|
||||
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "Before tool.".to_string(),
|
||||
});
|
||||
|
||||
acc.add_or_update_tool_call("msg1", ToolCallData {
|
||||
id: "tc1".to_string(),
|
||||
tool_name: "grep".to_string(),
|
||||
input: serde_json::json!({"q": "x"}),
|
||||
output: None,
|
||||
});
|
||||
|
||||
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "After tool.".to_string(),
|
||||
});
|
||||
|
||||
let msg = acc.finalize("msg1").unwrap();
|
||||
assert_eq!(msg.parts.len(), 3);
|
||||
assert!(matches!(&msg.parts[0], AccumulatedPart::Text { .. }));
|
||||
assert!(matches!(&msg.parts[1], AccumulatedPart::Tool { .. }));
|
||||
assert!(matches!(&msg.parts[2], AccumulatedPart::Text { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_deduplication() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
// Create buffer first
|
||||
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "hi".to_string(),
|
||||
});
|
||||
|
||||
// Initial tool call
|
||||
acc.add_or_update_tool_call("msg1", ToolCallData {
|
||||
id: "tc1".to_string(),
|
||||
tool_name: "read".to_string(),
|
||||
input: serde_json::json!({"path": "foo.rs"}),
|
||||
output: None,
|
||||
});
|
||||
|
||||
// Update same tool call with output (empty input should NOT overwrite)
|
||||
acc.add_or_update_tool_call("msg1", ToolCallData {
|
||||
id: "tc1".to_string(),
|
||||
tool_name: "read".to_string(),
|
||||
input: serde_json::json!({}),
|
||||
output: Some(serde_json::json!({"content": "fn main() {}"})),
|
||||
});
|
||||
|
||||
let msg = acc.finalize("msg1").unwrap();
|
||||
// text + 1 tool (not 2)
|
||||
assert_eq!(msg.parts.len(), 2);
|
||||
|
||||
match &msg.parts[1] {
|
||||
AccumulatedPart::Tool { data } => {
|
||||
assert_eq!(data.id, "tc1");
|
||||
// Input preserved from first call (non-empty), not overwritten by empty update
|
||||
assert_eq!(data.input, serde_json::json!({"path": "foo.rs"}));
|
||||
// Output set from update
|
||||
assert_eq!(
|
||||
data.output,
|
||||
Some(serde_json::json!({"content": "fn main() {}"}))
|
||||
);
|
||||
}
|
||||
other => panic!("Expected Tool, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_message_parts_non_streaming() {
|
||||
let parts = vec![
|
||||
MessagePart::Text {
|
||||
text: "Hello".to_string(),
|
||||
},
|
||||
MessagePart::Thinking {
|
||||
text: "hmm".to_string(),
|
||||
},
|
||||
MessagePart::Code {
|
||||
language: "rs".to_string(),
|
||||
code: "fn main() {}".to_string(),
|
||||
},
|
||||
MessagePart::Tool {
|
||||
tool: "grep".to_string(),
|
||||
tool_call_id: Some("tc1".to_string()),
|
||||
input: serde_json::json!({"q": "x"}),
|
||||
output: Some(serde_json::json!("found")),
|
||||
},
|
||||
MessagePart::File {
|
||||
path: "README.md".to_string(),
|
||||
content: "# Title".to_string(),
|
||||
},
|
||||
// Empty text and code should be skipped
|
||||
MessagePart::Text {
|
||||
text: String::new(),
|
||||
},
|
||||
MessagePart::Code {
|
||||
language: "py".to_string(),
|
||||
code: String::new(),
|
||||
},
|
||||
];
|
||||
|
||||
let msg = AccumulatedMessage::from_message_parts(
|
||||
"msg1".into(),
|
||||
"s1".into(),
|
||||
"c1".into(),
|
||||
"assistant".into(),
|
||||
&parts,
|
||||
);
|
||||
|
||||
// 5 non-empty parts: text, thinking, code-as-text, tool, file-as-text
|
||||
assert_eq!(msg.parts.len(), 5);
|
||||
|
||||
match &msg.parts[0] {
|
||||
AccumulatedPart::Text { text } => assert_eq!(text, "Hello"),
|
||||
other => panic!("Expected Text, got {:?}", other),
|
||||
}
|
||||
match &msg.parts[1] {
|
||||
AccumulatedPart::Thinking { text } => assert_eq!(text, "hmm"),
|
||||
other => panic!("Expected Thinking, got {:?}", other),
|
||||
}
|
||||
match &msg.parts[2] {
|
||||
AccumulatedPart::Text { text } => {
|
||||
assert!(text.contains("```rs"));
|
||||
assert!(text.contains("fn main() {}"));
|
||||
}
|
||||
other => panic!("Expected Text (code), got {:?}", other),
|
||||
}
|
||||
match &msg.parts[3] {
|
||||
AccumulatedPart::Tool { data } => {
|
||||
assert_eq!(data.tool_name, "grep");
|
||||
assert_eq!(data.id, "tc1");
|
||||
assert_eq!(data.output, Some(serde_json::json!("found")));
|
||||
}
|
||||
other => panic!("Expected Tool, got {:?}", other),
|
||||
}
|
||||
match &msg.parts[4] {
|
||||
AccumulatedPart::Text { text } => {
|
||||
assert!(text.contains("File: README.md"));
|
||||
}
|
||||
other => panic!("Expected Text (file), got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_and_stale_queries() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
acc.add_chunk("msg1", "s1", "c1", "user", ContentBlock::Text {
|
||||
text: "a".to_string(),
|
||||
});
|
||||
acc.add_chunk("msg2", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "b".to_string(),
|
||||
});
|
||||
acc.add_chunk("msg3", "s2", "c1", "user", ContentBlock::Text {
|
||||
text: "c".to_string(),
|
||||
});
|
||||
|
||||
// message_ids_for_session
|
||||
let mut s1_ids = acc.message_ids_for_session("s1");
|
||||
s1_ids.sort();
|
||||
assert_eq!(s1_ids, vec!["msg1", "msg2"]);
|
||||
|
||||
let s2_ids = acc.message_ids_for_session("s2");
|
||||
assert_eq!(s2_ids, vec!["msg3"]);
|
||||
|
||||
assert!(acc.message_ids_for_session("s3").is_empty());
|
||||
|
||||
// active_message_ids
|
||||
let mut all = acc.active_message_ids();
|
||||
all.sort();
|
||||
assert_eq!(all, vec!["msg1", "msg2", "msg3"]);
|
||||
|
||||
// has_buffer
|
||||
assert!(acc.has_buffer("msg1"));
|
||||
assert!(!acc.has_buffer("msg99"));
|
||||
|
||||
// stale_message_ids with zero threshold -- everything is stale
|
||||
// (last_activity <= now)
|
||||
let _stale_zero = acc.stale_message_ids(std::time::Duration::ZERO);
|
||||
// All three should be considered stale since last_activity <= now
|
||||
// (Due to timing, they might not all be strictly < now, so we check
|
||||
// with a small threshold instead.)
|
||||
let stale_lenient = acc.stale_message_ids(std::time::Duration::from_secs(0));
|
||||
// At minimum, none should be stale with a huge threshold
|
||||
let not_stale = acc.stale_message_ids(std::time::Duration::from_secs(3600));
|
||||
assert!(not_stale.is_empty());
|
||||
|
||||
// Verify stale detection works by checking the lenient case doesn't
|
||||
// return more ids than we have buffers
|
||||
assert!(stale_lenient.len() <= 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_message_parts() {
|
||||
let mut acc = MessageAccumulator::new();
|
||||
|
||||
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "Hello ".to_string(),
|
||||
});
|
||||
acc.add_chunk("msg1", "s1", "c1", "assistant", ContentBlock::Text {
|
||||
text: "world.".to_string(),
|
||||
});
|
||||
acc.add_thinking("msg1", "s1", "c1", "thinking...");
|
||||
acc.add_or_update_tool_call("msg1", ToolCallData {
|
||||
id: "tc1".to_string(),
|
||||
tool_name: "search".to_string(),
|
||||
input: serde_json::json!({"q": "test"}),
|
||||
output: Some(serde_json::json!("result")),
|
||||
});
|
||||
|
||||
let msg = acc.finalize("msg1").unwrap();
|
||||
let parts = msg.to_message_parts();
|
||||
|
||||
assert_eq!(parts.len(), 3);
|
||||
|
||||
// Coalesced text
|
||||
match &parts[0] {
|
||||
MessagePart::Text { text } => assert_eq!(text, "Hello world."),
|
||||
other => panic!("Expected Text, got {:?}", other),
|
||||
}
|
||||
|
||||
// Thinking
|
||||
match &parts[1] {
|
||||
MessagePart::Thinking { text } => assert_eq!(text, "thinking..."),
|
||||
other => panic!("Expected Thinking, got {:?}", other),
|
||||
}
|
||||
|
||||
// Tool roundtrip
|
||||
match &parts[2] {
|
||||
MessagePart::Tool {
|
||||
tool,
|
||||
tool_call_id,
|
||||
input,
|
||||
output,
|
||||
} => {
|
||||
assert_eq!(tool, "search");
|
||||
assert_eq!(tool_call_id, &Some("tc1".to_string()));
|
||||
assert_eq!(input, &serde_json::json!({"q": "test"}));
|
||||
assert_eq!(output, &Some(serde_json::json!("result")));
|
||||
}
|
||||
other => panic!("Expected Tool, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_empty() {
|
||||
let msg = AccumulatedMessage::from_message_parts(
|
||||
"msg1".into(),
|
||||
"s1".into(),
|
||||
"c1".into(),
|
||||
"user".into(),
|
||||
&[],
|
||||
);
|
||||
assert!(msg.is_empty());
|
||||
|
||||
let msg2 = AccumulatedMessage::from_message_parts(
|
||||
"msg2".into(),
|
||||
"s1".into(),
|
||||
"c1".into(),
|
||||
"user".into(),
|
||||
&[MessagePart::Text {
|
||||
text: "hi".to_string(),
|
||||
}],
|
||||
);
|
||||
assert!(!msg2.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_message_parts_empty_tool_id() {
|
||||
// Tool with no tool_call_id should roundtrip as None
|
||||
let parts = vec![MessagePart::Tool {
|
||||
tool: "bash".to_string(),
|
||||
tool_call_id: None,
|
||||
input: serde_json::json!({"cmd": "ls"}),
|
||||
output: None,
|
||||
}];
|
||||
|
||||
let msg = AccumulatedMessage::from_message_parts(
|
||||
"msg1".into(),
|
||||
"s1".into(),
|
||||
"c1".into(),
|
||||
"assistant".into(),
|
||||
&parts,
|
||||
);
|
||||
|
||||
let roundtripped = msg.to_message_parts();
|
||||
match &roundtripped[0] {
|
||||
MessagePart::Tool { tool_call_id, .. } => {
|
||||
assert_eq!(tool_call_id, &None);
|
||||
}
|
||||
other => panic!("Expected Tool, got {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user