135 lines
4.1 KiB
Rust
135 lines
4.1 KiB
Rust
//! Neutral tool events. Connector adapters translate to wire types.
|
|
|
|
use crate::tool::ToolKind;
|
|
use bytes::Bytes;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::sync::Arc;
|
|
use tokio::sync::mpsc;
|
|
|
|
/// Opaque permission-request id, allocated by the dispatcher.
|
|
#[derive(Clone, Debug, Eq, PartialEq, Hash, Serialize, Deserialize)]
|
|
pub struct PermissionRequestId(Arc<str>);
|
|
|
|
impl PermissionRequestId {
|
|
/// Construct a new permission-request id from any string-like value.
|
|
pub fn new(value: impl Into<Arc<str>>) -> Self {
|
|
Self(value.into())
|
|
}
|
|
|
|
/// Borrow the inner id as a string slice.
|
|
pub fn as_str(&self) -> &str {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
/// Where a tool is operating (file path + optional line).
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct ToolLocation {
|
|
pub path: String,
|
|
pub line: Option<u32>,
|
|
}
|
|
|
|
/// Result content shape. Mirrors what most providers accept as a tool result.
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
pub enum ToolResultContent {
|
|
Text { text: Arc<str> },
|
|
Json { value: serde_json::Value },
|
|
Image { mime: Arc<str>, #[serde(with = "serde_bytes_arc")] data: Bytes },
|
|
Parts { parts: Vec<ToolResultContent> },
|
|
}
|
|
|
|
impl ToolResultContent {
|
|
pub fn text(s: impl Into<Arc<str>>) -> Self { Self::Text { text: s.into() } }
|
|
}
|
|
|
|
/// Events emitted by a running tool. Transport-agnostic.
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
#[serde(tag = "type", rename_all = "snake_case")]
|
|
pub enum ToolEvent {
|
|
Started { title: Arc<str>, kind: ToolKind, location: Option<ToolLocation> },
|
|
TitleUpdate { title: Arc<str>, location: Option<ToolLocation> },
|
|
PartialOutput { content: ToolResultContent },
|
|
Status { message: Arc<str> },
|
|
PermissionRequested { request_id: PermissionRequestId, summary: Arc<str> },
|
|
Completed,
|
|
Failed,
|
|
}
|
|
|
|
/// Sink a tool emits events into. Cheap to clone.
|
|
#[derive(Clone, Debug)]
|
|
pub struct ToolEventSink {
|
|
tx: mpsc::UnboundedSender<ToolEvent>,
|
|
}
|
|
|
|
impl ToolEventSink {
|
|
pub fn new() -> (Self, mpsc::UnboundedReceiver<ToolEvent>) {
|
|
let (tx, rx) = mpsc::unbounded_channel();
|
|
(Self { tx }, rx)
|
|
}
|
|
|
|
/// Best-effort emit. Drops the event if the receiver is gone.
|
|
pub fn emit(&self, event: ToolEvent) {
|
|
let _ = self.tx.send(event);
|
|
}
|
|
}
|
|
|
|
mod serde_bytes_arc {
|
|
use bytes::Bytes;
|
|
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
|
|
|
pub fn serialize<S: Serializer>(b: &Bytes, s: S) -> Result<S::Ok, S::Error> {
|
|
serde_bytes::Bytes::new(b.as_ref()).serialize(s)
|
|
}
|
|
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Bytes, D::Error> {
|
|
let v: Vec<u8> = serde_bytes::ByteBuf::deserialize(d)?.into_vec();
|
|
Ok(Bytes::from(v))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[tokio::test]
|
|
async fn sink_round_trips_event() {
|
|
let (sink, mut rx) = ToolEventSink::new();
|
|
sink.emit(ToolEvent::Status { message: "hi".into() });
|
|
let got = rx.recv().await.unwrap();
|
|
match got {
|
|
ToolEvent::Status { message } => assert_eq!(&*message, "hi"),
|
|
_ => panic!("wrong variant"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn result_content_text_helper() {
|
|
match ToolResultContent::text("hello") {
|
|
ToolResultContent::Text { text } => assert_eq!(&*text, "hello"),
|
|
_ => panic!(),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn tool_event_serde_round_trip() {
|
|
let ev = ToolEvent::PartialOutput { content: ToolResultContent::text("x") };
|
|
let json = serde_json::to_string(&ev).unwrap();
|
|
let _back: ToolEvent = serde_json::from_str(&json).unwrap();
|
|
}
|
|
|
|
#[test]
|
|
fn permission_request_id_constructor_and_accessor() {
|
|
let id = PermissionRequestId::new("abc");
|
|
assert_eq!(id.as_str(), "abc");
|
|
}
|
|
|
|
#[test]
|
|
fn permission_request_id_serde_is_transparent_string() {
|
|
let id = PermissionRequestId::new("foo");
|
|
let json = serde_json::to_string(&id).unwrap();
|
|
assert_eq!(json, "\"foo\"");
|
|
let back: PermissionRequestId = serde_json::from_str(&json).unwrap();
|
|
assert_eq!(back, id);
|
|
}
|
|
}
|