sync from monorepo @ 2452e92e
This commit is contained in:
@@ -0,0 +1,107 @@
|
||||
//! Tool call correlation — matches assistant ToolUse blocks with their
|
||||
//! corresponding user ToolResult blocks by ID across a message sequence.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::types::{
|
||||
Content, ContentBlock, RawAssistantMessage, RawMessage, RawUserMessage, ToolCall,
|
||||
ToolExchange, ToolName, ToolResultData,
|
||||
};
|
||||
|
||||
/// Extract tool calls from an assistant message's content blocks.
|
||||
fn extract_tool_calls(msg: &RawAssistantMessage) -> Vec<ToolCall> {
|
||||
let source_uuid = msg.uuid.clone().unwrap_or_default();
|
||||
msg.message
|
||||
.content
|
||||
.iter()
|
||||
.filter_map(|block| {
|
||||
if let ContentBlock::ToolUse { id, name, input, .. } = block {
|
||||
Some(ToolCall {
|
||||
id: id.clone(),
|
||||
name: ToolName::from(name.clone()),
|
||||
input: input.clone(),
|
||||
source_message_uuid: source_uuid.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Extract tool results from a user message's content blocks.
|
||||
fn extract_tool_results(msg: &RawUserMessage) -> Vec<ToolResultData> {
|
||||
let source_uuid = msg.uuid.clone().unwrap_or_default();
|
||||
match &msg.message.content {
|
||||
Content::Blocks(blocks) => blocks
|
||||
.iter()
|
||||
.filter_map(|block| {
|
||||
if let ContentBlock::ToolResult { tool_use_id, content, is_error } = block {
|
||||
// Extract text content from the tool result
|
||||
let text_content = content.as_ref().and_then(|c| match c {
|
||||
Content::Text(s) => Some(s.clone()),
|
||||
Content::Blocks(bs) => {
|
||||
// Concatenate text blocks
|
||||
let texts: Vec<&str> = bs
|
||||
.iter()
|
||||
.filter_map(|b| {
|
||||
if let ContentBlock::Text { text } = b {
|
||||
Some(text.as_str())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
if texts.is_empty() { None } else { Some(texts.join("\n")) }
|
||||
}
|
||||
});
|
||||
Some(ToolResultData {
|
||||
tool_use_id: tool_use_id.clone(),
|
||||
content: text_content,
|
||||
is_error: *is_error,
|
||||
source_message_uuid: source_uuid.clone(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
Content::Text(_) => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Correlate tool calls with their results across a message sequence.
|
||||
///
|
||||
/// Iterates messages in order, collecting ToolUse blocks from assistant
|
||||
/// messages and matching them by ID to ToolResult blocks in subsequent user
|
||||
/// messages. Any tool calls that never received a result are emitted with
|
||||
/// `result: None`.
|
||||
pub fn correlate_tools(messages: &[RawMessage]) -> Vec<ToolExchange> {
|
||||
let mut pending: HashMap<String, ToolCall> = HashMap::new();
|
||||
let mut exchanges: Vec<ToolExchange> = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
match msg {
|
||||
RawMessage::Assistant(asst) => {
|
||||
for call in extract_tool_calls(asst) {
|
||||
pending.insert(call.id.clone(), call);
|
||||
}
|
||||
}
|
||||
RawMessage::User(user) => {
|
||||
for result in extract_tool_results(user) {
|
||||
if let Some(call) = pending.remove(&result.tool_use_id) {
|
||||
exchanges.push(ToolExchange { call, result: Some(result) });
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Emit unmatched calls (no result found)
|
||||
for (_id, call) in pending {
|
||||
exchanges.push(ToolExchange { call, result: None });
|
||||
}
|
||||
|
||||
exchanges
|
||||
}
|
||||
Reference in New Issue
Block a user