From b4fd1e8ad219ad663d375465c3362ab765990cd1 Mon Sep 17 00:00:00 2001 From: Alex Holmberg Date: Thu, 5 Feb 2026 10:05:22 +0100 Subject: [PATCH] feat: vendor ag-ui-core and ag-ui-server crates Move AG-UI protocol crates into the repository to fix release-plz dependency resolution issues. These crates enable frontend connectivity for agent mode. --- Cargo.toml | 4 +- crates/ag-ui-core/Cargo.toml | 16 + crates/ag-ui-core/src/error.rs | 30 + crates/ag-ui-core/src/event.rs | 2481 ++++++++++++++++++++++ crates/ag-ui-core/src/lib.rs | 64 + crates/ag-ui-core/src/patch.rs | 622 ++++++ crates/ag-ui-core/src/state.rs | 645 ++++++ crates/ag-ui-core/src/types/content.rs | 451 ++++ crates/ag-ui-core/src/types/ids.rs | 156 ++ crates/ag-ui-core/src/types/input.rs | 289 +++ crates/ag-ui-core/src/types/message.rs | 714 +++++++ crates/ag-ui-core/src/types/mod.rs | 20 + crates/ag-ui-core/src/types/tool.rs | 78 + crates/ag-ui-server/Cargo.toml | 28 + crates/ag-ui-server/src/error.rs | 31 + crates/ag-ui-server/src/lib.rs | 43 + crates/ag-ui-server/src/producer.rs | 1055 +++++++++ crates/ag-ui-server/src/transport/mod.rs | 64 + crates/ag-ui-server/src/transport/sse.rs | 291 +++ crates/ag-ui-server/src/transport/ws.rs | 443 ++++ 20 files changed, 7523 insertions(+), 2 deletions(-) create mode 100644 crates/ag-ui-core/Cargo.toml create mode 100644 crates/ag-ui-core/src/error.rs create mode 100644 crates/ag-ui-core/src/event.rs create mode 100644 crates/ag-ui-core/src/lib.rs create mode 100644 crates/ag-ui-core/src/patch.rs create mode 100644 crates/ag-ui-core/src/state.rs create mode 100644 crates/ag-ui-core/src/types/content.rs create mode 100644 crates/ag-ui-core/src/types/ids.rs create mode 100644 crates/ag-ui-core/src/types/input.rs create mode 100644 crates/ag-ui-core/src/types/message.rs create mode 100644 crates/ag-ui-core/src/types/mod.rs create mode 100644 crates/ag-ui-core/src/types/tool.rs create mode 100644 crates/ag-ui-server/Cargo.toml create mode 100644 crates/ag-ui-server/src/error.rs create mode 100644 crates/ag-ui-server/src/lib.rs create mode 100644 crates/ag-ui-server/src/producer.rs create mode 100644 crates/ag-ui-server/src/transport/mod.rs create mode 100644 crates/ag-ui-server/src/transport/sse.rs create mode 100644 crates/ag-ui-server/src/transport/ws.rs diff --git a/Cargo.toml b/Cargo.toml index 48f3a5a6..5f76557b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,8 +85,8 @@ futures-util = "0.3" rig-core = { version = "0.28", features = ["derive", "image"] } # AG-UI Protocol - enables frontend connectivity for agent mode -ag-ui-core = { path = "../ag-ui-sdk/crates/ag-ui-core" } -ag-ui-server = { path = "../ag-ui-sdk/crates/ag-ui-server" } +ag-ui-core = { path = "crates/ag-ui-core" } +ag-ui-server = { path = "crates/ag-ui-server" } axum = { version = "0.8", features = ["ws"] } tower-http = { version = "0.6", features = ["cors"] } tokio-stream = { version = "0.1", features = ["sync"] } diff --git a/crates/ag-ui-core/Cargo.toml b/crates/ag-ui-core/Cargo.toml new file mode 100644 index 00000000..aa26d058 --- /dev/null +++ b/crates/ag-ui-core/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "ag-ui-core" +version = "0.1.0" +edition = "2024" +rust-version = "1.88" +license = "MIT" +description = "Core type library for AG-UI protocol - Syncable SDK" +readme = "README.md" + +[dependencies] +thiserror = "2" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +json-patch = "3" +jsonptr = "0.6" +uuid = { version = "1", features = ["v4", "serde"] } diff --git a/crates/ag-ui-core/src/error.rs b/crates/ag-ui-core/src/error.rs new file mode 100644 index 00000000..d7af543e --- /dev/null +++ b/crates/ag-ui-core/src/error.rs @@ -0,0 +1,30 @@ +//! Error types for AG-UI core operations. + +use thiserror::Error; + +/// Errors that can occur in AG-UI core operations. +#[derive(Debug, Error)] +pub enum AgUiError { + /// Error during JSON serialization/deserialization + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + /// Validation error for event or message data + #[error("Validation error: {0}")] + Validation(String), + + /// Invalid event format or structure + #[error("Invalid event: {0}")] + InvalidEvent(String), + + /// Invalid message format or content + #[error("Invalid message: {0}")] + InvalidMessage(String), + + /// State operation error + #[error("State error: {0}")] + State(String), +} + +/// Result type alias using AgUiError +pub type Result = std::result::Result; diff --git a/crates/ag-ui-core/src/event.rs b/crates/ag-ui-core/src/event.rs new file mode 100644 index 00000000..1522f757 --- /dev/null +++ b/crates/ag-ui-core/src/event.rs @@ -0,0 +1,2481 @@ +//! AG-UI Event Types +//! +//! This module defines all AG-UI protocol event types including: +//! - Text message events (start, content, end, chunk) +//! - Thinking text message events +//! - Tool call events (start, args, end, result) +//! - State events (snapshot, delta) +//! - Run lifecycle events (started, finished, error) +//! - Step events (started, finished) +//! - Custom and raw events + +use crate::state::AgentState; +use crate::types::{Message, MessageId, Role, RunId, ThreadId, ToolCallId}; +use crate::JsonValue; +use serde::{Deserialize, Serialize}; + +/// Event types for the AG-UI protocol. +/// +/// This enum enumerates all possible event types in the protocol. +/// Event types are serialized using SCREAMING_SNAKE_CASE (e.g., `TEXT_MESSAGE_START`). +/// +/// # Event Categories +/// +/// - **Text Messages**: `TextMessageStart`, `TextMessageContent`, `TextMessageEnd`, `TextMessageChunk` +/// - **Thinking Messages**: `ThinkingTextMessageStart`, `ThinkingTextMessageContent`, `ThinkingTextMessageEnd` +/// - **Tool Calls**: `ToolCallStart`, `ToolCallArgs`, `ToolCallEnd`, `ToolCallChunk`, `ToolCallResult` +/// - **Thinking Steps**: `ThinkingStart`, `ThinkingEnd` +/// - **State**: `StateSnapshot`, `StateDelta` +/// - **Messages**: `MessagesSnapshot` +/// - **Run Lifecycle**: `RunStarted`, `RunFinished`, `RunError` +/// - **Step Lifecycle**: `StepStarted`, `StepFinished` +/// - **Other**: `Raw`, `Custom` +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum EventType { + /// Start of a text message from the assistant. + TextMessageStart, + /// Content chunk of a text message (streaming delta). + TextMessageContent, + /// End of a text message. + TextMessageEnd, + /// Complete text message chunk (non-streaming alternative). + TextMessageChunk, + /// Start of a thinking text message (extended thinking). + ThinkingTextMessageStart, + /// Content chunk of a thinking text message. + ThinkingTextMessageContent, + /// End of a thinking text message. + ThinkingTextMessageEnd, + /// Start of a tool call. + ToolCallStart, + /// Arguments chunk for a tool call (streaming). + ToolCallArgs, + /// End of a tool call. + ToolCallEnd, + /// Complete tool call chunk (non-streaming alternative). + ToolCallChunk, + /// Result of a tool call execution. + ToolCallResult, + /// Start of a thinking step (chain-of-thought). + ThinkingStart, + /// End of a thinking step. + ThinkingEnd, + /// Complete state snapshot. + StateSnapshot, + /// Incremental state update (JSON Patch RFC 6902). + StateDelta, + /// Complete messages snapshot. + MessagesSnapshot, + /// Complete activity snapshot. + ActivitySnapshot, + /// Incremental activity update (JSON Patch RFC 6902). + ActivityDelta, + /// Raw event from the underlying provider. + Raw, + /// Custom application-specific event. + Custom, + /// Agent run has started. + RunStarted, + /// Agent run has finished successfully. + RunFinished, + /// Agent run encountered an error. + RunError, + /// A step within a run has started. + StepStarted, + /// A step within a run has finished. + StepFinished, +} + +impl EventType { + /// Returns the string representation of the event type. + pub fn as_str(&self) -> &'static str { + match self { + EventType::TextMessageStart => "TEXT_MESSAGE_START", + EventType::TextMessageContent => "TEXT_MESSAGE_CONTENT", + EventType::TextMessageEnd => "TEXT_MESSAGE_END", + EventType::TextMessageChunk => "TEXT_MESSAGE_CHUNK", + EventType::ThinkingTextMessageStart => "THINKING_TEXT_MESSAGE_START", + EventType::ThinkingTextMessageContent => "THINKING_TEXT_MESSAGE_CONTENT", + EventType::ThinkingTextMessageEnd => "THINKING_TEXT_MESSAGE_END", + EventType::ToolCallStart => "TOOL_CALL_START", + EventType::ToolCallArgs => "TOOL_CALL_ARGS", + EventType::ToolCallEnd => "TOOL_CALL_END", + EventType::ToolCallChunk => "TOOL_CALL_CHUNK", + EventType::ToolCallResult => "TOOL_CALL_RESULT", + EventType::ThinkingStart => "THINKING_START", + EventType::ThinkingEnd => "THINKING_END", + EventType::StateSnapshot => "STATE_SNAPSHOT", + EventType::StateDelta => "STATE_DELTA", + EventType::MessagesSnapshot => "MESSAGES_SNAPSHOT", + EventType::ActivitySnapshot => "ACTIVITY_SNAPSHOT", + EventType::ActivityDelta => "ACTIVITY_DELTA", + EventType::Raw => "RAW", + EventType::Custom => "CUSTOM", + EventType::RunStarted => "RUN_STARTED", + EventType::RunFinished => "RUN_FINISHED", + EventType::RunError => "RUN_ERROR", + EventType::StepStarted => "STEP_STARTED", + EventType::StepFinished => "STEP_FINISHED", + } + } +} + +impl std::fmt::Display for EventType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +/// Base event structure for all AG-UI protocol events. +/// +/// Contains common fields that are present in all event types. +/// Individual event structs flatten this into their structure. +/// +/// # Fields +/// +/// - `timestamp`: Optional Unix timestamp in milliseconds since epoch +/// - `raw_event`: Optional raw event from the underlying provider (for debugging/passthrough) +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::event::BaseEvent; +/// +/// let base = BaseEvent { +/// timestamp: Some(1706123456789.0), +/// raw_event: None, +/// }; +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +pub struct BaseEvent { + /// Unix timestamp in milliseconds since epoch. + #[serde(skip_serializing_if = "Option::is_none")] + pub timestamp: Option, + /// Raw event from the underlying provider (for debugging/passthrough). + #[serde(rename = "rawEvent", skip_serializing_if = "Option::is_none")] + pub raw_event: Option, +} + +impl BaseEvent { + /// Creates a new empty BaseEvent. + pub fn new() -> Self { + Self::default() + } + + /// Creates a BaseEvent with the current timestamp. + pub fn with_current_timestamp() -> Self { + Self { + timestamp: Some( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map(|d| d.as_millis() as f64) + .unwrap_or(0.0), + ), + raw_event: None, + } + } + + /// Sets the timestamp for this event. + pub fn timestamp(mut self, timestamp: f64) -> Self { + self.timestamp = Some(timestamp); + self + } + + /// Sets the raw event for this event. + pub fn raw_event(mut self, raw_event: JsonValue) -> Self { + self.raw_event = Some(raw_event); + self + } +} + +/// Validation errors for AG-UI protocol events. +/// +/// These errors occur when event data fails validation rules. +#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)] +pub enum EventValidationError { + /// Delta content must not be empty. + #[error("Delta must not be an empty string")] + EmptyDelta, + /// Event format is invalid. + #[error("Invalid event format: {0}")] + InvalidFormat(String), + /// Required field is missing. + #[error("Missing required field: {0}")] + MissingField(String), + /// Event type mismatch. + #[error("Event type mismatch: expected {expected}, got {actual}")] + TypeMismatch { + /// Expected event type. + expected: String, + /// Actual event type. + actual: String, + }, +} + +// ============================================================================= +// Text Message Events +// ============================================================================= + +/// Event indicating the start of a text message. +/// +/// This event is sent when the agent begins generating a text message. +/// The message_id identifies this message throughout the streaming process. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::{MessageId, Role}; +/// use ag_ui_core::event::TextMessageStartEvent; +/// +/// let event = TextMessageStartEvent::new(MessageId::random()); +/// assert_eq!(event.role, Role::Assistant); +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct TextMessageStartEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Unique identifier for this message. + #[serde(rename = "messageId")] + pub message_id: MessageId, + /// The role of the message sender (always Assistant for new messages). + pub role: Role, +} + +impl TextMessageStartEvent { + /// Creates a new TextMessageStartEvent with the given message ID. + pub fn new(message_id: impl Into) -> Self { + Self { + base: BaseEvent::default(), + message_id: message_id.into(), + role: Role::Assistant, + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } + + /// Sets the raw event for this event. + pub fn with_raw_event(mut self, raw_event: JsonValue) -> Self { + self.base.raw_event = Some(raw_event); + self + } +} + +/// Event containing a piece of text message content. +/// +/// This event is sent for each chunk of content as the agent generates a message. +/// The delta field contains the new text to append to the message. +/// +/// # Validation +/// +/// The delta must not be empty. Use `new()` which returns a Result to validate, +/// or `new_unchecked()` if you've already validated the input. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct TextMessageContentEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The message ID this content belongs to. + #[serde(rename = "messageId")] + pub message_id: MessageId, + /// The text content delta to append. + pub delta: String, +} + +impl TextMessageContentEvent { + /// Creates a new TextMessageContentEvent with validation. + /// + /// Returns an error if delta is empty. + pub fn new( + message_id: impl Into, + delta: impl Into, + ) -> Result { + let delta = delta.into(); + if delta.is_empty() { + return Err(EventValidationError::EmptyDelta); + } + Ok(Self { + base: BaseEvent::default(), + message_id: message_id.into(), + delta, + }) + } + + /// Creates a new TextMessageContentEvent without validation. + /// + /// Use this only if you've already validated the delta is not empty. + pub fn new_unchecked(message_id: impl Into, delta: impl Into) -> Self { + Self { + base: BaseEvent::default(), + message_id: message_id.into(), + delta: delta.into(), + } + } + + /// Validates this event's data. + pub fn validate(&self) -> Result<(), EventValidationError> { + if self.delta.is_empty() { + return Err(EventValidationError::EmptyDelta); + } + Ok(()) + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +/// Event indicating the end of a text message. +/// +/// This event is sent when the agent completes a text message. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct TextMessageEndEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The message ID that has completed. + #[serde(rename = "messageId")] + pub message_id: MessageId, +} + +impl TextMessageEndEvent { + /// Creates a new TextMessageEndEvent. + pub fn new(message_id: impl Into) -> Self { + Self { + base: BaseEvent::default(), + message_id: message_id.into(), + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +/// Event containing a chunk of text message content. +/// +/// This event combines start, content, and potentially end information in a single event. +/// Used as a non-streaming alternative where all fields may be optional. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct TextMessageChunkEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Optional message ID (may be omitted for continuation chunks). + #[serde(rename = "messageId", skip_serializing_if = "Option::is_none")] + pub message_id: Option, + /// The role of the message sender. + pub role: Role, + /// Optional text content delta. + #[serde(skip_serializing_if = "Option::is_none")] + pub delta: Option, +} + +impl TextMessageChunkEvent { + /// Creates a new TextMessageChunkEvent with the given role. + pub fn new(role: Role) -> Self { + Self { + base: BaseEvent::default(), + message_id: None, + role, + delta: None, + } + } + + /// Sets the message ID for this event. + pub fn with_message_id(mut self, message_id: impl Into) -> Self { + self.message_id = Some(message_id.into()); + self + } + + /// Sets the delta for this event. + pub fn with_delta(mut self, delta: impl Into) -> Self { + self.delta = Some(delta.into()); + self + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +// ============================================================================= +// Thinking Text Message Events +// ============================================================================= + +/// Event indicating the start of a thinking text message. +/// +/// This event is sent when the agent begins generating internal thinking content +/// (extended thinking / chain-of-thought). Unlike regular messages, thinking +/// messages don't have a message ID as they're ephemeral. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ThinkingTextMessageStartEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, +} + +impl ThinkingTextMessageStartEvent { + /// Creates a new ThinkingTextMessageStartEvent. + pub fn new() -> Self { + Self { + base: BaseEvent::default(), + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +impl Default for ThinkingTextMessageStartEvent { + fn default() -> Self { + Self::new() + } +} + +/// Event containing a piece of thinking text message content. +/// +/// This event contains chunks of the agent's internal thinking process. +/// Unlike regular content events, thinking content doesn't validate for +/// empty delta as it may represent the start of a stream. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ThinkingTextMessageContentEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The thinking content delta. + pub delta: String, +} + +impl ThinkingTextMessageContentEvent { + /// Creates a new ThinkingTextMessageContentEvent. + pub fn new(delta: impl Into) -> Self { + Self { + base: BaseEvent::default(), + delta: delta.into(), + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +/// Event indicating the end of a thinking text message. +/// +/// This event is sent when the agent completes its internal thinking process. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ThinkingTextMessageEndEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, +} + +impl ThinkingTextMessageEndEvent { + /// Creates a new ThinkingTextMessageEndEvent. + pub fn new() -> Self { + Self { + base: BaseEvent::default(), + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +impl Default for ThinkingTextMessageEndEvent { + fn default() -> Self { + Self::new() + } +} + +// ============================================================================= +// Tool Call Events +// ============================================================================= + +/// Event indicating the start of a tool call. +/// +/// This event is sent when the agent begins calling a tool with specific parameters. +/// The tool_call_id identifies this call throughout the streaming process. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ToolCallStartEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Unique identifier for this tool call. + #[serde(rename = "toolCallId")] + pub tool_call_id: ToolCallId, + /// Name of the tool being called. + #[serde(rename = "toolCallName")] + pub tool_call_name: String, + /// Optional parent message ID if this call is part of a message. + #[serde(rename = "parentMessageId", skip_serializing_if = "Option::is_none")] + pub parent_message_id: Option, +} + +impl ToolCallStartEvent { + /// Creates a new ToolCallStartEvent. + pub fn new(tool_call_id: impl Into, tool_call_name: impl Into) -> Self { + Self { + base: BaseEvent::default(), + tool_call_id: tool_call_id.into(), + tool_call_name: tool_call_name.into(), + parent_message_id: None, + } + } + + /// Sets the parent message ID. + pub fn with_parent_message_id(mut self, message_id: impl Into) -> Self { + self.parent_message_id = Some(message_id.into()); + self + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +/// Event containing tool call arguments. +/// +/// This event contains chunks of the arguments being passed to a tool. +/// Arguments are streamed as JSON string deltas. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ToolCallArgsEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The tool call ID this argument chunk belongs to. + #[serde(rename = "toolCallId")] + pub tool_call_id: ToolCallId, + /// The argument delta (JSON string chunk). + pub delta: String, +} + +impl ToolCallArgsEvent { + /// Creates a new ToolCallArgsEvent. + pub fn new(tool_call_id: impl Into, delta: impl Into) -> Self { + Self { + base: BaseEvent::default(), + tool_call_id: tool_call_id.into(), + delta: delta.into(), + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +/// Event indicating the end of a tool call. +/// +/// This event is sent when the agent completes sending arguments to a tool. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ToolCallEndEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The tool call ID that has completed. + #[serde(rename = "toolCallId")] + pub tool_call_id: ToolCallId, +} + +impl ToolCallEndEvent { + /// Creates a new ToolCallEndEvent. + pub fn new(tool_call_id: impl Into) -> Self { + Self { + base: BaseEvent::default(), + tool_call_id: tool_call_id.into(), + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +/// Event containing a chunk of tool call content. +/// +/// This event combines start, args, and potentially end information in a single event. +/// Used as a non-streaming alternative where all fields may be optional. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ToolCallChunkEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Optional tool call ID (may be omitted for continuation chunks). + #[serde(rename = "toolCallId", skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Optional tool name. + #[serde(rename = "toolCallName", skip_serializing_if = "Option::is_none")] + pub tool_call_name: Option, + /// Optional parent message ID. + #[serde(rename = "parentMessageId", skip_serializing_if = "Option::is_none")] + pub parent_message_id: Option, + /// Optional argument delta. + #[serde(skip_serializing_if = "Option::is_none")] + pub delta: Option, +} + +impl ToolCallChunkEvent { + /// Creates a new empty ToolCallChunkEvent. + pub fn new() -> Self { + Self { + base: BaseEvent::default(), + tool_call_id: None, + tool_call_name: None, + parent_message_id: None, + delta: None, + } + } + + /// Sets the tool call ID. + pub fn with_tool_call_id(mut self, tool_call_id: impl Into) -> Self { + self.tool_call_id = Some(tool_call_id.into()); + self + } + + /// Sets the tool call name. + pub fn with_tool_call_name(mut self, name: impl Into) -> Self { + self.tool_call_name = Some(name.into()); + self + } + + /// Sets the parent message ID. + pub fn with_parent_message_id(mut self, message_id: impl Into) -> Self { + self.parent_message_id = Some(message_id.into()); + self + } + + /// Sets the delta. + pub fn with_delta(mut self, delta: impl Into) -> Self { + self.delta = Some(delta.into()); + self + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +impl Default for ToolCallChunkEvent { + fn default() -> Self { + Self::new() + } +} + +/// Event containing the result of a tool call. +/// +/// This event is sent when a tool has completed execution and returns its result. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ToolCallResultEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Message ID for the result message. + #[serde(rename = "messageId")] + pub message_id: MessageId, + /// The tool call ID this result corresponds to. + #[serde(rename = "toolCallId")] + pub tool_call_id: ToolCallId, + /// The result content. + pub content: String, + /// Role (always Tool). + #[serde(default = "Role::tool")] + pub role: Role, +} + +impl ToolCallResultEvent { + /// Creates a new ToolCallResultEvent. + pub fn new( + message_id: impl Into, + tool_call_id: impl Into, + content: impl Into, + ) -> Self { + Self { + base: BaseEvent::default(), + message_id: message_id.into(), + tool_call_id: tool_call_id.into(), + content: content.into(), + role: Role::Tool, + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +// ============================================================================= +// Run Lifecycle Events +// ============================================================================= + +/// Event indicating that a run has started. +/// +/// This event is sent when an agent run begins execution within a specific thread. +/// A run represents a single agent execution that may produce multiple events. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RunStartedEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The thread ID this run belongs to. + #[serde(rename = "threadId")] + pub thread_id: ThreadId, + /// Unique identifier for this run. + #[serde(rename = "runId")] + pub run_id: RunId, +} + +impl RunStartedEvent { + /// Creates a new RunStartedEvent. + pub fn new(thread_id: impl Into, run_id: impl Into) -> Self { + Self { + base: BaseEvent::default(), + thread_id: thread_id.into(), + run_id: run_id.into(), + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +/// Outcome of a run finishing. +/// +/// Used to indicate whether a run completed successfully or was interrupted +/// for human-in-the-loop interaction. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum RunFinishedOutcome { + /// Run completed successfully. + Success, + /// Run was interrupted and requires human input to continue. + Interrupt, +} + +impl Default for RunFinishedOutcome { + fn default() -> Self { + Self::Success + } +} + +/// Information about a run interrupt. +/// +/// When a run finishes with `outcome == Interrupt`, this struct contains +/// information about why the interrupt occurred and what input is needed. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::InterruptInfo; +/// +/// let info = InterruptInfo::new() +/// .with_id("approval-001") +/// .with_reason("human_approval") +/// .with_payload(serde_json::json!({ +/// "action": "DELETE", +/// "table": "users", +/// "affectedRows": 42 +/// })); +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +pub struct InterruptInfo { + /// Optional identifier for tracking this interrupt across resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + /// Optional reason describing why the interrupt occurred. + /// Common values: "human_approval", "upload_required", "policy_hold" + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, + /// Optional payload with context for the interrupt UI. + /// Contains arbitrary JSON data for rendering approval forms, proposals, etc. + #[serde(skip_serializing_if = "Option::is_none")] + pub payload: Option, +} + +impl InterruptInfo { + /// Creates a new empty InterruptInfo. + pub fn new() -> Self { + Self::default() + } + + /// Sets the interrupt ID. + pub fn with_id(mut self, id: impl Into) -> Self { + self.id = Some(id.into()); + self + } + + /// Sets the interrupt reason. + pub fn with_reason(mut self, reason: impl Into) -> Self { + self.reason = Some(reason.into()); + self + } + + /// Sets the interrupt payload. + pub fn with_payload(mut self, payload: JsonValue) -> Self { + self.payload = Some(payload); + self + } +} + +/// Event indicating that a run has finished. +/// +/// This event is sent when an agent run completes, either successfully or +/// with an interrupt requiring human input. +/// +/// # Interrupt Flow +/// +/// When `outcome == Interrupt`, the agent indicates that on the next run, +/// a value needs to be provided via `RunAgentInput.resume` to continue. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::{RunFinishedEvent, RunFinishedOutcome, InterruptInfo, ThreadId, RunId}; +/// +/// // Success case +/// let success = RunFinishedEvent::new(ThreadId::random(), RunId::random()) +/// .with_result(serde_json::json!({"status": "done"})); +/// +/// // Interrupt case +/// let interrupt = RunFinishedEvent::new(ThreadId::random(), RunId::random()) +/// .with_interrupt( +/// InterruptInfo::new() +/// .with_reason("human_approval") +/// .with_payload(serde_json::json!({"action": "send_email"})) +/// ); +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RunFinishedEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The thread ID this run belongs to. + #[serde(rename = "threadId")] + pub thread_id: ThreadId, + /// The run ID that finished. + #[serde(rename = "runId")] + pub run_id: RunId, + /// Outcome of the run. Optional for backward compatibility. + /// When omitted, outcome is inferred: if interrupt is present, it's Interrupt; otherwise Success. + #[serde(skip_serializing_if = "Option::is_none")] + pub outcome: Option, + /// Optional result value from the run. + /// Present when outcome is Success (or omitted with no interrupt). + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + /// Optional interrupt information. + /// Present when outcome is Interrupt (or omitted with interrupt present). + #[serde(skip_serializing_if = "Option::is_none")] + pub interrupt: Option, +} + +impl RunFinishedEvent { + /// Creates a new RunFinishedEvent with Success outcome. + pub fn new(thread_id: impl Into, run_id: impl Into) -> Self { + Self { + base: BaseEvent::default(), + thread_id: thread_id.into(), + run_id: run_id.into(), + outcome: None, + result: None, + interrupt: None, + } + } + + /// Sets the outcome explicitly. + pub fn with_outcome(mut self, outcome: RunFinishedOutcome) -> Self { + self.outcome = Some(outcome); + self + } + + /// Sets the result for this event (implies Success outcome). + pub fn with_result(mut self, result: JsonValue) -> Self { + self.result = Some(result); + self + } + + /// Sets the interrupt info (implies Interrupt outcome). + pub fn with_interrupt(mut self, interrupt: InterruptInfo) -> Self { + self.outcome = Some(RunFinishedOutcome::Interrupt); + self.interrupt = Some(interrupt); + self + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } + + /// Returns the effective outcome of this event. + /// + /// If outcome is explicitly set, returns that. Otherwise: + /// - If interrupt is present, returns Interrupt + /// - Otherwise, returns Success + pub fn effective_outcome(&self) -> RunFinishedOutcome { + self.outcome.unwrap_or_else(|| { + if self.interrupt.is_some() { + RunFinishedOutcome::Interrupt + } else { + RunFinishedOutcome::Success + } + }) + } +} + +/// Event indicating that a run has encountered an error. +/// +/// This event is sent when an agent run fails with an error. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RunErrorEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Error message describing what went wrong. + pub message: String, + /// Optional error code for programmatic handling. + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, +} + +impl RunErrorEvent { + /// Creates a new RunErrorEvent. + pub fn new(message: impl Into) -> Self { + Self { + base: BaseEvent::default(), + message: message.into(), + code: None, + } + } + + /// Sets the error code. + pub fn with_code(mut self, code: impl Into) -> Self { + self.code = Some(code.into()); + self + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +// ============================================================================= +// Step Events +// ============================================================================= + +/// Event indicating that a step has started. +/// +/// This event is sent when a specific named step within a run begins execution. +/// Steps allow tracking progress through multi-stage agent workflows. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct StepStartedEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Name of the step that started. + #[serde(rename = "stepName")] + pub step_name: String, +} + +impl StepStartedEvent { + /// Creates a new StepStartedEvent. + pub fn new(step_name: impl Into) -> Self { + Self { + base: BaseEvent::default(), + step_name: step_name.into(), + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +/// Event indicating that a step has finished. +/// +/// This event is sent when a specific named step within a run completes execution. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct StepFinishedEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Name of the step that finished. + #[serde(rename = "stepName")] + pub step_name: String, +} + +impl StepFinishedEvent { + /// Creates a new StepFinishedEvent. + pub fn new(step_name: impl Into) -> Self { + Self { + base: BaseEvent::default(), + step_name: step_name.into(), + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +// ============================================================================= +// State Events +// ============================================================================= + +/// Event containing a complete state snapshot. +/// +/// This event is sent to provide the full current state of the agent. +/// The state is generic over `StateT` which must implement `AgentState`. +/// +/// # Type Parameter +/// +/// - `StateT`: The type of state, defaults to `JsonValue` for flexibility. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(bound(deserialize = ""))] +pub struct StateSnapshotEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The complete state snapshot. + pub snapshot: StateT, +} + +impl StateSnapshotEvent { + /// Creates a new StateSnapshotEvent with the given state. + pub fn new(snapshot: StateT) -> Self { + Self { + base: BaseEvent::default(), + snapshot, + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +impl Default for StateSnapshotEvent { + fn default() -> Self { + Self { + base: BaseEvent::default(), + snapshot: StateT::default(), + } + } +} + +/// Event containing incremental state updates as JSON Patch operations. +/// +/// This event is sent to update state incrementally using RFC 6902 JSON Patch format. +/// The delta is a vector of patch operations (add, remove, replace, move, copy, test). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct StateDeltaEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// JSON Patch operations per RFC 6902. + pub delta: Vec, +} + +impl StateDeltaEvent { + /// Creates a new StateDeltaEvent with the given patch operations. + pub fn new(delta: Vec) -> Self { + Self { + base: BaseEvent::default(), + delta, + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +impl Default for StateDeltaEvent { + fn default() -> Self { + Self { + base: BaseEvent::default(), + delta: Vec::new(), + } + } +} + +/// Event containing a complete snapshot of all messages. +/// +/// This event is sent to provide the full message history to the client. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct MessagesSnapshotEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Complete list of messages. + pub messages: Vec, +} + +impl MessagesSnapshotEvent { + /// Creates a new MessagesSnapshotEvent with the given messages. + pub fn new(messages: Vec) -> Self { + Self { + base: BaseEvent::default(), + messages, + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +impl Default for MessagesSnapshotEvent { + fn default() -> Self { + Self { + base: BaseEvent::default(), + messages: Vec::new(), + } + } +} + +// ============================================================================= +// Activity Events +// ============================================================================= + +/// Event containing a complete activity snapshot. +/// +/// This event creates a new activity message or replaces an existing one. +/// Activity messages track structured agent activities like planning or research. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::event::ActivitySnapshotEvent; +/// use ag_ui_core::MessageId; +/// use serde_json::json; +/// +/// let event = ActivitySnapshotEvent::new( +/// MessageId::random(), +/// "PLAN", +/// json!({"steps": ["research", "implement", "test"]}), +/// ); +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ActivitySnapshotEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The message ID for this activity. + #[serde(rename = "messageId")] + pub message_id: MessageId, + /// The type of activity (e.g., "PLAN", "RESEARCH"). + #[serde(rename = "activityType")] + pub activity_type: String, + /// The activity content as a flexible JSON object. + pub content: JsonValue, + /// Whether to replace the existing activity content (default: true). + #[serde(skip_serializing_if = "Option::is_none")] + pub replace: Option, +} + +impl ActivitySnapshotEvent { + /// Creates a new ActivitySnapshotEvent with the given message ID, type, and content. + pub fn new( + message_id: impl Into, + activity_type: impl Into, + content: JsonValue, + ) -> Self { + Self { + base: BaseEvent::default(), + message_id: message_id.into(), + activity_type: activity_type.into(), + content, + replace: None, + } + } + + /// Sets whether to replace the existing activity content. + pub fn with_replace(mut self, replace: bool) -> Self { + self.replace = Some(replace); + self + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +/// Event containing an incremental activity update. +/// +/// This event applies a JSON Patch (RFC 6902) to an existing activity's content. +/// Use this for efficient partial updates to activity content. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::event::ActivityDeltaEvent; +/// use ag_ui_core::MessageId; +/// use serde_json::json; +/// +/// let event = ActivityDeltaEvent::new( +/// MessageId::random(), +/// "PLAN", +/// vec![json!({"op": "add", "path": "/steps/-", "value": "deploy"})], +/// ); +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ActivityDeltaEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The message ID for this activity. + #[serde(rename = "messageId")] + pub message_id: MessageId, + /// The type of activity (e.g., "PLAN", "RESEARCH"). + #[serde(rename = "activityType")] + pub activity_type: String, + /// JSON Patch operations (RFC 6902) to apply to the content. + pub patch: Vec, +} + +impl ActivityDeltaEvent { + /// Creates a new ActivityDeltaEvent with the given message ID, type, and patch. + pub fn new( + message_id: impl Into, + activity_type: impl Into, + patch: Vec, + ) -> Self { + Self { + base: BaseEvent::default(), + message_id: message_id.into(), + activity_type: activity_type.into(), + patch, + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +// ============================================================================= +// Thinking Step Events +// ============================================================================= + +/// Event indicating that a thinking step has started. +/// +/// This event is sent when the agent begins a chain-of-thought reasoning step. +/// Unlike ThinkingTextMessage events (which contain actual thinking content), +/// this event marks the boundary of a thinking block. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ThinkingStartEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Optional title for the thinking step. + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, +} + +impl ThinkingStartEvent { + /// Creates a new ThinkingStartEvent. + pub fn new() -> Self { + Self { + base: BaseEvent::default(), + title: None, + } + } + + /// Sets the title for this thinking step. + pub fn with_title(mut self, title: impl Into) -> Self { + self.title = Some(title.into()); + self + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +impl Default for ThinkingStartEvent { + fn default() -> Self { + Self::new() + } +} + +/// Event indicating that a thinking step has ended. +/// +/// This event is sent when the agent completes a chain-of-thought reasoning step. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ThinkingEndEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, +} + +impl ThinkingEndEvent { + /// Creates a new ThinkingEndEvent. + pub fn new() -> Self { + Self { + base: BaseEvent::default(), + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +impl Default for ThinkingEndEvent { + fn default() -> Self { + Self::new() + } +} + +// ============================================================================= +// Special Events +// ============================================================================= + +/// Event containing raw data from the underlying provider. +/// +/// This event is sent to pass through raw provider-specific data that +/// doesn't fit into other event types. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RawEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// The raw event data. + pub event: JsonValue, + /// Optional source identifier for the raw event. + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, +} + +impl RawEvent { + /// Creates a new RawEvent with the given event data. + pub fn new(event: JsonValue) -> Self { + Self { + base: BaseEvent::default(), + event, + source: None, + } + } + + /// Sets the source identifier. + pub fn with_source(mut self, source: impl Into) -> Self { + self.source = Some(source.into()); + self + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +/// Event for custom application-specific data. +/// +/// This event allows sending arbitrary named events with custom payloads. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct CustomEvent { + /// Common event fields (timestamp, rawEvent). + #[serde(flatten)] + pub base: BaseEvent, + /// Name of the custom event. + pub name: String, + /// Custom event payload. + pub value: JsonValue, +} + +impl CustomEvent { + /// Creates a new CustomEvent with the given name and value. + pub fn new(name: impl Into, value: JsonValue) -> Self { + Self { + base: BaseEvent::default(), + name: name.into(), + value, + } + } + + /// Sets the timestamp for this event. + pub fn with_timestamp(mut self, timestamp: f64) -> Self { + self.base.timestamp = Some(timestamp); + self + } +} + +// ============================================================================= +// Event Union +// ============================================================================= + +/// Union of all possible events in the Agent User Interaction Protocol. +/// +/// This enum represents any event that can be sent or received in the AG-UI protocol. +/// Events are serialized with a `type` discriminant in SCREAMING_SNAKE_CASE format. +/// +/// # Type Parameter +/// +/// - `StateT`: The type of state for `StateSnapshot` events, defaults to `JsonValue`. +/// +/// # Serialization +/// +/// Events are serialized as JSON objects with a `type` field indicating the variant: +/// ```json +/// {"type": "TEXT_MESSAGE_START", "messageId": "...", "role": "assistant"} +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "SCREAMING_SNAKE_CASE", bound(deserialize = ""))] +pub enum Event { + /// Start of a text message from the assistant. + TextMessageStart(TextMessageStartEvent), + /// Content chunk of a text message (streaming delta). + TextMessageContent(TextMessageContentEvent), + /// End of a text message. + TextMessageEnd(TextMessageEndEvent), + /// Complete text message chunk (non-streaming alternative). + TextMessageChunk(TextMessageChunkEvent), + /// Start of a thinking text message (extended thinking). + ThinkingTextMessageStart(ThinkingTextMessageStartEvent), + /// Content chunk of a thinking text message. + ThinkingTextMessageContent(ThinkingTextMessageContentEvent), + /// End of a thinking text message. + ThinkingTextMessageEnd(ThinkingTextMessageEndEvent), + /// Start of a tool call. + ToolCallStart(ToolCallStartEvent), + /// Arguments chunk for a tool call (streaming). + ToolCallArgs(ToolCallArgsEvent), + /// End of a tool call. + ToolCallEnd(ToolCallEndEvent), + /// Complete tool call chunk (non-streaming alternative). + ToolCallChunk(ToolCallChunkEvent), + /// Result of a tool call execution. + ToolCallResult(ToolCallResultEvent), + /// Start of a thinking step (chain-of-thought boundary). + ThinkingStart(ThinkingStartEvent), + /// End of a thinking step. + ThinkingEnd(ThinkingEndEvent), + /// Complete state snapshot. + StateSnapshot(StateSnapshotEvent), + /// Incremental state update (JSON Patch). + StateDelta(StateDeltaEvent), + /// Complete messages snapshot. + MessagesSnapshot(MessagesSnapshotEvent), + /// Complete activity snapshot. + ActivitySnapshot(ActivitySnapshotEvent), + /// Incremental activity update (JSON Patch). + ActivityDelta(ActivityDeltaEvent), + /// Raw event from the underlying provider. + Raw(RawEvent), + /// Custom application-specific event. + Custom(CustomEvent), + /// Agent run has started. + RunStarted(RunStartedEvent), + /// Agent run has finished successfully. + RunFinished(RunFinishedEvent), + /// Agent run encountered an error. + RunError(RunErrorEvent), + /// A step within a run has started. + StepStarted(StepStartedEvent), + /// A step within a run has finished. + StepFinished(StepFinishedEvent), +} + +impl Event { + /// Returns the event type for this event. + pub fn event_type(&self) -> EventType { + match self { + Event::TextMessageStart(_) => EventType::TextMessageStart, + Event::TextMessageContent(_) => EventType::TextMessageContent, + Event::TextMessageEnd(_) => EventType::TextMessageEnd, + Event::TextMessageChunk(_) => EventType::TextMessageChunk, + Event::ThinkingTextMessageStart(_) => EventType::ThinkingTextMessageStart, + Event::ThinkingTextMessageContent(_) => EventType::ThinkingTextMessageContent, + Event::ThinkingTextMessageEnd(_) => EventType::ThinkingTextMessageEnd, + Event::ToolCallStart(_) => EventType::ToolCallStart, + Event::ToolCallArgs(_) => EventType::ToolCallArgs, + Event::ToolCallEnd(_) => EventType::ToolCallEnd, + Event::ToolCallChunk(_) => EventType::ToolCallChunk, + Event::ToolCallResult(_) => EventType::ToolCallResult, + Event::ThinkingStart(_) => EventType::ThinkingStart, + Event::ThinkingEnd(_) => EventType::ThinkingEnd, + Event::StateSnapshot(_) => EventType::StateSnapshot, + Event::StateDelta(_) => EventType::StateDelta, + Event::MessagesSnapshot(_) => EventType::MessagesSnapshot, + Event::ActivitySnapshot(_) => EventType::ActivitySnapshot, + Event::ActivityDelta(_) => EventType::ActivityDelta, + Event::Raw(_) => EventType::Raw, + Event::Custom(_) => EventType::Custom, + Event::RunStarted(_) => EventType::RunStarted, + Event::RunFinished(_) => EventType::RunFinished, + Event::RunError(_) => EventType::RunError, + Event::StepStarted(_) => EventType::StepStarted, + Event::StepFinished(_) => EventType::StepFinished, + } + } + + /// Returns the timestamp of this event if available. + pub fn timestamp(&self) -> Option { + match self { + Event::TextMessageStart(e) => e.base.timestamp, + Event::TextMessageContent(e) => e.base.timestamp, + Event::TextMessageEnd(e) => e.base.timestamp, + Event::TextMessageChunk(e) => e.base.timestamp, + Event::ThinkingTextMessageStart(e) => e.base.timestamp, + Event::ThinkingTextMessageContent(e) => e.base.timestamp, + Event::ThinkingTextMessageEnd(e) => e.base.timestamp, + Event::ToolCallStart(e) => e.base.timestamp, + Event::ToolCallArgs(e) => e.base.timestamp, + Event::ToolCallEnd(e) => e.base.timestamp, + Event::ToolCallChunk(e) => e.base.timestamp, + Event::ToolCallResult(e) => e.base.timestamp, + Event::ThinkingStart(e) => e.base.timestamp, + Event::ThinkingEnd(e) => e.base.timestamp, + Event::StateSnapshot(e) => e.base.timestamp, + Event::StateDelta(e) => e.base.timestamp, + Event::MessagesSnapshot(e) => e.base.timestamp, + Event::ActivitySnapshot(e) => e.base.timestamp, + Event::ActivityDelta(e) => e.base.timestamp, + Event::Raw(e) => e.base.timestamp, + Event::Custom(e) => e.base.timestamp, + Event::RunStarted(e) => e.base.timestamp, + Event::RunFinished(e) => e.base.timestamp, + Event::RunError(e) => e.base.timestamp, + Event::StepStarted(e) => e.base.timestamp, + Event::StepFinished(e) => e.base.timestamp, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_event_type_serialization() { + let event_type = EventType::TextMessageStart; + let json = serde_json::to_string(&event_type).unwrap(); + assert_eq!(json, "\"TEXT_MESSAGE_START\""); + + let event_type = EventType::ToolCallArgs; + let json = serde_json::to_string(&event_type).unwrap(); + assert_eq!(json, "\"TOOL_CALL_ARGS\""); + + let event_type = EventType::StateSnapshot; + let json = serde_json::to_string(&event_type).unwrap(); + assert_eq!(json, "\"STATE_SNAPSHOT\""); + } + + #[test] + fn test_event_type_deserialization() { + let event_type: EventType = serde_json::from_str("\"RUN_STARTED\"").unwrap(); + assert_eq!(event_type, EventType::RunStarted); + + let event_type: EventType = serde_json::from_str("\"THINKING_TEXT_MESSAGE_CONTENT\"").unwrap(); + assert_eq!(event_type, EventType::ThinkingTextMessageContent); + } + + #[test] + fn test_event_type_as_str() { + assert_eq!(EventType::TextMessageStart.as_str(), "TEXT_MESSAGE_START"); + assert_eq!(EventType::RunFinished.as_str(), "RUN_FINISHED"); + assert_eq!(EventType::Custom.as_str(), "CUSTOM"); + } + + #[test] + fn test_event_type_display() { + assert_eq!(format!("{}", EventType::TextMessageStart), "TEXT_MESSAGE_START"); + assert_eq!(format!("{}", EventType::StateDelta), "STATE_DELTA"); + } + + #[test] + fn test_base_event_serialization() { + let event = BaseEvent { + timestamp: Some(1706123456789.0), + raw_event: None, + }; + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"timestamp\":1706123456789.0")); + assert!(!json.contains("rawEvent")); // skipped when None + } + + #[test] + fn test_base_event_with_raw_event() { + let event = BaseEvent { + timestamp: None, + raw_event: Some(serde_json::json!({"provider": "openai"})), + }; + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"rawEvent\"")); + assert!(json.contains("\"provider\":\"openai\"")); + } + + #[test] + fn test_base_event_builder() { + let event = BaseEvent::new() + .timestamp(1234567890.0) + .raw_event(serde_json::json!({"test": true})); + + assert_eq!(event.timestamp, Some(1234567890.0)); + assert!(event.raw_event.is_some()); + } + + #[test] + fn test_event_validation_error_display() { + let error = EventValidationError::EmptyDelta; + assert_eq!(error.to_string(), "Delta must not be an empty string"); + + let error = EventValidationError::InvalidFormat("bad json".to_string()); + assert_eq!(error.to_string(), "Invalid event format: bad json"); + + let error = EventValidationError::MissingField("message_id".to_string()); + assert_eq!(error.to_string(), "Missing required field: message_id"); + + let error = EventValidationError::TypeMismatch { + expected: "TEXT_MESSAGE_START".to_string(), + actual: "RUN_STARTED".to_string(), + }; + assert_eq!( + error.to_string(), + "Event type mismatch: expected TEXT_MESSAGE_START, got RUN_STARTED" + ); + } + + #[test] + fn test_event_validation_error_is_std_error() { + fn requires_error(_: E) {} + requires_error(EventValidationError::EmptyDelta); + } + + #[test] + fn test_all_event_types_roundtrip() { + let all_types = [ + EventType::TextMessageStart, + EventType::TextMessageContent, + EventType::TextMessageEnd, + EventType::TextMessageChunk, + EventType::ThinkingTextMessageStart, + EventType::ThinkingTextMessageContent, + EventType::ThinkingTextMessageEnd, + EventType::ToolCallStart, + EventType::ToolCallArgs, + EventType::ToolCallEnd, + EventType::ToolCallChunk, + EventType::ToolCallResult, + EventType::ThinkingStart, + EventType::ThinkingEnd, + EventType::StateSnapshot, + EventType::StateDelta, + EventType::MessagesSnapshot, + EventType::ActivitySnapshot, + EventType::ActivityDelta, + EventType::Raw, + EventType::Custom, + EventType::RunStarted, + EventType::RunFinished, + EventType::RunError, + EventType::StepStarted, + EventType::StepFinished, + ]; + + for event_type in all_types { + let json = serde_json::to_string(&event_type).unwrap(); + let parsed: EventType = serde_json::from_str(&json).unwrap(); + assert_eq!(event_type, parsed); + } + } + + // ========================================================================= + // Text Message Event Tests + // ========================================================================= + + #[test] + fn test_text_message_start_event() { + use crate::types::{MessageId, Role}; + + let event = TextMessageStartEvent::new(MessageId::random()); + assert_eq!(event.role, Role::Assistant); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"messageId\"")); + assert!(json.contains("\"role\":\"assistant\"")); + } + + #[test] + fn test_text_message_start_event_with_timestamp() { + use crate::types::MessageId; + + let event = TextMessageStartEvent::new(MessageId::random()).with_timestamp(1234567890.0); + assert_eq!(event.base.timestamp, Some(1234567890.0)); + } + + #[test] + fn test_text_message_content_event_validation() { + use crate::types::MessageId; + + // Valid delta + let result = TextMessageContentEvent::new(MessageId::random(), "Hello"); + assert!(result.is_ok()); + + // Empty delta should fail + let result = TextMessageContentEvent::new(MessageId::random(), ""); + assert!(matches!(result, Err(EventValidationError::EmptyDelta))); + } + + #[test] + fn test_text_message_content_event_validate_method() { + use crate::types::MessageId; + + let event = TextMessageContentEvent::new_unchecked(MessageId::random(), ""); + assert!(matches!(event.validate(), Err(EventValidationError::EmptyDelta))); + + let event = TextMessageContentEvent::new_unchecked(MessageId::random(), "Hello"); + assert!(event.validate().is_ok()); + } + + #[test] + fn test_text_message_content_event_serialization() { + use crate::types::MessageId; + + let event = TextMessageContentEvent::new(MessageId::random(), "Hello, world!").unwrap(); + let json = serde_json::to_string(&event).unwrap(); + + assert!(json.contains("\"messageId\"")); + assert!(json.contains("\"delta\":\"Hello, world!\"")); + } + + #[test] + fn test_text_message_end_event() { + use crate::types::MessageId; + + let msg_id = MessageId::random(); + let event = TextMessageEndEvent::new(msg_id.clone()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"messageId\"")); + } + + #[test] + fn test_text_message_chunk_event() { + use crate::types::{MessageId, Role}; + + let event = TextMessageChunkEvent::new(Role::Assistant) + .with_message_id(MessageId::random()) + .with_delta("chunk content"); + + assert!(event.message_id.is_some()); + assert_eq!(event.delta, Some("chunk content".to_string())); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"messageId\"")); + assert!(json.contains("\"delta\":\"chunk content\"")); + } + + #[test] + fn test_text_message_chunk_event_skips_none() { + use crate::types::Role; + + let event = TextMessageChunkEvent::new(Role::Assistant); + let json = serde_json::to_string(&event).unwrap(); + + // Should not contain optional fields when None + assert!(!json.contains("\"messageId\"")); + assert!(!json.contains("\"delta\"")); + assert!(json.contains("\"role\":\"assistant\"")); + } + + // ========================================================================= + // Thinking Text Message Event Tests + // ========================================================================= + + #[test] + fn test_thinking_text_message_start_event() { + let event = ThinkingTextMessageStartEvent::new(); + let json = serde_json::to_string(&event).unwrap(); + + // Should be minimal - just empty object or with timestamp if set + assert_eq!(json, "{}"); + } + + #[test] + fn test_thinking_text_message_start_event_with_timestamp() { + let event = ThinkingTextMessageStartEvent::new().with_timestamp(1234567890.0); + let json = serde_json::to_string(&event).unwrap(); + + assert!(json.contains("\"timestamp\":1234567890.0")); + } + + #[test] + fn test_thinking_text_message_content_event() { + let event = ThinkingTextMessageContentEvent::new("Let me think about this..."); + + assert_eq!(event.delta, "Let me think about this..."); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"delta\":\"Let me think about this...\"")); + } + + #[test] + fn test_thinking_text_message_content_event_allows_empty() { + // Unlike TextMessageContentEvent, ThinkingTextMessageContentEvent allows empty delta + let event = ThinkingTextMessageContentEvent::new(""); + assert_eq!(event.delta, ""); + } + + #[test] + fn test_thinking_text_message_end_event() { + let event = ThinkingTextMessageEndEvent::new(); + let json = serde_json::to_string(&event).unwrap(); + + // Should be minimal + assert_eq!(json, "{}"); + } + + #[test] + fn test_thinking_text_message_events_default() { + let start = ThinkingTextMessageStartEvent::default(); + let end = ThinkingTextMessageEndEvent::default(); + + assert!(start.base.timestamp.is_none()); + assert!(end.base.timestamp.is_none()); + } + + // ========================================================================= + // Tool Call Event Tests + // ========================================================================= + + #[test] + fn test_tool_call_start_event() { + use crate::types::ToolCallId; + + let event = ToolCallStartEvent::new(ToolCallId::random(), "get_weather"); + + assert_eq!(event.tool_call_name, "get_weather"); + assert!(event.parent_message_id.is_none()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"toolCallId\"")); + assert!(json.contains("\"toolCallName\":\"get_weather\"")); + assert!(!json.contains("parentMessageId")); // skipped when None + } + + #[test] + fn test_tool_call_start_event_with_parent() { + use crate::types::{MessageId, ToolCallId}; + + let event = ToolCallStartEvent::new(ToolCallId::random(), "get_weather") + .with_parent_message_id(MessageId::random()); + + assert!(event.parent_message_id.is_some()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"parentMessageId\"")); + } + + #[test] + fn test_tool_call_args_event() { + use crate::types::ToolCallId; + + let event = ToolCallArgsEvent::new(ToolCallId::random(), r#"{"location":"#); + + assert_eq!(event.delta, r#"{"location":"#); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"toolCallId\"")); + assert!(json.contains("\"delta\"")); + } + + #[test] + fn test_tool_call_end_event() { + use crate::types::ToolCallId; + + let event = ToolCallEndEvent::new(ToolCallId::random()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"toolCallId\"")); + } + + #[test] + fn test_tool_call_chunk_event() { + use crate::types::ToolCallId; + + let event = ToolCallChunkEvent::new() + .with_tool_call_id(ToolCallId::random()) + .with_tool_call_name("search") + .with_delta(r#"{"query": "rust"}"#); + + assert!(event.tool_call_id.is_some()); + assert_eq!(event.tool_call_name, Some("search".to_string())); + assert!(event.delta.is_some()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"toolCallId\"")); + assert!(json.contains("\"toolCallName\":\"search\"")); + assert!(json.contains("\"delta\"")); + } + + #[test] + fn test_tool_call_chunk_event_skips_none() { + let event = ToolCallChunkEvent::new(); + let json = serde_json::to_string(&event).unwrap(); + + // Should not contain optional fields when None + assert_eq!(json, "{}"); + } + + #[test] + fn test_tool_call_result_event() { + use crate::types::{MessageId, Role, ToolCallId}; + + let event = ToolCallResultEvent::new( + MessageId::random(), + ToolCallId::random(), + r#"{"weather": "sunny", "temp": 72}"#, + ); + + assert_eq!(event.role, Role::Tool); + assert!(event.content.contains("sunny")); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"messageId\"")); + assert!(json.contains("\"toolCallId\"")); + assert!(json.contains("\"content\"")); + assert!(json.contains("\"role\":\"tool\"")); + } + + #[test] + fn test_tool_call_result_event_deserialize_default_role() { + // Test that role defaults to "tool" when not present in JSON + let json = r#"{"messageId":"550e8400-e29b-41d4-a716-446655440000","toolCallId":"6ba7b810-9dad-11d1-80b4-00c04fd430c8","content":"result"}"#; + let event: ToolCallResultEvent = serde_json::from_str(json).unwrap(); + + assert_eq!(event.role, Role::Tool); + } + + // ========================================================================= + // Run Lifecycle Event Tests + // ========================================================================= + + #[test] + fn test_run_started_event() { + use crate::types::{RunId, ThreadId}; + + let event = RunStartedEvent::new(ThreadId::random(), RunId::random()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"threadId\"")); + assert!(json.contains("\"runId\"")); + } + + #[test] + fn test_run_finished_event() { + use crate::types::{RunId, ThreadId}; + + let event = RunFinishedEvent::new(ThreadId::random(), RunId::random()); + + assert!(event.result.is_none()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"threadId\"")); + assert!(json.contains("\"runId\"")); + assert!(!json.contains("\"result\"")); // skipped when None + } + + #[test] + fn test_run_finished_event_with_result() { + use crate::types::{RunId, ThreadId}; + + let event = RunFinishedEvent::new(ThreadId::random(), RunId::random()) + .with_result(serde_json::json!({"success": true})); + + assert!(event.result.is_some()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"result\"")); + assert!(json.contains("\"success\":true")); + } + + #[test] + fn test_run_error_event() { + let event = RunErrorEvent::new("Connection timeout"); + + assert_eq!(event.message, "Connection timeout"); + assert!(event.code.is_none()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"message\":\"Connection timeout\"")); + assert!(!json.contains("\"code\"")); // skipped when None + } + + #[test] + fn test_run_error_event_with_code() { + let event = RunErrorEvent::new("Rate limit exceeded").with_code("RATE_LIMITED"); + + assert_eq!(event.code, Some("RATE_LIMITED".to_string())); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"code\":\"RATE_LIMITED\"")); + } + + // ========================================================================= + // Interrupt Tests + // ========================================================================= + + #[test] + fn test_run_finished_outcome_serialization() { + // Test SCREAMING_SNAKE_CASE serialization + let success = RunFinishedOutcome::Success; + let interrupt = RunFinishedOutcome::Interrupt; + + let success_json = serde_json::to_string(&success).unwrap(); + let interrupt_json = serde_json::to_string(&interrupt).unwrap(); + + assert_eq!(success_json, "\"SUCCESS\""); + assert_eq!(interrupt_json, "\"INTERRUPT\""); + + // Test deserialization + let deserialized: RunFinishedOutcome = serde_json::from_str("\"SUCCESS\"").unwrap(); + assert_eq!(deserialized, RunFinishedOutcome::Success); + + let deserialized: RunFinishedOutcome = serde_json::from_str("\"INTERRUPT\"").unwrap(); + assert_eq!(deserialized, RunFinishedOutcome::Interrupt); + } + + #[test] + fn test_run_finished_outcome_default() { + let outcome = RunFinishedOutcome::default(); + assert_eq!(outcome, RunFinishedOutcome::Success); + } + + #[test] + fn test_interrupt_info_empty() { + let info = InterruptInfo::new(); + + assert!(info.id.is_none()); + assert!(info.reason.is_none()); + assert!(info.payload.is_none()); + + // Empty struct should serialize to {} + let json = serde_json::to_string(&info).unwrap(); + assert_eq!(json, "{}"); + } + + #[test] + fn test_interrupt_info_with_all_fields() { + let info = InterruptInfo::new() + .with_id("approval-001") + .with_reason("human_approval") + .with_payload(serde_json::json!({"action": "delete", "rows": 42})); + + assert_eq!(info.id, Some("approval-001".to_string())); + assert_eq!(info.reason, Some("human_approval".to_string())); + assert!(info.payload.is_some()); + + let json = serde_json::to_string(&info).unwrap(); + assert!(json.contains("\"id\":\"approval-001\"")); + assert!(json.contains("\"reason\":\"human_approval\"")); + assert!(json.contains("\"action\":\"delete\"")); + } + + #[test] + fn test_run_finished_event_with_interrupt() { + use crate::types::{RunId, ThreadId}; + + let event = RunFinishedEvent::new(ThreadId::random(), RunId::random()) + .with_interrupt( + InterruptInfo::new() + .with_reason("human_approval") + .with_payload(serde_json::json!({"proposal": "send email"})) + ); + + // with_interrupt sets outcome to Interrupt + assert_eq!(event.outcome, Some(RunFinishedOutcome::Interrupt)); + assert!(event.interrupt.is_some()); + assert!(event.result.is_none()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"outcome\":\"INTERRUPT\"")); + assert!(json.contains("\"interrupt\"")); + assert!(json.contains("\"reason\":\"human_approval\"")); + } + + #[test] + fn test_run_finished_event_backward_compatibility() { + use crate::types::{RunId, ThreadId}; + + // Old-style event without outcome field + let event = RunFinishedEvent::new(ThreadId::random(), RunId::random()) + .with_result(serde_json::json!({"done": true})); + + // outcome is None (backward compat) + assert!(event.outcome.is_none()); + assert!(event.interrupt.is_none()); + + // effective_outcome should infer Success + assert_eq!(event.effective_outcome(), RunFinishedOutcome::Success); + + let json = serde_json::to_string(&event).unwrap(); + assert!(!json.contains("\"outcome\"")); // skipped when None + } + + #[test] + fn test_run_finished_event_effective_outcome() { + use crate::types::{RunId, ThreadId}; + + // No outcome, no interrupt → Success + let event1 = RunFinishedEvent::new(ThreadId::random(), RunId::random()); + assert_eq!(event1.effective_outcome(), RunFinishedOutcome::Success); + + // No outcome, has interrupt → Interrupt + let mut event2 = RunFinishedEvent::new(ThreadId::random(), RunId::random()); + event2.interrupt = Some(InterruptInfo::new()); + assert_eq!(event2.effective_outcome(), RunFinishedOutcome::Interrupt); + + // Explicit outcome overrides + let event3 = RunFinishedEvent::new(ThreadId::random(), RunId::random()) + .with_outcome(RunFinishedOutcome::Interrupt); + assert_eq!(event3.effective_outcome(), RunFinishedOutcome::Interrupt); + } + + // ========================================================================= + // Step Event Tests + // ========================================================================= + + #[test] + fn test_step_started_event() { + let event = StepStartedEvent::new("process_input"); + + assert_eq!(event.step_name, "process_input"); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"stepName\":\"process_input\"")); + } + + #[test] + fn test_step_finished_event() { + let event = StepFinishedEvent::new("generate_response"); + + assert_eq!(event.step_name, "generate_response"); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"stepName\":\"generate_response\"")); + } + + #[test] + fn test_step_events_with_timestamp() { + let start = StepStartedEvent::new("step1").with_timestamp(1234567890.0); + let end = StepFinishedEvent::new("step1").with_timestamp(1234567891.0); + + assert_eq!(start.base.timestamp, Some(1234567890.0)); + assert_eq!(end.base.timestamp, Some(1234567891.0)); + } + + // ========================================================================= + // State Event Tests + // ========================================================================= + + #[test] + fn test_state_snapshot_event() { + let event = StateSnapshotEvent::new(serde_json::json!({"count": 42})); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"snapshot\"")); + assert!(json.contains("\"count\":42")); + } + + #[test] + fn test_state_snapshot_event_default() { + let event: StateSnapshotEvent<()> = StateSnapshotEvent::default(); + assert!(event.base.timestamp.is_none()); + } + + #[test] + fn test_state_delta_event() { + let patches = vec![ + serde_json::json!({"op": "replace", "path": "/count", "value": 43}), + serde_json::json!({"op": "add", "path": "/new_field", "value": "hello"}), + ]; + let event = StateDeltaEvent::new(patches); + + assert_eq!(event.delta.len(), 2); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"delta\"")); + assert!(json.contains("\"op\":\"replace\"")); + } + + #[test] + fn test_state_delta_event_default() { + let event = StateDeltaEvent::default(); + assert!(event.delta.is_empty()); + } + + #[test] + fn test_messages_snapshot_event() { + use crate::types::{Message, MessageId}; + + let messages = vec![ + Message::User { + id: MessageId::random(), + content: "Hello".to_string(), + name: None, + }, + Message::Assistant { + id: MessageId::random(), + content: Some("Hi there!".to_string()), + name: None, + tool_calls: None, + }, + ]; + let event = MessagesSnapshotEvent::new(messages); + + assert_eq!(event.messages.len(), 2); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"messages\"")); + } + + #[test] + fn test_messages_snapshot_event_default() { + let event = MessagesSnapshotEvent::default(); + assert!(event.messages.is_empty()); + } + + // ========================================================================= + // Thinking Step Event Tests + // ========================================================================= + + #[test] + fn test_thinking_start_event() { + let event = ThinkingStartEvent::new(); + + assert!(event.title.is_none()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(!json.contains("\"title\"")); // skipped when None + } + + #[test] + fn test_thinking_start_event_with_title() { + let event = ThinkingStartEvent::new().with_title("Analyzing query"); + + assert_eq!(event.title, Some("Analyzing query".to_string())); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"title\":\"Analyzing query\"")); + } + + #[test] + fn test_thinking_end_event() { + let event = ThinkingEndEvent::new(); + + let json = serde_json::to_string(&event).unwrap(); + assert_eq!(json, "{}"); + } + + #[test] + fn test_thinking_step_events_default() { + let start = ThinkingStartEvent::default(); + let end = ThinkingEndEvent::default(); + + assert!(start.title.is_none()); + assert!(end.base.timestamp.is_none()); + } + + // ========================================================================= + // Special Event Tests + // ========================================================================= + + #[test] + fn test_raw_event() { + let event = RawEvent::new(serde_json::json!({"provider_data": "openai"})); + + assert!(event.source.is_none()); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"event\"")); + assert!(json.contains("\"provider_data\":\"openai\"")); + assert!(!json.contains("\"source\"")); // skipped when None + } + + #[test] + fn test_raw_event_with_source() { + let event = RawEvent::new(serde_json::json!({})).with_source("anthropic"); + + assert_eq!(event.source, Some("anthropic".to_string())); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"source\":\"anthropic\"")); + } + + #[test] + fn test_custom_event() { + let event = CustomEvent::new("user_action", serde_json::json!({"clicked": "button"})); + + assert_eq!(event.name, "user_action"); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"name\":\"user_action\"")); + assert!(json.contains("\"value\"")); + assert!(json.contains("\"clicked\":\"button\"")); + } + + // ========================================================================= + // Event Enum Tests + // ========================================================================= + + #[test] + fn test_event_enum_serialization() { + use crate::types::MessageId; + + let event: Event = Event::TextMessageStart(TextMessageStartEvent::new( + MessageId::random(), + )); + + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"type\":\"TEXT_MESSAGE_START\"")); + assert!(json.contains("\"messageId\"")); + assert!(json.contains("\"role\":\"assistant\"")); + } + + #[test] + fn test_event_enum_deserialization() { + let json = r#"{"type":"RUN_ERROR","message":"Test error"}"#; + let event: Event = serde_json::from_str(json).unwrap(); + + match event { + Event::RunError(e) => assert_eq!(e.message, "Test error"), + _ => panic!("Expected RunError variant"), + } + } + + #[test] + fn test_event_type_method() { + use crate::types::MessageId; + + let event: Event = Event::TextMessageEnd(TextMessageEndEvent::new(MessageId::random())); + assert_eq!(event.event_type(), EventType::TextMessageEnd); + + let event: Event = Event::RunStarted(RunStartedEvent::new( + crate::types::ThreadId::random(), + crate::types::RunId::random(), + )); + assert_eq!(event.event_type(), EventType::RunStarted); + + let event: Event = Event::Custom(CustomEvent::new("test", serde_json::json!({}))); + assert_eq!(event.event_type(), EventType::Custom); + } + + #[test] + fn test_event_timestamp_method() { + use crate::types::MessageId; + + let event: Event = Event::TextMessageStart( + TextMessageStartEvent::new(MessageId::random()) + .with_timestamp(1234567890.0), + ); + assert_eq!(event.timestamp(), Some(1234567890.0)); + + let event: Event = Event::ThinkingEnd(ThinkingEndEvent::new()); + assert_eq!(event.timestamp(), None); + } + + #[test] + fn test_event_all_variants_serialize() { + use crate::types::{Message, MessageId, RunId, ThreadId, ToolCallId}; + + // Test that all event variants can be serialized + let events: Vec = vec![ + Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())), + Event::TextMessageContent(TextMessageContentEvent::new_unchecked(MessageId::random(), "Hello")), + Event::TextMessageEnd(TextMessageEndEvent::new(MessageId::random())), + Event::TextMessageChunk(TextMessageChunkEvent::new(Role::Assistant).with_delta("Hi")), + Event::ThinkingTextMessageStart(ThinkingTextMessageStartEvent::new()), + Event::ThinkingTextMessageContent(ThinkingTextMessageContentEvent::new("thinking...")), + Event::ThinkingTextMessageEnd(ThinkingTextMessageEndEvent::new()), + Event::ToolCallStart(ToolCallStartEvent::new(ToolCallId::random(), "test_tool")), + Event::ToolCallArgs(ToolCallArgsEvent::new(ToolCallId::random(), "{}")), + Event::ToolCallEnd(ToolCallEndEvent::new(ToolCallId::random())), + Event::ToolCallChunk(ToolCallChunkEvent::new()), + Event::ToolCallResult(ToolCallResultEvent::new(MessageId::random(), ToolCallId::random(), "result")), + Event::ThinkingStart(ThinkingStartEvent::new()), + Event::ThinkingEnd(ThinkingEndEvent::new()), + Event::StateSnapshot(StateSnapshotEvent::new(serde_json::json!({}))), + Event::StateDelta(StateDeltaEvent::new(vec![])), + Event::MessagesSnapshot(MessagesSnapshotEvent::new(vec![Message::Assistant { + id: MessageId::random(), + content: Some("Hi".to_string()), + name: None, + tool_calls: None, + }])), + Event::ActivitySnapshot(ActivitySnapshotEvent::new(MessageId::random(), "PLAN", serde_json::json!({"steps": []}))), + Event::ActivityDelta(ActivityDeltaEvent::new(MessageId::random(), "PLAN", vec![serde_json::json!({"op": "add", "path": "/steps/-", "value": "test"})])), + Event::Raw(RawEvent::new(serde_json::json!({}))), + Event::Custom(CustomEvent::new("test", serde_json::json!({}))), + Event::RunStarted(RunStartedEvent::new(ThreadId::random(), RunId::random())), + Event::RunFinished(RunFinishedEvent::new(ThreadId::random(), RunId::random())), + Event::RunError(RunErrorEvent::new("error")), + Event::StepStarted(StepStartedEvent::new("step")), + Event::StepFinished(StepFinishedEvent::new("step")), + ]; + + for event in events { + let json = serde_json::to_string(&event).unwrap(); + assert!(json.contains("\"type\":")); + + // Roundtrip test + let deserialized: Event = serde_json::from_str(&json).unwrap(); + assert_eq!(event.event_type(), deserialized.event_type()); + } + } +} diff --git a/crates/ag-ui-core/src/lib.rs b/crates/ag-ui-core/src/lib.rs new file mode 100644 index 00000000..6b9da45b --- /dev/null +++ b/crates/ag-ui-core/src/lib.rs @@ -0,0 +1,64 @@ +//! AG-UI Core Types +//! +//! This crate provides the core type definitions for the AG-UI (Agent-User Interaction) +//! protocol. It includes event types, message structures, state management primitives, +//! and error handling for building AG-UI compatible agents. +//! +//! # Overview +//! +//! AG-UI is an event-based protocol that standardizes how AI agents communicate with +//! user-facing applications. This crate provides: +//! +//! - **Event types**: All ~25 AG-UI protocol event types (text messages, tool calls, state, etc.) +//! - **Message types**: Structured message formats for agent-user communication +//! - **State management**: State snapshots and JSON Patch delta operations +//! - **Error handling**: Comprehensive error types for protocol operations +//! +//! # Usage +//! +//! ```rust,ignore +//! use ag_ui_core::{Event, Result}; +//! ``` + +pub mod error; +pub mod event; +pub mod patch; +pub mod state; +pub mod types; + +// Re-export key types for convenience +pub use error::{AgUiError, Result}; + +/// Re-export serde_json::Value for consistent JSON handling across the crate +pub use serde_json::Value as JsonValue; + +// Re-export all types at crate root for convenient access +pub use types::*; + +// Re-export state traits and helpers +pub use state::{diff_states, AgentState, FwdProps, StateManager, TypedStateManager}; + +// Re-export event types +pub use event::{ + // Foundation types + BaseEvent, Event, EventType, EventValidationError, + // Text message events + TextMessageChunkEvent, TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent, + // Thinking text message events + ThinkingTextMessageContentEvent, ThinkingTextMessageEndEvent, ThinkingTextMessageStartEvent, + // Tool call events + ToolCallArgsEvent, ToolCallChunkEvent, ToolCallEndEvent, ToolCallResultEvent, + ToolCallStartEvent, + // Thinking step events + ThinkingEndEvent, ThinkingStartEvent, + // State events + MessagesSnapshotEvent, StateDeltaEvent, StateSnapshotEvent, + // Activity events + ActivityDeltaEvent, ActivitySnapshotEvent, + // Special events + CustomEvent, RawEvent, + // Run lifecycle events + InterruptInfo, RunErrorEvent, RunFinishedEvent, RunFinishedOutcome, RunStartedEvent, + // Step events + StepFinishedEvent, StepStartedEvent, +}; diff --git a/crates/ag-ui-core/src/patch.rs b/crates/ag-ui-core/src/patch.rs new file mode 100644 index 00000000..e43a8f64 --- /dev/null +++ b/crates/ag-ui-core/src/patch.rs @@ -0,0 +1,622 @@ +//! JSON Patch utilities for AG-UI state delta generation. +//! +//! This module provides utilities for working with JSON Patch (RFC 6902) +//! operations, enabling efficient state synchronization between agents and +//! frontends through delta updates. +//! +//! # Overview +//! +//! JSON Patch is a format for describing changes to a JSON document. Instead +//! of sending the entire state on every update, you can send just the changes +//! (patches) which is more efficient for large state objects. +//! +//! # Example +//! +//! ```rust +//! use ag_ui_core::patch::{create_patch, apply_patch}; +//! use serde_json::json; +//! +//! // Create a patch from two states +//! let old_state = json!({"count": 0, "items": []}); +//! let new_state = json!({"count": 1, "items": ["apple"]}); +//! +//! let patch = create_patch(&old_state, &new_state); +//! +//! // Apply patch to recreate the new state +//! let mut state = old_state.clone(); +//! apply_patch(&mut state, &patch).unwrap(); +//! assert_eq!(state, new_state); +//! ``` + +use serde_json::Value as JsonValue; +use std::error::Error; +use std::fmt; + +// Re-export json_patch types for convenience +pub use json_patch::{ + AddOperation, CopyOperation, MoveOperation, Patch, PatchOperation, RemoveOperation, + ReplaceOperation, TestOperation, +}; +use jsonptr::PointerBuf; + +/// Error type for patch operations. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PatchError { + message: String, +} + +impl PatchError { + /// Creates a new patch error with the given message. + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + } + } +} + +impl fmt::Display for PatchError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Patch error: {}", self.message) + } +} + +impl Error for PatchError {} + +impl From for PatchError { + fn from(err: json_patch::PatchError) -> Self { + Self::new(format!("{}", err)) + } +} + +/// Creates a JSON Patch representing the difference between two JSON values. +/// +/// The patch, when applied to `from`, will produce `to`. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::patch::create_patch; +/// use serde_json::json; +/// +/// let from = json!({"name": "Alice", "age": 30}); +/// let to = json!({"name": "Alice", "age": 31}); +/// +/// let patch = create_patch(&from, &to); +/// +/// // The patch contains a "replace" operation for the age field +/// assert!(!patch.0.is_empty()); +/// ``` +pub fn create_patch(from: &JsonValue, to: &JsonValue) -> Patch { + json_patch::diff(from, to) +} + +/// Applies a JSON Patch to a JSON value in place. +/// +/// # Errors +/// +/// Returns an error if any patch operation fails (e.g., path doesn't exist +/// for a remove operation, or test operation fails). +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::patch::{create_patch, apply_patch}; +/// use serde_json::json; +/// +/// let mut state = json!({"count": 0}); +/// let patch = create_patch(&json!({"count": 0}), &json!({"count": 5})); +/// +/// apply_patch(&mut state, &patch).unwrap(); +/// assert_eq!(state["count"], 5); +/// ``` +pub fn apply_patch(target: &mut JsonValue, patch: &Patch) -> Result<(), PatchError> { + json_patch::patch(target, patch.0.as_slice()).map_err(PatchError::from) +} + +/// Applies a JSON Patch from a JSON array representation. +/// +/// This is useful when you receive patches as raw JSON values (e.g., from +/// network events). +/// +/// # Errors +/// +/// Returns an error if the patch is not a valid JSON Patch array or if +/// any operation fails. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::patch::apply_patch_from_value; +/// use serde_json::json; +/// +/// let mut state = json!({"count": 0}); +/// let patch_json = json!([ +/// {"op": "replace", "path": "/count", "value": 10} +/// ]); +/// +/// apply_patch_from_value(&mut state, &patch_json).unwrap(); +/// assert_eq!(state["count"], 10); +/// ``` +pub fn apply_patch_from_value(target: &mut JsonValue, patch: &JsonValue) -> Result<(), PatchError> { + let patch: Patch = serde_json::from_value(patch.clone()) + .map_err(|e| PatchError::new(format!("Invalid patch format: {}", e)))?; + apply_patch(target, &patch) +} + +/// Converts a Patch to a JSON value for serialization. +/// +/// This is useful when you need to send patches over the network or +/// store them as JSON. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::patch::{create_patch, patch_to_value}; +/// use serde_json::json; +/// +/// let patch = create_patch( +/// &json!({"x": 1}), +/// &json!({"x": 2}), +/// ); +/// +/// let json = patch_to_value(&patch); +/// assert!(json.is_array()); +/// ``` +pub fn patch_to_value(patch: &Patch) -> JsonValue { + serde_json::to_value(patch).unwrap_or(JsonValue::Array(vec![])) +} + +/// Converts a Patch to a vector of JSON values. +/// +/// This is the format expected by StateDeltaEvent and ActivityDeltaEvent. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::patch::{create_patch, patch_to_vec}; +/// use serde_json::json; +/// +/// let patch = create_patch( +/// &json!({"items": []}), +/// &json!({"items": ["a"]}), +/// ); +/// +/// let ops = patch_to_vec(&patch); +/// // Each operation is a separate JSON object +/// assert!(!ops.is_empty()); +/// ``` +pub fn patch_to_vec(patch: &Patch) -> Vec { + patch + .0 + .iter() + .filter_map(|op| serde_json::to_value(op).ok()) + .collect() +} + +/// A builder for constructing JSON Patches programmatically. +/// +/// This provides a more ergonomic way to create patches when you know +/// exactly what operations you want to perform. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::patch::PatchBuilder; +/// use serde_json::json; +/// +/// let patch = PatchBuilder::new() +/// .add("/name", json!("Alice")) +/// .replace("/age", json!(31)) +/// .remove("/temp") +/// .build(); +/// +/// assert_eq!(patch.0.len(), 3); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct PatchBuilder { + operations: Vec, +} + +impl PatchBuilder { + /// Creates a new empty patch builder. + pub fn new() -> Self { + Self::default() + } + + /// Adds an "add" operation to the patch. + /// + /// The add operation adds a value at the target location. If the target + /// location specifies an array index, the value is inserted at that index. + pub fn add(mut self, path: impl AsRef, value: JsonValue) -> Self { + self.operations.push(PatchOperation::Add(AddOperation { + path: PointerBuf::parse(path.as_ref()).unwrap_or_default(), + value, + })); + self + } + + /// Adds a "remove" operation to the patch. + /// + /// The remove operation removes the value at the target location. + pub fn remove(mut self, path: impl AsRef) -> Self { + self.operations + .push(PatchOperation::Remove(RemoveOperation { + path: PointerBuf::parse(path.as_ref()).unwrap_or_default(), + })); + self + } + + /// Adds a "replace" operation to the patch. + /// + /// The replace operation replaces the value at the target location with + /// the new value. + pub fn replace(mut self, path: impl AsRef, value: JsonValue) -> Self { + self.operations + .push(PatchOperation::Replace(ReplaceOperation { + path: PointerBuf::parse(path.as_ref()).unwrap_or_default(), + value, + })); + self + } + + /// Adds a "move" operation to the patch. + /// + /// The move operation removes the value at a specified location and + /// adds it to the target location. + pub fn move_value(mut self, from: impl AsRef, path: impl AsRef) -> Self { + self.operations.push(PatchOperation::Move(MoveOperation { + from: PointerBuf::parse(from.as_ref()).unwrap_or_default(), + path: PointerBuf::parse(path.as_ref()).unwrap_or_default(), + })); + self + } + + /// Adds a "copy" operation to the patch. + /// + /// The copy operation copies the value at a specified location to the + /// target location. + pub fn copy(mut self, from: impl AsRef, path: impl AsRef) -> Self { + self.operations.push(PatchOperation::Copy(CopyOperation { + from: PointerBuf::parse(from.as_ref()).unwrap_or_default(), + path: PointerBuf::parse(path.as_ref()).unwrap_or_default(), + })); + self + } + + /// Adds a "test" operation to the patch. + /// + /// The test operation tests that a value at the target location is equal + /// to a specified value. If the test fails, the entire patch fails. + pub fn test(mut self, path: impl AsRef, value: JsonValue) -> Self { + self.operations.push(PatchOperation::Test(TestOperation { + path: PointerBuf::parse(path.as_ref()).unwrap_or_default(), + value, + })); + self + } + + /// Builds the patch from the accumulated operations. + pub fn build(self) -> Patch { + Patch(self.operations) + } + + /// Builds the patch and returns it as a vector of JSON values. + /// + /// This is the format expected by StateDeltaEvent and ActivityDeltaEvent. + pub fn build_vec(self) -> Vec { + patch_to_vec(&self.build()) + } +} + +/// Checks if applying a patch would succeed without actually modifying the target. +/// +/// This is useful for validation before committing to a patch operation. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::patch::{can_apply_patch, PatchBuilder}; +/// use serde_json::json; +/// +/// let state = json!({"count": 0}); +/// let valid_patch = PatchBuilder::new().replace("/count", json!(1)).build(); +/// let invalid_patch = PatchBuilder::new().remove("/nonexistent").build(); +/// +/// assert!(can_apply_patch(&state, &valid_patch)); +/// assert!(!can_apply_patch(&state, &invalid_patch)); +/// ``` +pub fn can_apply_patch(target: &JsonValue, patch: &Patch) -> bool { + let mut test_target = target.clone(); + apply_patch(&mut test_target, patch).is_ok() +} + +/// Merges two patches into one. +/// +/// The resulting patch applies the operations from the first patch followed +/// by operations from the second patch. +/// +/// Note: This is a simple concatenation and does not optimize or simplify +/// the resulting patch. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::patch::{merge_patches, PatchBuilder}; +/// use serde_json::json; +/// +/// let patch1 = PatchBuilder::new().add("/a", json!(1)).build(); +/// let patch2 = PatchBuilder::new().add("/b", json!(2)).build(); +/// +/// let merged = merge_patches(&patch1, &patch2); +/// assert_eq!(merged.0.len(), 2); +/// ``` +pub fn merge_patches(first: &Patch, second: &Patch) -> Patch { + let mut operations = first.0.clone(); + operations.extend(second.0.clone()); + Patch(operations) +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_create_patch_simple() { + let from = json!({"count": 0}); + let to = json!({"count": 5}); + + let patch = create_patch(&from, &to); + assert!(!patch.0.is_empty()); + + // Apply patch and verify + let mut result = from.clone(); + apply_patch(&mut result, &patch).unwrap(); + assert_eq!(result, to); + } + + #[test] + fn test_create_patch_add_field() { + let from = json!({"name": "Alice"}); + let to = json!({"name": "Alice", "age": 30}); + + let patch = create_patch(&from, &to); + + let mut result = from.clone(); + apply_patch(&mut result, &patch).unwrap(); + assert_eq!(result, to); + } + + #[test] + fn test_create_patch_remove_field() { + let from = json!({"name": "Alice", "temp": "value"}); + let to = json!({"name": "Alice"}); + + let patch = create_patch(&from, &to); + + let mut result = from.clone(); + apply_patch(&mut result, &patch).unwrap(); + assert_eq!(result, to); + } + + #[test] + fn test_create_patch_array_operations() { + let from = json!({"items": ["a", "b"]}); + let to = json!({"items": ["a", "b", "c"]}); + + let patch = create_patch(&from, &to); + + let mut result = from.clone(); + apply_patch(&mut result, &patch).unwrap(); + assert_eq!(result, to); + } + + #[test] + fn test_apply_patch_from_value() { + let mut state = json!({"count": 0}); + let patch_json = json!([ + {"op": "replace", "path": "/count", "value": 42} + ]); + + apply_patch_from_value(&mut state, &patch_json).unwrap(); + assert_eq!(state["count"], 42); + } + + #[test] + fn test_apply_patch_from_value_invalid() { + let mut state = json!({"count": 0}); + let invalid_patch = json!("not an array"); + + let result = apply_patch_from_value(&mut state, &invalid_patch); + assert!(result.is_err()); + } + + #[test] + fn test_patch_to_value() { + let patch = create_patch(&json!({"x": 1}), &json!({"x": 2})); + let value = patch_to_value(&patch); + + assert!(value.is_array()); + } + + #[test] + fn test_patch_to_vec() { + let patch = create_patch(&json!({"a": 1, "b": 2}), &json!({"a": 1, "b": 3, "c": 4})); + let ops = patch_to_vec(&patch); + + // Should have operations for changing b and adding c + assert!(!ops.is_empty()); + for op in &ops { + assert!(op.is_object()); + assert!(op.get("op").is_some()); + } + } + + #[test] + fn test_patch_builder_add() { + let patch = PatchBuilder::new() + .add("/name", json!("Alice")) + .build(); + + let mut state = json!({}); + apply_patch(&mut state, &patch).unwrap(); + assert_eq!(state["name"], "Alice"); + } + + #[test] + fn test_patch_builder_replace() { + let patch = PatchBuilder::new() + .replace("/count", json!(10)) + .build(); + + let mut state = json!({"count": 0}); + apply_patch(&mut state, &patch).unwrap(); + assert_eq!(state["count"], 10); + } + + #[test] + fn test_patch_builder_remove() { + let patch = PatchBuilder::new().remove("/temp").build(); + + let mut state = json!({"name": "Alice", "temp": "value"}); + apply_patch(&mut state, &patch).unwrap(); + assert!(state.get("temp").is_none()); + assert_eq!(state["name"], "Alice"); + } + + #[test] + fn test_patch_builder_move() { + let patch = PatchBuilder::new() + .move_value("/old", "/new") + .build(); + + let mut state = json!({"old": "value"}); + apply_patch(&mut state, &patch).unwrap(); + assert!(state.get("old").is_none()); + assert_eq!(state["new"], "value"); + } + + #[test] + fn test_patch_builder_copy() { + let patch = PatchBuilder::new() + .copy("/source", "/dest") + .build(); + + let mut state = json!({"source": "value"}); + apply_patch(&mut state, &patch).unwrap(); + assert_eq!(state["source"], "value"); + assert_eq!(state["dest"], "value"); + } + + #[test] + fn test_patch_builder_test() { + // Test operation succeeds + let patch = PatchBuilder::new() + .test("/count", json!(0)) + .replace("/count", json!(1)) + .build(); + + let mut state = json!({"count": 0}); + apply_patch(&mut state, &patch).unwrap(); + assert_eq!(state["count"], 1); + } + + #[test] + fn test_patch_builder_test_fails() { + let patch = PatchBuilder::new() + .test("/count", json!(999)) // Wrong value + .replace("/count", json!(1)) + .build(); + + let mut state = json!({"count": 0}); + let result = apply_patch(&mut state, &patch); + assert!(result.is_err()); + } + + #[test] + fn test_patch_builder_build_vec() { + let ops = PatchBuilder::new() + .add("/a", json!(1)) + .replace("/b", json!(2)) + .build_vec(); + + assert_eq!(ops.len(), 2); + } + + #[test] + fn test_can_apply_patch() { + let state = json!({"count": 0}); + + let valid_patch = PatchBuilder::new().replace("/count", json!(1)).build(); + assert!(can_apply_patch(&state, &valid_patch)); + + let invalid_patch = PatchBuilder::new().remove("/nonexistent").build(); + assert!(!can_apply_patch(&state, &invalid_patch)); + } + + #[test] + fn test_merge_patches() { + let patch1 = PatchBuilder::new().add("/a", json!(1)).build(); + let patch2 = PatchBuilder::new().add("/b", json!(2)).build(); + + let merged = merge_patches(&patch1, &patch2); + assert_eq!(merged.0.len(), 2); + + let mut state = json!({}); + apply_patch(&mut state, &merged).unwrap(); + assert_eq!(state["a"], 1); + assert_eq!(state["b"], 2); + } + + #[test] + fn test_patch_error_display() { + let err = PatchError::new("test error"); + assert!(err.to_string().contains("test error")); + } + + #[test] + fn test_complex_nested_patch() { + let from = json!({ + "user": { + "profile": { + "name": "Alice", + "settings": { + "theme": "light" + } + } + } + }); + + let to = json!({ + "user": { + "profile": { + "name": "Alice", + "settings": { + "theme": "dark", + "notifications": true + } + } + } + }); + + let patch = create_patch(&from, &to); + + let mut result = from.clone(); + apply_patch(&mut result, &patch).unwrap(); + assert_eq!(result, to); + } + + #[test] + fn test_empty_patch() { + let state = json!({"count": 0}); + let patch = create_patch(&state, &state); + + // Patch of identical values should be empty + assert!(patch.0.is_empty()); + + // Applying empty patch should be no-op + let mut result = state.clone(); + apply_patch(&mut result, &patch).unwrap(); + assert_eq!(result, state); + } +} diff --git a/crates/ag-ui-core/src/state.rs b/crates/ag-ui-core/src/state.rs new file mode 100644 index 00000000..b2df1f01 --- /dev/null +++ b/crates/ag-ui-core/src/state.rs @@ -0,0 +1,645 @@ +//! AG-UI State Management +//! +//! This module provides state management traits and utilities for AG-UI: +//! - `AgentState`: Marker trait for types that can represent agent state +//! - `FwdProps`: Marker trait for types that can be forwarded as props to UI +//! - `StateManager`: Helper for managing state and generating deltas +//! +//! These traits enable generic state handling in events while ensuring +//! the necessary bounds for serialization and async operations. +//! +//! # State Synchronization +//! +//! AG-UI supports two modes of state synchronization: +//! - **Snapshots**: Send the complete state (simpler but less efficient) +//! - **Deltas**: Send JSON Patch operations (more efficient for large states) +//! +//! The `StateManager` helper makes it easy to track state changes and +//! generate appropriate events. +//! +//! # Example +//! +//! ```rust +//! use ag_ui_core::state::StateManager; +//! use serde_json::json; +//! +//! let mut manager = StateManager::new(json!({"count": 0})); +//! +//! // Update state and get the delta +//! let delta = manager.update(json!({"count": 1})); +//! assert!(delta.is_some()); +//! +//! // Get current state +//! assert_eq!(manager.current()["count"], 1); +//! ``` + +use crate::patch::{create_patch, Patch}; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; +use std::fmt::Debug; + +/// Marker trait for types that can represent agent state. +/// +/// Types implementing this trait can be used as the state type in +/// state-related events (StateSnapshot, StateDelta, etc.). +/// +/// # Bounds +/// +/// - `'static`: Required for async operations +/// - `Debug`: For debugging and logging +/// - `Clone`: State may need to be copied +/// - `Send + Sync`: For thread-safe async operations +/// - `Serialize + Deserialize`: For JSON serialization +/// - `Default`: For initializing empty state +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::AgentState; +/// use serde::{Deserialize, Serialize}; +/// +/// #[derive(Debug, Clone, Default, Serialize, Deserialize)] +/// struct MyState { +/// counter: u32, +/// messages: Vec, +/// } +/// +/// impl AgentState for MyState {} +/// ``` +pub trait AgentState: + 'static + Debug + Clone + Send + Sync + for<'de> Deserialize<'de> + Serialize + Default +{ +} + +/// Marker trait for types that can be forwarded as props to UI components. +/// +/// Types implementing this trait can be passed through the AG-UI protocol +/// to frontend components as properties. +/// +/// # Bounds +/// +/// - `'static`: Required for async operations +/// - `Clone`: Props may need to be copied +/// - `Send + Sync`: For thread-safe async operations +/// - `Serialize + Deserialize`: For JSON serialization +/// - `Default`: For initializing empty props +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::FwdProps; +/// use serde::{Deserialize, Serialize}; +/// +/// #[derive(Clone, Default, Serialize, Deserialize)] +/// struct MyProps { +/// theme: String, +/// locale: String, +/// } +/// +/// impl FwdProps for MyProps {} +/// ``` +pub trait FwdProps: + 'static + Clone + Send + Sync + for<'de> Deserialize<'de> + Serialize + Default +{ +} + +// Implement AgentState for common types + +impl AgentState for JsonValue {} +impl AgentState for () {} + +// Implement FwdProps for common types + +impl FwdProps for JsonValue {} +impl FwdProps for () {} + +// ============================================================================= +// State Helper Utilities +// ============================================================================= + +/// Computes the difference between two JSON states as a JSON Patch. +/// +/// Returns `None` if the states are identical. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::state::diff_states; +/// use serde_json::json; +/// +/// let old = json!({"count": 0}); +/// let new = json!({"count": 5}); +/// +/// let patch = diff_states(&old, &new); +/// assert!(patch.is_some()); +/// ``` +pub fn diff_states(old: &JsonValue, new: &JsonValue) -> Option { + let patch = create_patch(old, new); + if patch.0.is_empty() { + None + } else { + Some(patch) + } +} + +/// A helper for managing state and generating deltas. +/// +/// `StateManager` tracks the current state and provides methods to update +/// it while automatically computing the JSON Patch delta between states. +/// This is useful for efficiently synchronizing state with frontends. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::state::StateManager; +/// use serde_json::json; +/// +/// let mut manager = StateManager::new(json!({"count": 0, "items": []})); +/// +/// // Update state - returns delta patch +/// let delta = manager.update(json!({"count": 1, "items": []})); +/// assert!(delta.is_some()); +/// +/// // No change - returns None +/// let delta = manager.update(json!({"count": 1, "items": []})); +/// assert!(delta.is_none()); +/// +/// // Check current state +/// assert_eq!(manager.current()["count"], 1); +/// ``` +#[derive(Debug, Clone)] +pub struct StateManager { + current: JsonValue, + version: u64, +} + +impl StateManager { + /// Creates a new state manager with the given initial state. + pub fn new(initial: JsonValue) -> Self { + Self { + current: initial, + version: 0, + } + } + + /// Returns a reference to the current state. + pub fn current(&self) -> &JsonValue { + &self.current + } + + /// Returns the current state version (increments on each update). + pub fn version(&self) -> u64 { + self.version + } + + /// Updates the state and returns the delta patch if there were changes. + /// + /// Returns `None` if the new state is identical to the current state. + pub fn update(&mut self, new_state: JsonValue) -> Option { + let patch = diff_states(&self.current, &new_state); + if patch.is_some() { + self.current = new_state; + self.version += 1; + } + patch + } + + /// Updates the state using a closure and returns the delta patch. + /// + /// The closure receives a mutable reference to the current state. + /// After the closure completes, the delta is computed. + /// + /// # Example + /// + /// ```rust + /// use ag_ui_core::state::StateManager; + /// use serde_json::json; + /// + /// let mut manager = StateManager::new(json!({"count": 0})); + /// + /// let delta = manager.update_with(|state| { + /// state["count"] = json!(10); + /// }); + /// + /// assert!(delta.is_some()); + /// assert_eq!(manager.current()["count"], 10); + /// ``` + pub fn update_with(&mut self, f: F) -> Option + where + F: FnOnce(&mut JsonValue), + { + let old_state = self.current.clone(); + f(&mut self.current); + let patch = diff_states(&old_state, &self.current); + if patch.is_some() { + self.version += 1; + } + patch + } + + /// Resets the state to a new value without computing a delta. + /// + /// Use this when you want to replace the entire state (e.g., on reconnection) + /// and will send a snapshot instead of a delta. + pub fn reset(&mut self, new_state: JsonValue) { + self.current = new_state; + self.version += 1; + } + + /// Takes a snapshot of the current state. + /// + /// Returns a clone of the current state value. + pub fn snapshot(&self) -> JsonValue { + self.current.clone() + } +} + +impl Default for StateManager { + fn default() -> Self { + Self::new(JsonValue::Object(serde_json::Map::new())) + } +} + +/// A typed state manager for custom state types. +/// +/// This provides the same functionality as `StateManager` but works with +/// strongly-typed state objects that implement `AgentState`. +/// +/// # Example +/// +/// ```rust +/// use ag_ui_core::state::TypedStateManager; +/// use ag_ui_core::AgentState; +/// use serde::{Deserialize, Serialize}; +/// +/// #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] +/// struct AppState { +/// count: u32, +/// user: Option, +/// } +/// +/// impl AgentState for AppState {} +/// +/// let mut manager = TypedStateManager::new(AppState { count: 0, user: None }); +/// +/// let delta = manager.update(AppState { count: 1, user: None }); +/// assert!(delta.is_some()); +/// +/// assert_eq!(manager.current().count, 1); +/// ``` +#[derive(Debug, Clone)] +pub struct TypedStateManager { + current: S, + version: u64, +} + +impl TypedStateManager { + /// Creates a new typed state manager with the given initial state. + pub fn new(initial: S) -> Self { + Self { + current: initial, + version: 0, + } + } + + /// Returns a reference to the current state. + pub fn current(&self) -> &S { + &self.current + } + + /// Returns the current state version (increments on each update). + pub fn version(&self) -> u64 { + self.version + } + + /// Updates the state and returns the delta patch if there were changes. + /// + /// Returns `None` if the new state is identical to the current state. + pub fn update(&mut self, new_state: S) -> Option { + if self.current == new_state { + return None; + } + + let old_json = serde_json::to_value(&self.current).ok()?; + let new_json = serde_json::to_value(&new_state).ok()?; + let patch = diff_states(&old_json, &new_json); + + self.current = new_state; + self.version += 1; + patch + } + + /// Resets the state to a new value without computing a delta. + pub fn reset(&mut self, new_state: S) { + self.current = new_state; + self.version += 1; + } + + /// Takes a snapshot of the current state as JSON. + pub fn snapshot(&self) -> JsonValue { + serde_json::to_value(&self.current).unwrap_or(JsonValue::Null) + } + + /// Returns the current state as a JSON value. + pub fn as_json(&self) -> JsonValue { + serde_json::to_value(&self.current).unwrap_or(JsonValue::Null) + } +} + +impl Default for TypedStateManager { + fn default() -> Self { + Self::new(S::default()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug, Clone, Default, Serialize, Deserialize)] + struct TestState { + value: i32, + } + + impl AgentState for TestState {} + + #[derive(Clone, Default, Serialize, Deserialize)] + struct TestProps { + name: String, + } + + impl FwdProps for TestProps {} + + #[test] + fn test_json_value_implements_agent_state() { + fn requires_agent_state(_: T) {} + requires_agent_state(JsonValue::Null); + } + + #[test] + fn test_unit_implements_agent_state() { + fn requires_agent_state(_: T) {} + requires_agent_state(()); + } + + #[test] + fn test_json_value_implements_fwd_props() { + fn requires_fwd_props(_: T) {} + requires_fwd_props(JsonValue::Null); + } + + #[test] + fn test_unit_implements_fwd_props() { + fn requires_fwd_props(_: T) {} + requires_fwd_props(()); + } + + #[test] + fn test_custom_state_type() { + fn requires_agent_state(_: T) {} + requires_agent_state(TestState { value: 42 }); + } + + #[test] + fn test_custom_props_type() { + fn requires_fwd_props(_: T) {} + requires_fwd_props(TestProps { + name: "test".to_string(), + }); + } + + // ========================================================================= + // State Helper Tests + // ========================================================================= + + #[test] + fn test_diff_states_with_changes() { + use serde_json::json; + + let old = json!({"count": 0}); + let new = json!({"count": 5}); + + let patch = diff_states(&old, &new); + assert!(patch.is_some()); + } + + #[test] + fn test_diff_states_no_changes() { + use serde_json::json; + + let state = json!({"count": 0}); + + let patch = diff_states(&state, &state); + assert!(patch.is_none()); + } + + #[test] + fn test_state_manager_new() { + use serde_json::json; + + let manager = StateManager::new(json!({"count": 0})); + assert_eq!(manager.current()["count"], 0); + assert_eq!(manager.version(), 0); + } + + #[test] + fn test_state_manager_update_with_changes() { + use serde_json::json; + + let mut manager = StateManager::new(json!({"count": 0})); + + let delta = manager.update(json!({"count": 5})); + assert!(delta.is_some()); + assert_eq!(manager.current()["count"], 5); + assert_eq!(manager.version(), 1); + } + + #[test] + fn test_state_manager_update_no_changes() { + use serde_json::json; + + let mut manager = StateManager::new(json!({"count": 0})); + + let delta = manager.update(json!({"count": 0})); + assert!(delta.is_none()); + assert_eq!(manager.version(), 0); // Version shouldn't increment + } + + #[test] + fn test_state_manager_update_with_closure() { + use serde_json::json; + + let mut manager = StateManager::new(json!({"count": 0})); + + let delta = manager.update_with(|state| { + state["count"] = json!(10); + }); + + assert!(delta.is_some()); + assert_eq!(manager.current()["count"], 10); + assert_eq!(manager.version(), 1); + } + + #[test] + fn test_state_manager_update_with_no_changes() { + use serde_json::json; + + let mut manager = StateManager::new(json!({"count": 0})); + + let delta = manager.update_with(|_state| { + // No changes + }); + + assert!(delta.is_none()); + assert_eq!(manager.version(), 0); + } + + #[test] + fn test_state_manager_reset() { + use serde_json::json; + + let mut manager = StateManager::new(json!({"count": 0})); + manager.reset(json!({"count": 100, "new_field": true})); + + assert_eq!(manager.current()["count"], 100); + assert_eq!(manager.current()["new_field"], true); + assert_eq!(manager.version(), 1); + } + + #[test] + fn test_state_manager_snapshot() { + use serde_json::json; + + let manager = StateManager::new(json!({"count": 42})); + let snapshot = manager.snapshot(); + + assert_eq!(snapshot, json!({"count": 42})); + } + + #[test] + fn test_state_manager_default() { + let manager = StateManager::default(); + assert!(manager.current().is_object()); + assert_eq!(manager.version(), 0); + } + + #[test] + fn test_state_manager_multiple_updates() { + use serde_json::json; + + let mut manager = StateManager::new(json!({"count": 0})); + + manager.update(json!({"count": 1})); + manager.update(json!({"count": 2})); + manager.update(json!({"count": 3})); + + assert_eq!(manager.current()["count"], 3); + assert_eq!(manager.version(), 3); + } + + // ========================================================================= + // TypedStateManager Tests + // ========================================================================= + + #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)] + struct AppState { + count: u32, + name: String, + } + + impl AgentState for AppState {} + + #[test] + fn test_typed_state_manager_new() { + let manager = TypedStateManager::new(AppState { + count: 0, + name: "test".to_string(), + }); + + assert_eq!(manager.current().count, 0); + assert_eq!(manager.current().name, "test"); + assert_eq!(manager.version(), 0); + } + + #[test] + fn test_typed_state_manager_update() { + let mut manager = TypedStateManager::new(AppState { + count: 0, + name: "test".to_string(), + }); + + let delta = manager.update(AppState { + count: 5, + name: "test".to_string(), + }); + + assert!(delta.is_some()); + assert_eq!(manager.current().count, 5); + assert_eq!(manager.version(), 1); + } + + #[test] + fn test_typed_state_manager_update_no_changes() { + let mut manager = TypedStateManager::new(AppState { + count: 0, + name: "test".to_string(), + }); + + let delta = manager.update(AppState { + count: 0, + name: "test".to_string(), + }); + + assert!(delta.is_none()); + assert_eq!(manager.version(), 0); + } + + #[test] + fn test_typed_state_manager_reset() { + let mut manager = TypedStateManager::new(AppState { + count: 0, + name: "old".to_string(), + }); + + manager.reset(AppState { + count: 100, + name: "new".to_string(), + }); + + assert_eq!(manager.current().count, 100); + assert_eq!(manager.current().name, "new"); + assert_eq!(manager.version(), 1); + } + + #[test] + fn test_typed_state_manager_snapshot() { + let manager = TypedStateManager::new(AppState { + count: 42, + name: "test".to_string(), + }); + + let snapshot = manager.snapshot(); + assert_eq!(snapshot["count"], 42); + assert_eq!(snapshot["name"], "test"); + } + + #[test] + fn test_typed_state_manager_as_json() { + let manager = TypedStateManager::new(AppState { + count: 10, + name: "hello".to_string(), + }); + + let json = manager.as_json(); + assert_eq!(json["count"], 10); + assert_eq!(json["name"], "hello"); + } + + #[test] + fn test_typed_state_manager_default() { + let manager: TypedStateManager = TypedStateManager::default(); + assert_eq!(manager.current().count, 0); + assert_eq!(manager.current().name, ""); + assert_eq!(manager.version(), 0); + } +} diff --git a/crates/ag-ui-core/src/types/content.rs b/crates/ag-ui-core/src/types/content.rs new file mode 100644 index 00000000..e76b9d4e --- /dev/null +++ b/crates/ag-ui-core/src/types/content.rs @@ -0,0 +1,451 @@ +//! Content types for AG-UI protocol multimodal messages. +//! +//! This module defines input content types for handling text and binary +//! content in messages, enabling multimodal agent interactions. + +use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::fmt; + +/// Error type for content validation failures. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ContentValidationError { + message: String, +} + +impl ContentValidationError { + /// Creates a new validation error with the given message. + pub fn new(message: impl Into) -> Self { + Self { + message: message.into(), + } + } +} + +impl fmt::Display for ContentValidationError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.message) + } +} + +impl Error for ContentValidationError {} + +/// Text input content for messages. +/// +/// Represents plain text content in a message. +/// +/// # Example +/// +/// ``` +/// use ag_ui_core::TextInputContent; +/// +/// let content = TextInputContent::new("Hello, world!"); +/// assert_eq!(content.text, "Hello, world!"); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TextInputContent { + /// The content type discriminator, always "text". + #[serde(rename = "type")] + pub type_tag: String, + /// The text content. + pub text: String, +} + +impl TextInputContent { + /// Creates a new text input content. + pub fn new(text: impl Into) -> Self { + Self { + type_tag: "text".to_string(), + text: text.into(), + } + } +} + +/// Binary input content for multimodal messages. +/// +/// Represents binary content such as images, files, or other media. +/// At least one of `id`, `url`, or `data` must be provided. +/// +/// # Example +/// +/// ``` +/// use ag_ui_core::BinaryInputContent; +/// +/// let content = BinaryInputContent::new("image/png") +/// .with_url("https://example.com/image.png") +/// .with_filename("screenshot.png"); +/// +/// assert_eq!(content.mime_type, "image/png"); +/// assert_eq!(content.url, Some("https://example.com/image.png".to_string())); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct BinaryInputContent { + /// The content type discriminator, always "binary". + #[serde(rename = "type")] + pub type_tag: String, + /// The MIME type of the binary content. + #[serde(rename = "mimeType")] + pub mime_type: String, + /// Optional identifier for the content. + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + /// Optional URL where the content can be fetched. + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, + /// Optional base64-encoded data. + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, + /// Optional filename for the content. + #[serde(skip_serializing_if = "Option::is_none")] + pub filename: Option, +} + +impl BinaryInputContent { + /// Creates a new binary input content with the given MIME type. + pub fn new(mime_type: impl Into) -> Self { + Self { + type_tag: "binary".to_string(), + mime_type: mime_type.into(), + id: None, + url: None, + data: None, + filename: None, + } + } + + /// Sets the content identifier. + pub fn with_id(mut self, id: impl Into) -> Self { + self.id = Some(id.into()); + self + } + + /// Sets the content URL. + pub fn with_url(mut self, url: impl Into) -> Self { + self.url = Some(url.into()); + self + } + + /// Sets the base64-encoded data. + pub fn with_data(mut self, data: impl Into) -> Self { + self.data = Some(data.into()); + self + } + + /// Sets the filename. + pub fn with_filename(mut self, filename: impl Into) -> Self { + self.filename = Some(filename.into()); + self + } + + /// Validates that at least one of id, url, or data is present. + pub fn validate(&self) -> Result<(), ContentValidationError> { + if self.id.is_none() && self.url.is_none() && self.data.is_none() { + return Err(ContentValidationError::new( + "BinaryInputContent requires at least one of: id, url, or data", + )); + } + Ok(()) + } +} + +/// Input content union type for multimodal messages. +/// +/// This is a discriminated union that can hold either text or binary content. +/// The `type` field in JSON determines which variant is used. +/// +/// # Example +/// +/// ``` +/// use ag_ui_core::InputContent; +/// +/// // Create text content +/// let text = InputContent::text("Hello!"); +/// assert!(text.is_text()); +/// +/// // Create binary content with URL +/// let binary = InputContent::binary_with_url("image/jpeg", "https://example.com/img.jpg"); +/// assert!(binary.is_binary()); +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +pub enum InputContent { + /// Text content variant. + Text { + /// The text content. + text: String, + }, + /// Binary content variant for images, files, etc. + Binary { + /// The MIME type of the binary content. + #[serde(rename = "mimeType")] + mime_type: String, + /// Optional identifier for the content. + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + /// Optional URL where the content can be fetched. + #[serde(skip_serializing_if = "Option::is_none")] + url: Option, + /// Optional base64-encoded data. + #[serde(skip_serializing_if = "Option::is_none")] + data: Option, + /// Optional filename for the content. + #[serde(skip_serializing_if = "Option::is_none")] + filename: Option, + }, +} + +impl InputContent { + /// Creates a text content variant. + pub fn text(text: impl Into) -> Self { + Self::Text { text: text.into() } + } + + /// Creates a minimal binary content variant. + pub fn binary(mime_type: impl Into) -> Self { + Self::Binary { + mime_type: mime_type.into(), + id: None, + url: None, + data: None, + filename: None, + } + } + + /// Creates a binary content variant with a URL. + pub fn binary_with_url(mime_type: impl Into, url: impl Into) -> Self { + Self::Binary { + mime_type: mime_type.into(), + id: None, + url: Some(url.into()), + data: None, + filename: None, + } + } + + /// Creates a binary content variant with base64-encoded data. + pub fn binary_with_data(mime_type: impl Into, data: impl Into) -> Self { + Self::Binary { + mime_type: mime_type.into(), + id: None, + url: None, + data: Some(data.into()), + filename: None, + } + } + + /// Returns true if this is text content. + pub fn is_text(&self) -> bool { + matches!(self, Self::Text { .. }) + } + + /// Returns true if this is binary content. + pub fn is_binary(&self) -> bool { + matches!(self, Self::Binary { .. }) + } + + /// Returns the text content if this is a text variant. + pub fn as_text(&self) -> Option<&str> { + match self { + Self::Text { text } => Some(text), + Self::Binary { .. } => None, + } + } + + /// Validates the content. + /// + /// For text content, always succeeds. + /// For binary content, validates that at least one of id, url, or data is present. + pub fn validate(&self) -> Result<(), ContentValidationError> { + match self { + Self::Text { .. } => Ok(()), + Self::Binary { + id, url, data, .. + } => { + if id.is_none() && url.is_none() && data.is_none() { + Err(ContentValidationError::new( + "Binary content requires at least one of: id, url, or data", + )) + } else { + Ok(()) + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Test 1: TextInputContent serialization + #[test] + fn test_text_input_content_serialization() { + let content = TextInputContent::new("Hello, world!"); + let json = serde_json::to_string(&content).unwrap(); + + assert!(json.contains("\"type\":\"text\"")); + assert!(json.contains("\"text\":\"Hello, world!\"")); + } + + // Test 2: TextInputContent deserialization + #[test] + fn test_text_input_content_deserialization() { + let json = r#"{"type":"text","text":"Hello!"}"#; + let content: TextInputContent = serde_json::from_str(json).unwrap(); + + assert_eq!(content.type_tag, "text"); + assert_eq!(content.text, "Hello!"); + } + + // Test 3: BinaryInputContent serialization + #[test] + fn test_binary_input_content_serialization() { + let content = BinaryInputContent::new("image/png") + .with_url("https://example.com/img.png") + .with_filename("test.png"); + + let json = serde_json::to_string(&content).unwrap(); + + assert!(json.contains("\"type\":\"binary\"")); + assert!(json.contains("\"mimeType\":\"image/png\"")); + assert!(json.contains("\"url\":\"https://example.com/img.png\"")); + assert!(json.contains("\"filename\":\"test.png\"")); + // Optional fields should be omitted when None + assert!(!json.contains("\"id\"")); + assert!(!json.contains("\"data\"")); + } + + // Test 4: BinaryInputContent builder pattern + #[test] + fn test_binary_input_content_builder() { + let content = BinaryInputContent::new("application/pdf") + .with_id("file-123") + .with_url("https://example.com/doc.pdf") + .with_data("base64data") + .with_filename("document.pdf"); + + assert_eq!(content.mime_type, "application/pdf"); + assert_eq!(content.id, Some("file-123".to_string())); + assert_eq!(content.url, Some("https://example.com/doc.pdf".to_string())); + assert_eq!(content.data, Some("base64data".to_string())); + assert_eq!(content.filename, Some("document.pdf".to_string())); + } + + // Test 5: InputContent text variant + #[test] + fn test_input_content_text_variant() { + let content = InputContent::text("Hello!"); + + assert!(content.is_text()); + assert!(!content.is_binary()); + assert_eq!(content.as_text(), Some("Hello!")); + } + + // Test 6: InputContent binary variant + #[test] + fn test_input_content_binary_variant() { + let content = InputContent::binary_with_url("image/jpeg", "https://example.com/img.jpg"); + + assert!(!content.is_text()); + assert!(content.is_binary()); + assert_eq!(content.as_text(), None); + } + + // Test 7: InputContent discriminated union serialization + #[test] + fn test_input_content_discriminated_union() { + // Text variant + let text = InputContent::text("Hello"); + let text_json = serde_json::to_string(&text).unwrap(); + assert!(text_json.contains("\"type\":\"text\"")); + + // Binary variant + let binary = InputContent::binary_with_url("image/png", "https://example.com/img.png"); + let binary_json = serde_json::to_string(&binary).unwrap(); + assert!(binary_json.contains("\"type\":\"binary\"")); + + // Deserialize text + let parsed_text: InputContent = serde_json::from_str(&text_json).unwrap(); + assert!(parsed_text.is_text()); + + // Deserialize binary + let parsed_binary: InputContent = serde_json::from_str(&binary_json).unwrap(); + assert!(parsed_binary.is_binary()); + } + + // Test 8: Binary validation success + #[test] + fn test_binary_validation_success() { + let with_url = BinaryInputContent::new("image/png").with_url("https://example.com/img.png"); + assert!(with_url.validate().is_ok()); + + let with_data = BinaryInputContent::new("image/png").with_data("base64data"); + assert!(with_data.validate().is_ok()); + + let with_id = BinaryInputContent::new("image/png").with_id("file-123"); + assert!(with_id.validate().is_ok()); + } + + // Test 9: Binary validation failure + #[test] + fn test_binary_validation_failure() { + let empty = BinaryInputContent::new("image/png"); + let result = empty.validate(); + + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(err.to_string().contains("at least one of")); + } + + // Test 10: InputContent roundtrip + #[test] + fn test_input_content_roundtrip() { + // Text roundtrip + let text = InputContent::text("Hello, world!"); + let text_json = serde_json::to_string(&text).unwrap(); + let text_parsed: InputContent = serde_json::from_str(&text_json).unwrap(); + assert_eq!(text, text_parsed); + + // Binary roundtrip + let binary = InputContent::Binary { + mime_type: "image/png".to_string(), + id: Some("img-123".to_string()), + url: Some("https://example.com/img.png".to_string()), + data: Some("iVBORw0KGgo=".to_string()), + filename: Some("screenshot.png".to_string()), + }; + let binary_json = serde_json::to_string(&binary).unwrap(); + let binary_parsed: InputContent = serde_json::from_str(&binary_json).unwrap(); + assert_eq!(binary, binary_parsed); + } + + // Test 11: InputContent validation + #[test] + fn test_input_content_validation() { + // Text always valid + let text = InputContent::text("Hello"); + assert!(text.validate().is_ok()); + + // Binary with url is valid + let binary_valid = InputContent::binary_with_url("image/png", "https://example.com/img.png"); + assert!(binary_valid.validate().is_ok()); + + // Binary without id/url/data is invalid + let binary_invalid = InputContent::binary("image/png"); + assert!(binary_invalid.validate().is_err()); + } + + // Test 12: BinaryInputContent deserialization + #[test] + fn test_binary_input_content_deserialization() { + let json = r#"{"type":"binary","mimeType":"image/jpeg","url":"https://example.com/img.jpg"}"#; + let content: BinaryInputContent = serde_json::from_str(json).unwrap(); + + assert_eq!(content.type_tag, "binary"); + assert_eq!(content.mime_type, "image/jpeg"); + assert_eq!(content.url, Some("https://example.com/img.jpg".to_string())); + assert_eq!(content.id, None); + assert_eq!(content.data, None); + assert_eq!(content.filename, None); + } +} diff --git a/crates/ag-ui-core/src/types/ids.rs b/crates/ag-ui-core/src/types/ids.rs new file mode 100644 index 00000000..7cd14812 --- /dev/null +++ b/crates/ag-ui-core/src/types/ids.rs @@ -0,0 +1,156 @@ +//! ID types for the AG-UI protocol. +//! +//! This module provides strongly-typed ID newtypes to prevent mixing up +//! different ID types (e.g., passing a MessageId where a ThreadId is expected). + +use serde::{Deserialize, Serialize}; +use std::ops::Deref; +use uuid::Uuid; + +/// Macro to define a newtype ID based on Uuid. +macro_rules! define_id_type { + // This arm of the macro handles calls that don't specify extra derives. + ($name:ident) => { + define_id_type!($name,); + }; + // This arm handles calls that do specify extra derives (like Eq). + ($name:ident, $($extra_derive:ident),*) => { + #[doc = concat!(stringify!($name), ": A newtype used to prevent mixing it with other ID values.")] + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq, Hash, $($extra_derive),*)] + pub struct $name(Uuid); + + impl $name { + /// Creates a new random ID. + pub fn random() -> Self { + Self(Uuid::new_v4()) + } + } + + /// Allows creating an ID from a Uuid. + impl From for $name { + fn from(uuid: Uuid) -> Self { + Self(uuid) + } + } + + /// Allows converting an ID back into a Uuid. + impl From<$name> for Uuid { + fn from(id: $name) -> Self { + id.0 + } + } + + /// Allows getting a reference to the inner Uuid. + impl AsRef for $name { + fn as_ref(&self) -> &Uuid { + &self.0 + } + } + + /// Allows printing the ID. + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } + } + + /// Allows parsing an ID from a string slice. + impl std::str::FromStr for $name { + type Err = uuid::Error; + + fn from_str(s: &str) -> Result { + Ok(Self(Uuid::parse_str(s)?)) + } + } + + /// Allows comparing the ID with a Uuid. + impl PartialEq for $name { + fn eq(&self, other: &Uuid) -> bool { + self.0 == *other + } + } + + /// Allows comparing the ID with a string slice. + impl PartialEq for $name { + fn eq(&self, other: &str) -> bool { + if let Ok(uuid) = Uuid::parse_str(other) { + self.0 == uuid + } else { + false + } + } + } + }; +} + +// Define UUID-based ID types using the macro +define_id_type!(AgentId); +define_id_type!(ThreadId); +define_id_type!(RunId); +define_id_type!(MessageId); + +/// A tool call ID. +/// +/// Used by some providers to denote a specific ID for a tool call generation, +/// where the result of the tool call must also use this ID. +/// +/// Does not follow UUID format, instead uses "call_xxxxxxxx" format. +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize, Clone)] +pub struct ToolCallId(String); + +impl ToolCallId { + /// Creates a new random tool call ID in the format "call_xxxxxxxx". + pub fn random() -> Self { + let uuid = &Uuid::new_v4().to_string()[..8]; + let id = format!("call_{uuid}"); + Self(id) + } +} + +impl Deref for ToolCallId { + type Target = str; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl> From for ToolCallId { + fn from(s: S) -> Self { + Self(s.into()) + } +} + +impl std::fmt::Display for ToolCallId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Test whether tool call ID has the expected format + #[test] + fn test_tool_call_random() { + let id = ToolCallId::random(); + assert_eq!(id.0.len(), 5 + 8); // "call_" + 8 hex chars + assert!(id.0.starts_with("call_")); + } + + /// Test UUID-based ID creation and conversion + #[test] + fn test_message_id_random() { + let id = MessageId::random(); + let uuid: Uuid = id.clone().into(); + assert_eq!(id, uuid); + } + + /// Test ID parsing from string + #[test] + fn test_id_from_str() { + let uuid_str = "550e8400-e29b-41d4-a716-446655440000"; + let id: MessageId = uuid_str.parse().unwrap(); + assert_eq!(id, *uuid_str); // Dereference &str to str for PartialEq + } +} diff --git a/crates/ag-ui-core/src/types/input.rs b/crates/ag-ui-core/src/types/input.rs new file mode 100644 index 00000000..64593c39 --- /dev/null +++ b/crates/ag-ui-core/src/types/input.rs @@ -0,0 +1,289 @@ +//! Input types for AG-UI protocol requests. +//! +//! This module defines types for handling client requests to AG-UI agents, +//! including the main `RunAgentInput` request type and supporting types. + +use crate::types::ids::{RunId, ThreadId}; +use crate::types::message::Message; +use crate::types::tool::Tool; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; + +/// Context information provided to an agent. +/// +/// Context items provide additional information to help the agent +/// understand the user's request or environment. +/// +/// # Example +/// +/// ``` +/// use ag_ui_core::Context; +/// +/// let ctx = Context::new( +/// "current_page".to_string(), +/// "https://example.com/dashboard".to_string(), +/// ); +/// +/// assert_eq!(ctx.description, "current_page"); +/// assert_eq!(ctx.value, "https://example.com/dashboard"); +/// ``` +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Context { + /// A description of what this context represents. + pub description: String, + /// The value of the context. + pub value: String, +} + +impl Context { + /// Creates a new context with the given description and value. + pub fn new(description: String, value: String) -> Self { + Self { description, value } + } +} + +/// Input for running an agent. +/// +/// This is the primary request type sent by clients to start or continue +/// an agent run. It contains the thread and run identifiers, conversation +/// messages, available tools, context, and any custom state. +/// +/// # Example +/// +/// ``` +/// use ag_ui_core::{RunAgentInput, Context, Message, ThreadId, RunId}; +/// +/// let input = RunAgentInput::new(ThreadId::random(), RunId::random()) +/// .with_messages(vec![Message::new_user("Hello!")]) +/// .with_context(vec![ +/// Context::new("timezone".to_string(), "UTC".to_string()), +/// ]); +/// +/// assert!(input.messages.len() == 1); +/// assert!(input.context.len() == 1); +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RunAgentInput { + /// The thread identifier for this conversation. + #[serde(rename = "threadId")] + pub thread_id: ThreadId, + + /// The run identifier for this specific agent invocation. + #[serde(rename = "runId")] + pub run_id: RunId, + + /// Optional parent run ID for nested or sub-agent runs. + #[serde(rename = "parentRunId", skip_serializing_if = "Option::is_none")] + pub parent_run_id: Option, + + /// The current state, can be any JSON value. + pub state: JsonValue, + + /// The conversation messages. + pub messages: Vec, + + /// The tools available to the agent. + pub tools: Vec, + + /// Additional context provided to the agent. + pub context: Vec, + + /// Forwarded properties from the client. + #[serde(rename = "forwardedProps")] + pub forwarded_props: JsonValue, +} + +impl RunAgentInput { + /// Creates a new RunAgentInput with the given thread and run IDs. + /// + /// Initializes with empty messages, tools, context, null state, + /// and null forwarded props. + pub fn new(thread_id: impl Into, run_id: impl Into) -> Self { + Self { + thread_id: thread_id.into(), + run_id: run_id.into(), + parent_run_id: None, + state: JsonValue::Null, + messages: Vec::new(), + tools: Vec::new(), + context: Vec::new(), + forwarded_props: JsonValue::Null, + } + } + + /// Sets the parent run ID for nested runs. + pub fn with_parent_run_id(mut self, parent_id: impl Into) -> Self { + self.parent_run_id = Some(parent_id.into()); + self + } + + /// Sets the state. + pub fn with_state(mut self, state: JsonValue) -> Self { + self.state = state; + self + } + + /// Sets the messages. + pub fn with_messages(mut self, messages: Vec) -> Self { + self.messages = messages; + self + } + + /// Sets the available tools. + pub fn with_tools(mut self, tools: Vec) -> Self { + self.tools = tools; + self + } + + /// Sets the context items. + pub fn with_context(mut self, context: Vec) -> Self { + self.context = context; + self + } + + /// Sets the forwarded props. + pub fn with_forwarded_props(mut self, props: JsonValue) -> Self { + self.forwarded_props = props; + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::json; + + #[test] + fn test_context_serialization() { + let ctx = Context::new("current_page".to_string(), "/dashboard".to_string()); + let json = serde_json::to_string(&ctx).unwrap(); + + assert!(json.contains("\"description\":\"current_page\"")); + assert!(json.contains("\"value\":\"/dashboard\"")); + } + + #[test] + fn test_context_deserialization() { + let json = r#"{"description":"timezone","value":"UTC"}"#; + let ctx: Context = serde_json::from_str(json).unwrap(); + + assert_eq!(ctx.description, "timezone"); + assert_eq!(ctx.value, "UTC"); + } + + #[test] + fn test_run_agent_input_minimal() { + let thread_id = ThreadId::random(); + let run_id = RunId::random(); + let input = RunAgentInput::new(thread_id.clone(), run_id.clone()); + + assert_eq!(input.thread_id, thread_id); + assert_eq!(input.run_id, run_id); + assert!(input.parent_run_id.is_none()); + assert_eq!(input.state, JsonValue::Null); + assert!(input.messages.is_empty()); + assert!(input.tools.is_empty()); + assert!(input.context.is_empty()); + assert_eq!(input.forwarded_props, JsonValue::Null); + } + + #[test] + fn test_run_agent_input_full() { + let thread_id = ThreadId::random(); + let run_id = RunId::random(); + let parent_id = RunId::random(); + + let input = RunAgentInput::new(thread_id.clone(), run_id.clone()) + .with_parent_run_id(parent_id.clone()) + .with_state(json!({"count": 42})) + .with_messages(vec![Message::new_user("Hello")]) + .with_tools(vec![Tool::new( + "get_weather".to_string(), + "Get weather".to_string(), + json!({"type": "object"}), + )]) + .with_context(vec![Context::new("tz".to_string(), "UTC".to_string())]) + .with_forwarded_props(json!({"custom": true})); + + assert_eq!(input.thread_id, thread_id); + assert_eq!(input.run_id, run_id); + assert_eq!(input.parent_run_id, Some(parent_id)); + assert_eq!(input.state, json!({"count": 42})); + assert_eq!(input.messages.len(), 1); + assert_eq!(input.tools.len(), 1); + assert_eq!(input.context.len(), 1); + assert_eq!(input.forwarded_props, json!({"custom": true})); + } + + #[test] + fn test_run_agent_input_builder() { + let input = RunAgentInput::new(ThreadId::random(), RunId::random()) + .with_state(json!(null)) + .with_messages(vec![]) + .with_tools(vec![]) + .with_context(vec![]) + .with_forwarded_props(json!({})); + + assert_eq!(input.state, JsonValue::Null); + assert!(input.messages.is_empty()); + assert_eq!(input.forwarded_props, json!({})); + } + + #[test] + fn test_run_agent_input_serialization() { + let thread_id = ThreadId::random(); + let run_id = RunId::random(); + let input = RunAgentInput::new(thread_id, run_id); + + let json = serde_json::to_string(&input).unwrap(); + + // Check camelCase field names + assert!(json.contains("\"threadId\"")); + assert!(json.contains("\"runId\"")); + assert!(json.contains("\"forwardedProps\"")); + // parentRunId should be skipped when None + assert!(!json.contains("\"parentRunId\"")); + } + + #[test] + fn test_run_agent_input_serialization_with_parent() { + let input = RunAgentInput::new(ThreadId::random(), RunId::random()) + .with_parent_run_id(RunId::random()); + + let json = serde_json::to_string(&input).unwrap(); + + // parentRunId should be present when Some + assert!(json.contains("\"parentRunId\"")); + } + + #[test] + fn test_run_agent_input_roundtrip() { + let thread_id = ThreadId::random(); + let run_id = RunId::random(); + let parent_id = RunId::random(); + + let original = RunAgentInput::new(thread_id, run_id) + .with_parent_run_id(parent_id) + .with_state(json!({"nested": {"value": 123}})) + .with_messages(vec![ + Message::new_user("Hello"), + Message::new_assistant("Hi there!"), + ]) + .with_context(vec![ + Context::new("key1".to_string(), "value1".to_string()), + Context::new("key2".to_string(), "value2".to_string()), + ]) + .with_forwarded_props(json!({"prop": "value"})); + + let json = serde_json::to_string(&original).unwrap(); + let deserialized: RunAgentInput = serde_json::from_str(&json).unwrap(); + + assert_eq!(original.thread_id, deserialized.thread_id); + assert_eq!(original.run_id, deserialized.run_id); + assert_eq!(original.parent_run_id, deserialized.parent_run_id); + assert_eq!(original.state, deserialized.state); + assert_eq!(original.messages.len(), deserialized.messages.len()); + assert_eq!(original.context.len(), deserialized.context.len()); + assert_eq!(original.forwarded_props, deserialized.forwarded_props); + } +} diff --git a/crates/ag-ui-core/src/types/message.rs b/crates/ag-ui-core/src/types/message.rs new file mode 100644 index 00000000..5cbdefd8 --- /dev/null +++ b/crates/ag-ui-core/src/types/message.rs @@ -0,0 +1,714 @@ +//! Message types for the AG-UI protocol. +//! +//! This module defines message structures for agent-user communication, +//! including role definitions and various message type variants. + +use crate::types::ids::{MessageId, ToolCallId}; +use crate::types::tool::ToolCall; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; + +/// A generated function call from a model. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct FunctionCall { + /// The name of the function to call. + pub name: String, + /// The arguments to pass to the function (JSON-encoded string). + pub arguments: String, +} + +/// Message role indicating the sender type. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + /// Developer messages, typically for debugging. + Developer, + /// System messages, usually containing system prompts. + System, + /// Assistant messages from the AI model. + Assistant, + /// User messages from the human user. + User, + /// Tool messages containing tool/function call results. + Tool, + /// Activity messages for tracking agent activities. + Activity, +} + +// Utility methods for serde defaults +impl Role { + pub(crate) fn developer() -> Self { + Self::Developer + } + pub(crate) fn system() -> Self { + Self::System + } + pub(crate) fn assistant() -> Self { + Self::Assistant + } + pub(crate) fn user() -> Self { + Self::User + } + pub(crate) fn tool() -> Self { + Self::Tool + } + pub(crate) fn activity() -> Self { + Self::Activity + } +} + +/// A basic message with optional string content. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct BaseMessage { + /// Unique identifier for this message. + pub id: MessageId, + /// The role of the message sender. + pub role: Role, + /// The text content of the message. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// Optional name for the sender. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +/// A developer message, typically for debugging purposes. +/// Not to be confused with system messages. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct DeveloperMessage { + /// Unique identifier for this message. + pub id: MessageId, + /// The role (always Developer). + #[serde(default = "Role::developer")] + pub role: Role, + /// The text content of the message. + pub content: String, + /// Optional name for the sender. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +impl DeveloperMessage { + /// Creates a new developer message with the given ID and content. + pub fn new(id: impl Into, content: String) -> Self { + Self { + id: id.into(), + role: Role::Developer, + content, + name: None, + } + } + + /// Sets the name for this message. + pub fn with_name(mut self, name: String) -> Self { + self.name = Some(name); + self + } +} + +/// A system message, usually containing the system prompt. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct SystemMessage { + /// Unique identifier for this message. + pub id: MessageId, + /// The role (always System). + #[serde(default = "Role::system")] + pub role: Role, + /// The text content of the message. + pub content: String, + /// Optional name for the sender. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +impl SystemMessage { + /// Creates a new system message with the given ID and content. + pub fn new(id: impl Into, content: String) -> Self { + Self { + id: id.into(), + role: Role::System, + content, + name: None, + } + } + + /// Sets the name for this message. + pub fn with_name(mut self, name: String) -> Self { + self.name = Some(name); + self + } +} + +/// An assistant message (from the AI model). +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct AssistantMessage { + /// Unique identifier for this message. + pub id: MessageId, + /// The role (always Assistant). + #[serde(default = "Role::assistant")] + pub role: Role, + /// The text content of the message. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// Optional name for the sender. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// Tool calls made by the assistant. + #[serde(rename = "toolCalls", skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, +} + +impl AssistantMessage { + /// Creates a new assistant message with the given ID. + pub fn new(id: impl Into) -> Self { + Self { + id: id.into(), + role: Role::Assistant, + content: None, + name: None, + tool_calls: None, + } + } + + /// Sets the content for this message. + pub fn with_content(mut self, content: String) -> Self { + self.content = Some(content); + self + } + + /// Sets the name for this message. + pub fn with_name(mut self, name: String) -> Self { + self.name = Some(name); + self + } + + /// Sets the tool calls for this message. + pub fn with_tool_calls(mut self, tool_calls: Vec) -> Self { + self.tool_calls = Some(tool_calls); + self + } +} + +/// A user message from the human user. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct UserMessage { + /// Unique identifier for this message. + pub id: MessageId, + /// The role (always User). + #[serde(default = "Role::user")] + pub role: Role, + /// The text content of the message. + pub content: String, + /// Optional name for the sender. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +impl UserMessage { + /// Creates a new user message with the given ID and content. + pub fn new(id: impl Into, content: String) -> Self { + Self { + id: id.into(), + role: Role::User, + content, + name: None, + } + } + + /// Sets the name for this message. + pub fn with_name(mut self, name: String) -> Self { + self.name = Some(name); + self + } +} + +/// A tool message containing the result of a tool/function call. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ToolMessage { + /// Unique identifier for this message. + pub id: MessageId, + /// The text content (tool result). + pub content: String, + /// The role (always Tool). + #[serde(default = "Role::tool")] + pub role: Role, + /// The ID of the tool call this result corresponds to. + #[serde(rename = "toolCallId")] + pub tool_call_id: ToolCallId, + /// Optional error message if the tool call failed. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +impl ToolMessage { + /// Creates a new tool message with the given ID, content, and tool call ID. + pub fn new( + id: impl Into, + content: String, + tool_call_id: impl Into, + ) -> Self { + Self { + id: id.into(), + content, + role: Role::Tool, + tool_call_id: tool_call_id.into(), + error: None, + } + } + + /// Sets the error for this message. + pub fn with_error(mut self, error: String) -> Self { + self.error = Some(error); + self + } +} + +/// An activity message for tracking agent activities. +/// +/// Activity messages represent structured agent activities like planning, +/// research, or other non-text operations. The content is a flexible JSON +/// object that can hold activity-specific data. +/// +/// # Example +/// +/// ``` +/// use ag_ui_core::{ActivityMessage, MessageId}; +/// use serde_json::json; +/// +/// let activity = ActivityMessage::new( +/// MessageId::random(), +/// "PLAN".to_string(), +/// json!({"steps": ["research", "implement", "test"]}), +/// ); +/// +/// assert_eq!(activity.activity_type, "PLAN"); +/// ``` +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ActivityMessage { + /// Unique identifier for this message. + pub id: MessageId, + /// The role (always Activity). + #[serde(default = "Role::activity")] + pub role: Role, + /// The type of activity (e.g., "PLAN", "RESEARCH"). + #[serde(rename = "activityType")] + pub activity_type: String, + /// The activity content as a flexible JSON object. + pub content: JsonValue, +} + +impl ActivityMessage { + /// Creates a new activity message with the given ID, type, and content. + pub fn new( + id: impl Into, + activity_type: impl Into, + content: JsonValue, + ) -> Self { + Self { + id: id.into(), + role: Role::Activity, + activity_type: activity_type.into(), + content, + } + } + + /// Sets the content for this activity message. + pub fn with_content(mut self, content: JsonValue) -> Self { + self.content = content; + self + } +} + +/// Represents the different types of messages in a conversation. +/// +/// This enum provides a unified type for all message variants, using the +/// role field as the discriminant for JSON serialization. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[serde(tag = "role", rename_all = "lowercase")] +pub enum Message { + /// A developer message for debugging. + Developer { + id: MessageId, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + /// A system message (usually the system prompt). + System { + id: MessageId, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + /// An assistant message from the AI model. + Assistant { + id: MessageId, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + #[serde(rename = "toolCalls", skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + }, + /// A user message from the human user. + User { + id: MessageId, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + name: Option, + }, + /// A tool message containing tool call results. + Tool { + id: MessageId, + content: String, + #[serde(rename = "toolCallId")] + tool_call_id: ToolCallId, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, + }, + /// An activity message for tracking agent activities. + Activity { + id: MessageId, + #[serde(rename = "activityType")] + activity_type: String, + content: JsonValue, + }, +} + +impl Message { + /// Creates a new message with the given role, ID, and content. + pub fn new>(role: Role, id: impl Into, content: S) -> Self { + match role { + Role::Developer => Self::Developer { + id: id.into(), + content: content.as_ref().to_string(), + name: None, + }, + Role::System => Self::System { + id: id.into(), + content: content.as_ref().to_string(), + name: None, + }, + Role::Assistant => Self::Assistant { + id: id.into(), + content: Some(content.as_ref().to_string()), + name: None, + tool_calls: None, + }, + Role::User => Self::User { + id: id.into(), + content: content.as_ref().to_string(), + name: None, + }, + Role::Tool => Self::Tool { + id: id.into(), + content: content.as_ref().to_string(), + tool_call_id: ToolCallId::random(), + error: None, + }, + Role::Activity => Self::Activity { + id: id.into(), + activity_type: "custom".to_string(), + content: JsonValue::String(content.as_ref().to_string()), + }, + } + } + + /// Creates a new user message with a random ID. + pub fn new_user>(content: S) -> Self { + Self::new(Role::User, MessageId::random(), content) + } + + /// Creates a new tool message with a random ID. + pub fn new_tool>(content: S) -> Self { + Self::new(Role::Tool, MessageId::random(), content) + } + + /// Creates a new system message with a random ID. + pub fn new_system>(content: S) -> Self { + Self::new(Role::System, MessageId::random(), content) + } + + /// Creates a new assistant message with a random ID. + pub fn new_assistant>(content: S) -> Self { + Self::new(Role::Assistant, MessageId::random(), content) + } + + /// Creates a new developer message with a random ID. + pub fn new_developer>(content: S) -> Self { + Self::new(Role::Developer, MessageId::random(), content) + } + + /// Creates a new activity message with a random ID. + pub fn new_activity(activity_type: impl Into, content: JsonValue) -> Self { + Self::Activity { + id: MessageId::random(), + activity_type: activity_type.into(), + content, + } + } + + /// Returns a reference to the message ID. + pub fn id(&self) -> &MessageId { + match self { + Message::Developer { id, .. } => id, + Message::System { id, .. } => id, + Message::Assistant { id, .. } => id, + Message::User { id, .. } => id, + Message::Tool { id, .. } => id, + Message::Activity { id, .. } => id, + } + } + + /// Returns a mutable reference to the message ID. + pub fn id_mut(&mut self) -> &mut MessageId { + match self { + Message::Developer { id, .. } => id, + Message::System { id, .. } => id, + Message::Assistant { id, .. } => id, + Message::User { id, .. } => id, + Message::Tool { id, .. } => id, + Message::Activity { id, .. } => id, + } + } + + /// Returns the role of this message. + pub fn role(&self) -> Role { + match self { + Message::Developer { .. } => Role::Developer, + Message::System { .. } => Role::System, + Message::Assistant { .. } => Role::Assistant, + Message::User { .. } => Role::User, + Message::Tool { .. } => Role::Tool, + Message::Activity { .. } => Role::Activity, + } + } + + /// Returns the content of this message, if any. + /// + /// Note: Activity messages have JSON content, not string content. + /// Use `activity_content()` to access their content. + pub fn content(&self) -> Option<&str> { + match self { + Message::Developer { content, .. } => Some(content), + Message::System { content, .. } => Some(content), + Message::User { content, .. } => Some(content), + Message::Tool { content, .. } => Some(content), + Message::Assistant { content, .. } => content.as_deref(), + Message::Activity { .. } => None, + } + } + + /// Returns a mutable reference to the content of this message. + /// + /// Note: Activity messages have JSON content, not string content. + /// Use `activity_content_mut()` to modify their content. + pub fn content_mut(&mut self) -> Option<&mut String> { + match self { + Message::Developer { content, .. } + | Message::System { content, .. } + | Message::User { content, .. } + | Message::Tool { content, .. } => Some(content), + Message::Assistant { content, .. } => { + if content.is_none() { + *content = Some(String::new()); + } + content.as_mut() + } + Message::Activity { .. } => None, + } + } + + /// Returns the activity content of this message, if it's an activity message. + pub fn activity_content(&self) -> Option<&JsonValue> { + match self { + Message::Activity { content, .. } => Some(content), + _ => None, + } + } + + /// Returns a mutable reference to the activity content, if it's an activity message. + pub fn activity_content_mut(&mut self) -> Option<&mut JsonValue> { + match self { + Message::Activity { content, .. } => Some(content), + _ => None, + } + } + + /// Returns the activity type, if this is an activity message. + pub fn activity_type(&self) -> Option<&str> { + match self { + Message::Activity { activity_type, .. } => Some(activity_type), + _ => None, + } + } + + /// Returns the tool calls for this message, if any. + pub fn tool_calls(&self) -> Option<&[ToolCall]> { + match self { + Message::Assistant { tool_calls, .. } => tool_calls.as_deref(), + _ => None, + } + } + + /// Returns a mutable reference to the tool calls for this message. + pub fn tool_calls_mut(&mut self) -> Option<&mut Vec> { + match self { + Message::Assistant { tool_calls, .. } => { + if tool_calls.is_none() { + *tool_calls = Some(Vec::new()); + } + tool_calls.as_mut() + } + _ => None, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_role_serialization() { + let role = Role::Assistant; + let json = serde_json::to_string(&role).unwrap(); + assert_eq!(json, "\"assistant\""); + + let role = Role::User; + let json = serde_json::to_string(&role).unwrap(); + assert_eq!(json, "\"user\""); + } + + #[test] + fn test_developer_message_builder() { + let msg = DeveloperMessage::new(MessageId::random(), "debug info".to_string()) + .with_name("debugger".to_string()); + + assert_eq!(msg.role, Role::Developer); + assert_eq!(msg.content, "debug info"); + assert_eq!(msg.name, Some("debugger".to_string())); + } + + #[test] + fn test_assistant_message_builder() { + let msg = AssistantMessage::new(MessageId::random()) + .with_content("Hello!".to_string()) + .with_name("Claude".to_string()); + + assert_eq!(msg.role, Role::Assistant); + assert_eq!(msg.content, Some("Hello!".to_string())); + assert_eq!(msg.name, Some("Claude".to_string())); + } + + #[test] + fn test_message_enum_serialization() { + let msg = Message::new_user("Hello, world!"); + let json = serde_json::to_string(&msg).unwrap(); + + // Should contain "role": "user" + assert!(json.contains("\"role\":\"user\"")); + assert!(json.contains("\"content\":\"Hello, world!\"")); + } + + #[test] + fn test_message_accessors() { + let msg = Message::new_assistant("I can help with that."); + + assert_eq!(msg.role(), Role::Assistant); + assert_eq!(msg.content(), Some("I can help with that.")); + assert!(msg.tool_calls().is_none()); + } + + #[test] + fn test_activity_role_serialization() { + let role = Role::Activity; + let json = serde_json::to_string(&role).unwrap(); + assert_eq!(json, "\"activity\""); + + let parsed: Role = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, Role::Activity); + } + + #[test] + fn test_activity_message_struct() { + use serde_json::json; + + let activity = ActivityMessage::new( + MessageId::random(), + "PLAN", + json!({"steps": ["research", "implement"]}), + ); + + assert_eq!(activity.role, Role::Activity); + assert_eq!(activity.activity_type, "PLAN"); + assert_eq!(activity.content["steps"][0], "research"); + } + + #[test] + fn test_activity_message_serialization() { + use serde_json::json; + + let activity = ActivityMessage::new( + MessageId::random(), + "RESEARCH", + json!({"query": "rust async"}), + ); + + let json_str = serde_json::to_string(&activity).unwrap(); + assert!(json_str.contains("\"role\":\"activity\"")); + assert!(json_str.contains("\"activityType\":\"RESEARCH\"")); + assert!(json_str.contains("\"query\":\"rust async\"")); + } + + #[test] + fn test_activity_message_enum() { + use serde_json::json; + + let msg = Message::new_activity("PLAN", json!({"steps": ["a", "b"]})); + + assert_eq!(msg.role(), Role::Activity); + assert!(msg.content().is_none()); // Activity has JSON content, not string + assert!(msg.activity_content().is_some()); + assert_eq!(msg.activity_type(), Some("PLAN")); + } + + #[test] + fn test_activity_message_enum_serialization() { + use serde_json::json; + + let msg = Message::new_activity("DEPLOY", json!({"target": "production"})); + let json_str = serde_json::to_string(&msg).unwrap(); + + assert!(json_str.contains("\"role\":\"activity\"")); + assert!(json_str.contains("\"activityType\":\"DEPLOY\"")); + assert!(json_str.contains("\"target\":\"production\"")); + + // Roundtrip + let parsed: Message = serde_json::from_str(&json_str).unwrap(); + assert_eq!(parsed.role(), Role::Activity); + assert_eq!(parsed.activity_type(), Some("DEPLOY")); + } + + #[test] + fn test_activity_content_accessors() { + use serde_json::json; + + let mut msg = Message::new_activity("TEST", json!({"status": "pending"})); + + // Test immutable accessor + assert!(msg.activity_content().is_some()); + assert_eq!(msg.activity_content().unwrap()["status"], "pending"); + + // Test mutable accessor + if let Some(content) = msg.activity_content_mut() { + content["status"] = json!("complete"); + } + assert_eq!(msg.activity_content().unwrap()["status"], "complete"); + + // Non-activity messages should return None + let user_msg = Message::new_user("hello"); + assert!(user_msg.activity_content().is_none()); + assert!(user_msg.activity_type().is_none()); + } +} diff --git a/crates/ag-ui-core/src/types/mod.rs b/crates/ag-ui-core/src/types/mod.rs new file mode 100644 index 00000000..df69afbd --- /dev/null +++ b/crates/ag-ui-core/src/types/mod.rs @@ -0,0 +1,20 @@ +//! AG-UI Protocol Types +//! +//! This module defines core protocol types including: +//! - Message types (user, assistant, system, tool) +//! - Role definitions +//! - ID types (MessageId, RunId, ThreadId, ToolCallId) +//! - Context and input types +//! - Content types (text, binary) for multimodal messages + +mod content; +mod ids; +mod input; +mod message; +mod tool; + +pub use content::*; +pub use ids::*; +pub use input::*; +pub use message::*; +pub use tool::*; diff --git a/crates/ag-ui-core/src/types/tool.rs b/crates/ag-ui-core/src/types/tool.rs new file mode 100644 index 00000000..9af554a8 --- /dev/null +++ b/crates/ag-ui-core/src/types/tool.rs @@ -0,0 +1,78 @@ +//! Tool types for the AG-UI protocol. +//! +//! This module defines structures for tool/function calling, +//! including tool definitions and tool call representations. + +use crate::types::ids::ToolCallId; +use crate::types::message::FunctionCall; +use serde::{Deserialize, Serialize}; +use serde_json::Value as JsonValue; + +/// A tool call made by an assistant. +/// +/// Represents a specific invocation of a tool/function by the model, +/// including the tool call ID, type, and function details. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct ToolCall { + /// Unique identifier for this tool call. + pub id: ToolCallId, + /// The type of call (always "function" for now). + #[serde(rename = "type")] + pub call_type: String, + /// The function being called with its arguments. + pub function: FunctionCall, +} + +impl ToolCall { + /// Creates a new tool call with the given ID and function. + pub fn new(id: impl Into, function: FunctionCall) -> Self { + Self { + id: id.into(), + call_type: "function".to_string(), + function, + } + } +} + +/// A tool definition describing a function the model can call. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct Tool { + /// The name of the tool. + pub name: String, + /// A description of what the tool does. + pub description: String, + /// JSON Schema describing the tool's parameters. + pub parameters: JsonValue, +} + +impl Tool { + /// Creates a new tool definition. + pub fn new(name: String, description: String, parameters: JsonValue) -> Self { + Self { + name, + description, + parameters, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tool_call_serialization() { + let tool_call = ToolCall::new( + ToolCallId::random(), + FunctionCall { + name: "get_weather".to_string(), + arguments: r#"{"location": "NYC"}"#.to_string(), + }, + ); + + let json = serde_json::to_string(&tool_call).unwrap(); + // Should have "type": "function", not "call_type" + assert!(json.contains("\"type\":\"function\"")); + assert!(json.contains("\"name\":\"get_weather\"")); + } +} diff --git a/crates/ag-ui-server/Cargo.toml b/crates/ag-ui-server/Cargo.toml new file mode 100644 index 00000000..b03cf872 --- /dev/null +++ b/crates/ag-ui-server/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "ag-ui-server" +version = "0.1.0" +edition = "2024" +rust-version = "1.88" +license = "MIT" +description = "Server-side AG-UI event producer for streaming to frontends" +readme = "README.md" + +[dependencies] +# Internal +ag-ui-core = { version = "0.1.0", path = "../ag-ui-core" } + +# Error handling +thiserror = "2" + +# Serialization +serde = { version = "1", features = ["derive"] } +serde_json = "1" + +# Async runtime +tokio = { version = "1", features = ["rt", "macros", "sync"] } +tokio-stream = "0.1" +async-trait = "0.1" +futures = "0.3" + +# Web framework +axum = { version = "0.8", features = ["ws"] } diff --git a/crates/ag-ui-server/src/error.rs b/crates/ag-ui-server/src/error.rs new file mode 100644 index 00000000..9a8e4c91 --- /dev/null +++ b/crates/ag-ui-server/src/error.rs @@ -0,0 +1,31 @@ +//! Error types for AG-UI server operations. + +use ag_ui_core::AgUiError; +use thiserror::Error; + +/// Errors that can occur in AG-UI server operations. +#[derive(Debug, Error)] +pub enum ServerError { + /// Core AG-UI error + #[error("Core error: {0}")] + Core(#[from] AgUiError), + + /// Transport layer error (SSE, WebSocket, etc.) + #[error("Transport error: {0}")] + Transport(String), + + /// Serialization error during event emission + #[error("Serialization error: {0}")] + Serialization(String), + + /// Channel or stream error + #[error("Channel error: {0}")] + Channel(String), + + /// Connection error + #[error("Connection error: {0}")] + Connection(String), +} + +/// Result type alias using ServerError +pub type Result = std::result::Result; diff --git a/crates/ag-ui-server/src/lib.rs b/crates/ag-ui-server/src/lib.rs new file mode 100644 index 00000000..93898b1c --- /dev/null +++ b/crates/ag-ui-server/src/lib.rs @@ -0,0 +1,43 @@ +//! AG-UI Server SDK +//! +//! This crate provides server-side functionality for producing AG-UI protocol events. +//! It enables Rust agents to stream events to frontend applications via various +//! transports (SSE, WebSocket, etc.). +//! +//! # Overview +//! +//! The AG-UI Server SDK includes: +//! +//! - **Event Producer**: High-level API for emitting AG-UI events from agent code +//! - **Transport Layer**: SSE and WebSocket implementations for streaming events +//! - **Error Handling**: Server-specific error types +//! +//! # Usage +//! +//! ```rust,ignore +//! use ag_ui_server::{EventProducer, Result}; +//! ``` +//! +//! # Integration +//! +//! This crate is designed to integrate with the Syncable CLI agent, enabling +//! any frontend to connect and receive real-time agent events. + +pub mod error; +pub mod producer; +pub mod transport; + +// Re-export ag-ui-core types for convenience +pub use ag_ui_core::*; + +// Re-export server-specific types +pub use error::{Result, ServerError}; + +// Re-export transport types +pub use transport::{SseHandler, SseSender}; + +// Re-export producer types +pub use producer::{ + AgentSession, EventProducer, MessageStream, ThinkingMessageStream, ThinkingStep, + ToolCallStream, +}; diff --git a/crates/ag-ui-server/src/producer.rs b/crates/ag-ui-server/src/producer.rs new file mode 100644 index 00000000..0796cfbe --- /dev/null +++ b/crates/ag-ui-server/src/producer.rs @@ -0,0 +1,1055 @@ +//! Event producer API for emitting AG-UI events. +//! +//! This module provides the high-level API for agents to emit events to connected +//! frontends. It includes: +//! +//! - [`EventProducer`] trait - Core abstraction for event emission +//! - [`MessageStream`] - Helper for streaming text messages +//! - [`ToolCallStream`] - Helper for streaming tool calls +//! - [`ThinkingMessageStream`] - Helper for streaming thinking content +//! - [`ThinkingStep`] - Helper for thinking block boundaries (chain-of-thought) +//! - [`AgentSession`] - Manages run lifecycle and state +//! +//! # Example +//! +//! ```rust,ignore +//! use ag_ui_server::{transport::sse, AgentSession, MessageStream}; +//! +//! async fn handle_request() -> impl IntoResponse { +//! let (sender, handler) = sse::channel(32); +//! +//! tokio::spawn(async move { +//! let mut session = AgentSession::new(sender); +//! session.start_run().await.unwrap(); +//! +//! // Stream a message +//! let msg = MessageStream::start(session.producer()).await.unwrap(); +//! msg.content("Hello, ").await.unwrap(); +//! msg.content("world!").await.unwrap(); +//! msg.end().await.unwrap(); +//! +//! session.finish_run(None).await.unwrap(); +//! }); +//! +//! handler.into_response() +//! } +//! ``` + +use std::marker::PhantomData; + +use ag_ui_core::{ + AgentState, Event, InterruptInfo, JsonValue, MessageId, RunErrorEvent, RunFinishedEvent, + RunId, RunStartedEvent, TextMessageContentEvent, TextMessageEndEvent, TextMessageStartEvent, + ThinkingEndEvent, ThinkingStartEvent, ThinkingTextMessageContentEvent, + ThinkingTextMessageEndEvent, ThinkingTextMessageStartEvent, ThreadId, ToolCallArgsEvent, + ToolCallEndEvent, ToolCallId, ToolCallStartEvent, +}; +use async_trait::async_trait; + +use crate::error::ServerError; +use crate::transport::SseSender; + +/// Trait for producing AG-UI events. +/// +/// Implementors of this trait can emit events to connected frontends +/// through various transport mechanisms (SSE, WebSocket, etc.). +/// +/// # Example +/// +/// ```rust,ignore +/// use ag_ui_server::EventProducer; +/// use ag_ui_core::{Event, RunErrorEvent}; +/// +/// async fn emit_error(producer: &P) -> Result<(), ServerError> { +/// producer.emit(Event::RunError(RunErrorEvent::new("Something went wrong"))).await +/// } +/// ``` +#[async_trait] +pub trait EventProducer: Send + Sync { + /// Emit a single event to connected clients. + /// + /// Returns an error if the connection is closed or the event cannot be sent. + async fn emit(&self, event: Event) -> Result<(), ServerError>; + + /// Emit multiple events to connected clients. + /// + /// Events are sent in order. Stops and returns an error on the first failure. + async fn emit_many(&self, events: Vec>) -> Result<(), ServerError> { + for event in events { + self.emit(event).await?; + } + Ok(()) + } + + /// Check if the connection is still open. + /// + /// Returns `false` if the client has disconnected. + fn is_connected(&self) -> bool; +} + +// Implement EventProducer for SseSender +#[async_trait] +impl EventProducer for SseSender { + async fn emit(&self, event: Event) -> Result<(), ServerError> { + self.send(event) + .await + .map_err(|_| ServerError::Channel("SSE channel closed".into())) + } + + fn is_connected(&self) -> bool { + !self.is_closed() + } +} + +/// Helper for streaming a text message piece by piece. +/// +/// This struct manages the lifecycle of a streaming text message, automatically +/// generating message IDs and emitting the appropriate events. +/// +/// # Example +/// +/// ```rust,ignore +/// let msg = MessageStream::start(&producer).await?; +/// msg.content("Hello, ").await?; +/// msg.content("world!").await?; +/// let message_id = msg.end().await?; +/// ``` +pub struct MessageStream<'a, P: EventProducer, StateT: AgentState = JsonValue> { + producer: &'a P, + message_id: MessageId, + _state: PhantomData, +} + +impl<'a, P: EventProducer, StateT: AgentState> MessageStream<'a, P, StateT> { + /// Start a new message stream. + /// + /// Emits a `TextMessageStart` event with a randomly generated message ID. + pub async fn start(producer: &'a P) -> Result { + let message_id = MessageId::random(); + producer + .emit(Event::TextMessageStart(TextMessageStartEvent::new( + message_id.clone(), + ))) + .await?; + Ok(Self { + producer, + message_id, + _state: PhantomData, + }) + } + + /// Start a new message stream with a specific message ID. + pub async fn start_with_id( + producer: &'a P, + message_id: MessageId, + ) -> Result { + producer + .emit(Event::TextMessageStart(TextMessageStartEvent::new( + message_id.clone(), + ))) + .await?; + Ok(Self { + producer, + message_id, + _state: PhantomData, + }) + } + + /// Append content to the message. + /// + /// Emits a `TextMessageContent` event with the given delta. + /// Empty deltas are silently ignored. + pub async fn content(&self, delta: impl Into) -> Result<(), ServerError> { + let delta = delta.into(); + if delta.is_empty() { + return Ok(()); + } + self.producer + .emit(Event::TextMessageContent( + TextMessageContentEvent::new_unchecked(self.message_id.clone(), delta), + )) + .await + } + + /// End the message stream. + /// + /// Emits a `TextMessageEnd` event and returns the message ID. + /// Consumes the stream to prevent further content being added. + pub async fn end(self) -> Result { + self.producer + .emit(Event::TextMessageEnd(TextMessageEndEvent::new( + self.message_id.clone(), + ))) + .await?; + Ok(self.message_id) + } + + /// Get the message ID for this stream. + pub fn message_id(&self) -> &MessageId { + &self.message_id + } +} + +/// Helper for streaming a tool call with arguments. +/// +/// This struct manages the lifecycle of a streaming tool call, automatically +/// generating tool call IDs and emitting the appropriate events. +/// +/// # Example +/// +/// ```rust,ignore +/// let call = ToolCallStream::start(&producer, "get_weather").await?; +/// call.args(r#"{"location": "#).await?; +/// call.args(r#""New York"}"#).await?; +/// let tool_call_id = call.end().await?; +/// ``` +pub struct ToolCallStream<'a, P: EventProducer, StateT: AgentState = JsonValue> { + producer: &'a P, + tool_call_id: ToolCallId, + _state: PhantomData, +} + +impl<'a, P: EventProducer, StateT: AgentState> ToolCallStream<'a, P, StateT> { + /// Start a new tool call stream. + /// + /// Emits a `ToolCallStart` event with the given tool name and a randomly + /// generated tool call ID. + pub async fn start(producer: &'a P, name: impl Into) -> Result { + let tool_call_id = ToolCallId::random(); + producer + .emit(Event::ToolCallStart(ToolCallStartEvent::new( + tool_call_id.clone(), + name, + ))) + .await?; + Ok(Self { + producer, + tool_call_id, + _state: PhantomData, + }) + } + + /// Start a new tool call stream with a specific tool call ID. + pub async fn start_with_id( + producer: &'a P, + tool_call_id: ToolCallId, + name: impl Into, + ) -> Result { + producer + .emit(Event::ToolCallStart(ToolCallStartEvent::new( + tool_call_id.clone(), + name, + ))) + .await?; + Ok(Self { + producer, + tool_call_id, + _state: PhantomData, + }) + } + + /// Stream an argument chunk. + /// + /// Emits a `ToolCallArgs` event with the given delta. + pub async fn args(&self, delta: impl Into) -> Result<(), ServerError> { + self.producer + .emit(Event::ToolCallArgs(ToolCallArgsEvent::new( + self.tool_call_id.clone(), + delta, + ))) + .await + } + + /// End the tool call stream. + /// + /// Emits a `ToolCallEnd` event and returns the tool call ID. + /// Consumes the stream to prevent further args being added. + pub async fn end(self) -> Result { + self.producer + .emit(Event::ToolCallEnd(ToolCallEndEvent::new( + self.tool_call_id.clone(), + ))) + .await?; + Ok(self.tool_call_id) + } + + /// Get the tool call ID for this stream. + pub fn tool_call_id(&self) -> &ToolCallId { + &self.tool_call_id + } +} + +/// Helper for streaming thinking content (extended thinking / chain-of-thought). +/// +/// This struct manages the lifecycle of streaming thinking content. Unlike +/// [`MessageStream`], thinking messages don't have IDs as they're ephemeral. +/// +/// # Example +/// +/// ```rust,ignore +/// let thinking = ThinkingMessageStream::start(&producer).await?; +/// thinking.content("Let me analyze this...").await?; +/// thinking.content("The key factors are...").await?; +/// thinking.end().await?; +/// ``` +pub struct ThinkingMessageStream<'a, P: EventProducer, StateT: AgentState = JsonValue> { + producer: &'a P, + _state: PhantomData, +} + +impl<'a, P: EventProducer, StateT: AgentState> ThinkingMessageStream<'a, P, StateT> { + /// Start a new thinking message stream. + /// + /// Emits a `ThinkingTextMessageStart` event. + pub async fn start(producer: &'a P) -> Result { + producer + .emit(Event::ThinkingTextMessageStart( + ThinkingTextMessageStartEvent::new(), + )) + .await?; + Ok(Self { + producer, + _state: PhantomData, + }) + } + + /// Append content to the thinking message. + /// + /// Emits a `ThinkingTextMessageContent` event with the given delta. + /// Unlike regular messages, empty deltas are allowed for thinking content. + pub async fn content(&self, delta: impl Into) -> Result<(), ServerError> { + self.producer + .emit(Event::ThinkingTextMessageContent( + ThinkingTextMessageContentEvent::new(delta), + )) + .await + } + + /// End the thinking message stream. + /// + /// Emits a `ThinkingTextMessageEnd` event. + /// Consumes the stream to prevent further content being added. + pub async fn end(self) -> Result<(), ServerError> { + self.producer + .emit(Event::ThinkingTextMessageEnd( + ThinkingTextMessageEndEvent::new(), + )) + .await + } +} + +/// Helper for managing thinking block boundaries (chain-of-thought steps). +/// +/// This struct wraps a thinking block with `ThinkingStart` and `ThinkingEnd` events. +/// Inside a thinking step, you can emit thinking content using [`ThinkingMessageStream`]. +/// +/// # Example +/// +/// ```rust,ignore +/// // Start a thinking step with optional title +/// let step = ThinkingStep::start(&producer, Some("Analyzing user query")).await?; +/// +/// // Emit thinking content inside the step +/// let thinking = ThinkingMessageStream::start(step.producer()).await?; +/// thinking.content("First, let me consider...").await?; +/// thinking.end().await?; +/// +/// // End the thinking step +/// step.end().await?; +/// ``` +pub struct ThinkingStep<'a, P: EventProducer, StateT: AgentState = JsonValue> { + producer: &'a P, + _state: PhantomData, +} + +impl<'a, P: EventProducer, StateT: AgentState> ThinkingStep<'a, P, StateT> { + /// Start a new thinking step. + /// + /// Emits a `ThinkingStart` event with an optional title. + pub async fn start( + producer: &'a P, + title: Option>, + ) -> Result { + let event = if let Some(t) = title { + ThinkingStartEvent::new().with_title(t) + } else { + ThinkingStartEvent::new() + }; + producer.emit(Event::ThinkingStart(event)).await?; + Ok(Self { + producer, + _state: PhantomData, + }) + } + + /// End the thinking step. + /// + /// Emits a `ThinkingEnd` event. + /// Consumes the step to prevent reuse. + pub async fn end(self) -> Result<(), ServerError> { + self.producer + .emit(Event::ThinkingEnd(ThinkingEndEvent::new())) + .await + } + + /// Get a reference to the underlying producer. + /// + /// Use this to create [`ThinkingMessageStream`] instances inside the step. + pub fn producer(&self) -> &'a P { + self.producer + } +} + +/// Manages an agent session with run lifecycle events. +/// +/// This struct provides high-level management of agent runs, including +/// starting, finishing, and error handling. +/// +/// # Example +/// +/// ```rust,ignore +/// let mut session = AgentSession::new(sender); +/// +/// // Start a run +/// let run_id = session.start_run().await?; +/// +/// // Do work... +/// +/// // Finish the run +/// session.finish_run(Some(json!({"result": "success"}))).await?; +/// ``` +pub struct AgentSession, StateT: AgentState = JsonValue> { + producer: P, + thread_id: ThreadId, + current_run: Option, + _state: PhantomData, +} + +impl, StateT: AgentState> AgentSession { + /// Create a new session with the given producer. + /// + /// Generates a random thread ID for the session. + pub fn new(producer: P) -> Self { + Self { + producer, + thread_id: ThreadId::random(), + current_run: None, + _state: PhantomData, + } + } + + /// Create a new session with a specific thread ID. + pub fn with_thread_id(producer: P, thread_id: ThreadId) -> Self { + Self { + producer, + thread_id, + current_run: None, + _state: PhantomData, + } + } + + /// Start a new run. + /// + /// Emits a `RunStarted` event and stores the run ID. + /// Returns an error if a run is already in progress. + pub async fn start_run(&mut self) -> Result { + if self.current_run.is_some() { + return Err(ServerError::Channel("Run already in progress".into())); + } + let run_id = RunId::random(); + self.producer + .emit(Event::RunStarted(RunStartedEvent::new( + self.thread_id.clone(), + run_id.clone(), + ))) + .await?; + self.current_run = Some(run_id.clone()); + Ok(run_id) + } + + /// Finish the current run. + /// + /// Emits a `RunFinished` event with an optional result. + /// Does nothing if no run is in progress. + pub async fn finish_run(&mut self, result: Option) -> Result<(), ServerError> { + if let Some(run_id) = self.current_run.take() { + let mut event = RunFinishedEvent::new(self.thread_id.clone(), run_id); + if let Some(r) = result { + event = event.with_result(r); + } + self.producer.emit(Event::RunFinished(event)).await?; + } + Ok(()) + } + + /// Signal a run error. + /// + /// Emits a `RunError` event and clears the current run. + pub async fn run_error(&mut self, message: impl Into) -> Result<(), ServerError> { + self.current_run = None; + self.producer + .emit(Event::RunError(RunErrorEvent::new(message))) + .await + } + + /// Signal a run error with an error code. + pub async fn run_error_with_code( + &mut self, + message: impl Into, + code: impl Into, + ) -> Result<(), ServerError> { + self.current_run = None; + self.producer + .emit(Event::RunError( + RunErrorEvent::new(message).with_code(code), + )) + .await + } + + /// Get a reference to the underlying producer. + pub fn producer(&self) -> &P { + &self.producer + } + + /// Get the thread ID for this session. + pub fn thread_id(&self) -> &ThreadId { + &self.thread_id + } + + /// Get the current run ID, if any. + pub fn run_id(&self) -> Option<&RunId> { + self.current_run.as_ref() + } + + /// Check if a run is currently in progress. + pub fn is_running(&self) -> bool { + self.current_run.is_some() + } + + /// Check if the connection is still open. + pub fn is_connected(&self) -> bool { + self.producer.is_connected() + } + + /// Start a thinking step. + /// + /// Convenience method that creates a [`ThinkingStep`] using this session's producer. + /// + /// # Example + /// + /// ```rust,ignore + /// let step = session.start_thinking(Some("Planning response")).await?; + /// // ... emit thinking content ... + /// step.end().await?; + /// ``` + pub async fn start_thinking( + &self, + title: Option>, + ) -> Result, ServerError> { + ThinkingStep::start(&self.producer, title).await + } + + /// Interrupt the current run for human-in-the-loop interaction. + /// + /// Finishes the run with an interrupt outcome, signaling that human input + /// is required before the agent can continue. The client should display + /// appropriate UI based on the interrupt info and resume with user input. + /// + /// # Example + /// + /// ```rust,ignore + /// session.start_run().await?; + /// + /// // Request human approval + /// session.interrupt( + /// Some("human_approval"), + /// Some(serde_json::json!({"action": "send_email", "to": "user@example.com"})) + /// ).await?; + /// ``` + pub async fn interrupt( + &mut self, + reason: Option>, + payload: Option, + ) -> Result<(), ServerError> { + let run_id = self.current_run.take(); + if let Some(run_id) = run_id { + let mut info = InterruptInfo::new(); + if let Some(r) = reason { + info = info.with_reason(r); + } + if let Some(p) = payload { + info = info.with_payload(p); + } + + let event = RunFinishedEvent::new(self.thread_id.clone(), run_id).with_interrupt(info); + self.producer.emit(Event::RunFinished(event)).await?; + } + Ok(()) + } + + /// Interrupt with a specific interrupt ID for tracking. + /// + /// The interrupt ID can be used by the client to correlate the resume + /// request with the original interrupt. + /// + /// # Example + /// + /// ```rust,ignore + /// session.start_run().await?; + /// + /// // Request approval with tracking ID + /// session.interrupt_with_id( + /// "approval-001", + /// Some("database_modification"), + /// Some(serde_json::json!({"query": "DELETE FROM users WHERE inactive"})) + /// ).await?; + /// ``` + pub async fn interrupt_with_id( + &mut self, + id: impl Into, + reason: Option>, + payload: Option, + ) -> Result<(), ServerError> { + let run_id = self.current_run.take(); + if let Some(run_id) = run_id { + let mut info = InterruptInfo::new().with_id(id); + if let Some(r) = reason { + info = info.with_reason(r); + } + if let Some(p) = payload { + info = info.with_payload(p); + } + + let event = RunFinishedEvent::new(self.thread_id.clone(), run_id).with_interrupt(info); + self.producer.emit(Event::RunFinished(event)).await?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::{Arc, Mutex}; + + /// Mock producer for testing + struct MockProducer { + events: Arc>>, + connected: bool, + } + + impl MockProducer { + fn new() -> Self { + Self { + events: Arc::new(Mutex::new(Vec::new())), + connected: true, + } + } + + fn events(&self) -> Vec { + self.events.lock().unwrap().clone() + } + } + + #[async_trait] + impl EventProducer for MockProducer { + async fn emit(&self, event: Event) -> Result<(), ServerError> { + if !self.connected { + return Err(ServerError::Channel("disconnected".into())); + } + self.events.lock().unwrap().push(event); + Ok(()) + } + + fn is_connected(&self) -> bool { + self.connected + } + } + + #[tokio::test] + async fn test_event_producer_emit() { + let producer = MockProducer::new(); + + producer + .emit(Event::RunError(RunErrorEvent::new("test"))) + .await + .unwrap(); + + let events = producer.events(); + assert_eq!(events.len(), 1); + assert!(matches!(events[0], Event::RunError(_))); + } + + #[tokio::test] + async fn test_event_producer_emit_many() { + let producer = MockProducer::new(); + + producer + .emit_many(vec![ + Event::RunError(RunErrorEvent::new("error1")), + Event::RunError(RunErrorEvent::new("error2")), + ]) + .await + .unwrap(); + + let events = producer.events(); + assert_eq!(events.len(), 2); + } + + #[tokio::test] + async fn test_message_stream() { + let producer = MockProducer::new(); + + let msg = MessageStream::start(&producer).await.unwrap(); + msg.content("Hello, ").await.unwrap(); + msg.content("world!").await.unwrap(); + let _message_id = msg.end().await.unwrap(); + + let events = producer.events(); + assert_eq!(events.len(), 4); // start + 2 content + end + + assert!(matches!(events[0], Event::TextMessageStart(_))); + assert!(matches!(events[1], Event::TextMessageContent(_))); + assert!(matches!(events[2], Event::TextMessageContent(_))); + assert!(matches!(events[3], Event::TextMessageEnd(_))); + } + + #[tokio::test] + async fn test_message_stream_empty_content_ignored() { + let producer = MockProducer::new(); + + let msg = MessageStream::start(&producer).await.unwrap(); + msg.content("").await.unwrap(); // Should be ignored + msg.content("Hello").await.unwrap(); + msg.end().await.unwrap(); + + let events = producer.events(); + assert_eq!(events.len(), 3); // start + 1 content + end (empty ignored) + } + + #[tokio::test] + async fn test_tool_call_stream() { + let producer = MockProducer::new(); + + let call = ToolCallStream::start(&producer, "get_weather").await.unwrap(); + call.args(r#"{"location": "#).await.unwrap(); + call.args(r#""NYC"}"#).await.unwrap(); + let _tool_call_id = call.end().await.unwrap(); + + let events = producer.events(); + assert_eq!(events.len(), 4); // start + 2 args + end + + assert!(matches!(events[0], Event::ToolCallStart(_))); + assert!(matches!(events[1], Event::ToolCallArgs(_))); + assert!(matches!(events[2], Event::ToolCallArgs(_))); + assert!(matches!(events[3], Event::ToolCallEnd(_))); + } + + #[tokio::test] + async fn test_agent_session_run_lifecycle() { + let producer = MockProducer::new(); + let mut session = AgentSession::new(producer); + + assert!(!session.is_running()); + + // Start run + let run_id = session.start_run().await.unwrap(); + assert!(session.is_running()); + assert_eq!(session.run_id(), Some(&run_id)); + + // Finish run + session.finish_run(None).await.unwrap(); + assert!(!session.is_running()); + assert_eq!(session.run_id(), None); + + let events = session.producer().events(); + assert_eq!(events.len(), 2); + assert!(matches!(events[0], Event::RunStarted(_))); + assert!(matches!(events[1], Event::RunFinished(_))); + } + + #[tokio::test] + async fn test_agent_session_run_error() { + let producer = MockProducer::new(); + let mut session = AgentSession::new(producer); + + session.start_run().await.unwrap(); + session.run_error("Something went wrong").await.unwrap(); + + assert!(!session.is_running()); + + let events = session.producer().events(); + assert_eq!(events.len(), 2); + assert!(matches!(events[0], Event::RunStarted(_))); + assert!(matches!(events[1], Event::RunError(_))); + } + + #[tokio::test] + async fn test_agent_session_double_start_error() { + let producer = MockProducer::new(); + let mut session = AgentSession::new(producer); + + session.start_run().await.unwrap(); + let result = session.start_run().await; + + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_agent_session_finish_without_run() { + let producer = MockProducer::new(); + let mut session = AgentSession::new(producer); + + // Should not error, just do nothing + session.finish_run(None).await.unwrap(); + + let events = session.producer().events(); + assert!(events.is_empty()); + } + + // ========================================================================= + // Thinking Message Stream Tests + // ========================================================================= + + #[tokio::test] + async fn test_thinking_message_stream() { + let producer = MockProducer::new(); + + let thinking = ThinkingMessageStream::start(&producer).await.unwrap(); + thinking.content("Let me analyze...").await.unwrap(); + thinking.content("The answer is...").await.unwrap(); + thinking.end().await.unwrap(); + + let events = producer.events(); + assert_eq!(events.len(), 4); // start + 2 content + end + + assert!(matches!(events[0], Event::ThinkingTextMessageStart(_))); + assert!(matches!(events[1], Event::ThinkingTextMessageContent(_))); + assert!(matches!(events[2], Event::ThinkingTextMessageContent(_))); + assert!(matches!(events[3], Event::ThinkingTextMessageEnd(_))); + } + + #[tokio::test] + async fn test_thinking_message_stream_empty_content_allowed() { + let producer = MockProducer::new(); + + let thinking = ThinkingMessageStream::start(&producer).await.unwrap(); + thinking.content("").await.unwrap(); // Empty is allowed for thinking + thinking.content("Thinking...").await.unwrap(); + thinking.end().await.unwrap(); + + let events = producer.events(); + // Empty content is emitted (unlike regular MessageStream) + assert_eq!(events.len(), 4); // start + empty + content + end + } + + // ========================================================================= + // Thinking Step Tests + // ========================================================================= + + #[tokio::test] + async fn test_thinking_step() { + let producer = MockProducer::new(); + + let step = ThinkingStep::start(&producer, None::).await.unwrap(); + step.end().await.unwrap(); + + let events = producer.events(); + assert_eq!(events.len(), 2); // start + end + + assert!(matches!(events[0], Event::ThinkingStart(_))); + assert!(matches!(events[1], Event::ThinkingEnd(_))); + } + + #[tokio::test] + async fn test_thinking_step_with_title() { + let producer = MockProducer::new(); + + let step = ThinkingStep::start(&producer, Some("Analyzing query")) + .await + .unwrap(); + step.end().await.unwrap(); + + let events = producer.events(); + assert_eq!(events.len(), 2); + + if let Event::ThinkingStart(start) = &events[0] { + assert_eq!(start.title, Some("Analyzing query".to_string())); + } else { + panic!("Expected ThinkingStart event"); + } + } + + #[tokio::test] + async fn test_thinking_step_with_content() { + let producer = MockProducer::new(); + + let step = ThinkingStep::start(&producer, Some("Planning")) + .await + .unwrap(); + + // Emit thinking content inside the step + let thinking = ThinkingMessageStream::start(step.producer()).await.unwrap(); + thinking.content("First, consider...").await.unwrap(); + thinking.end().await.unwrap(); + + step.end().await.unwrap(); + + let events = producer.events(); + assert_eq!(events.len(), 5); // ThinkingStart + TextStart + content + TextEnd + ThinkingEnd + + assert!(matches!(events[0], Event::ThinkingStart(_))); + assert!(matches!(events[1], Event::ThinkingTextMessageStart(_))); + assert!(matches!(events[2], Event::ThinkingTextMessageContent(_))); + assert!(matches!(events[3], Event::ThinkingTextMessageEnd(_))); + assert!(matches!(events[4], Event::ThinkingEnd(_))); + } + + // ========================================================================= + // AgentSession Thinking Tests + // ========================================================================= + + #[tokio::test] + async fn test_agent_session_start_thinking() { + let producer = MockProducer::new(); + let session = AgentSession::new(producer); + + let step = session.start_thinking(Some("Reasoning")).await.unwrap(); + step.end().await.unwrap(); + + let events = session.producer().events(); + assert_eq!(events.len(), 2); + assert!(matches!(events[0], Event::ThinkingStart(_))); + assert!(matches!(events[1], Event::ThinkingEnd(_))); + } + + #[tokio::test] + async fn test_agent_session_start_thinking_no_title() { + let producer = MockProducer::new(); + let session = AgentSession::new(producer); + + let step = session.start_thinking(None::).await.unwrap(); + step.end().await.unwrap(); + + let events = session.producer().events(); + assert_eq!(events.len(), 2); + + if let Event::ThinkingStart(start) = &events[0] { + assert!(start.title.is_none()); + } else { + panic!("Expected ThinkingStart event"); + } + } + + // ========================================================================= + // AgentSession Interrupt Tests + // ========================================================================= + + #[tokio::test] + async fn test_agent_session_interrupt() { + use ag_ui_core::RunFinishedOutcome; + + let producer = MockProducer::new(); + let mut session = AgentSession::new(producer); + + session.start_run().await.unwrap(); + session + .interrupt( + Some("human_approval"), + Some(serde_json::json!({"action": "send_email"})), + ) + .await + .unwrap(); + + // Run should be cleared after interrupt + assert!(!session.is_running()); + + let events = session.producer().events(); + assert_eq!(events.len(), 2); // RunStarted + RunFinished(interrupt) + + assert!(matches!(events[0], Event::RunStarted(_))); + + if let Event::RunFinished(finished) = &events[1] { + assert_eq!(finished.outcome, Some(RunFinishedOutcome::Interrupt)); + assert!(finished.interrupt.is_some()); + let info = finished.interrupt.as_ref().unwrap(); + assert_eq!(info.reason, Some("human_approval".to_string())); + assert!(info.payload.is_some()); + } else { + panic!("Expected RunFinished event"); + } + } + + #[tokio::test] + async fn test_agent_session_interrupt_with_id() { + use ag_ui_core::RunFinishedOutcome; + + let producer = MockProducer::new(); + let mut session = AgentSession::new(producer); + + session.start_run().await.unwrap(); + session + .interrupt_with_id( + "approval-001", + Some("database_modification"), + Some(serde_json::json!({"query": "DELETE FROM users"})), + ) + .await + .unwrap(); + + assert!(!session.is_running()); + + let events = session.producer().events(); + assert_eq!(events.len(), 2); + + if let Event::RunFinished(finished) = &events[1] { + assert_eq!(finished.outcome, Some(RunFinishedOutcome::Interrupt)); + let info = finished.interrupt.as_ref().unwrap(); + assert_eq!(info.id, Some("approval-001".to_string())); + assert_eq!(info.reason, Some("database_modification".to_string())); + } else { + panic!("Expected RunFinished event"); + } + } + + #[tokio::test] + async fn test_agent_session_interrupt_without_run() { + let producer = MockProducer::new(); + let mut session = AgentSession::new(producer); + + // Interrupt without an active run should do nothing + session + .interrupt(Some("test"), None) + .await + .unwrap(); + + let events = session.producer().events(); + assert!(events.is_empty()); + } + + #[tokio::test] + async fn test_agent_session_interrupt_minimal() { + let producer = MockProducer::new(); + let mut session = AgentSession::new(producer); + + session.start_run().await.unwrap(); + + // Interrupt with no reason or payload + session + .interrupt(None::, None) + .await + .unwrap(); + + let events = session.producer().events(); + assert_eq!(events.len(), 2); + + if let Event::RunFinished(finished) = &events[1] { + let info = finished.interrupt.as_ref().unwrap(); + assert!(info.id.is_none()); + assert!(info.reason.is_none()); + assert!(info.payload.is_none()); + } else { + panic!("Expected RunFinished event"); + } + } +} diff --git a/crates/ag-ui-server/src/transport/mod.rs b/crates/ag-ui-server/src/transport/mod.rs new file mode 100644 index 00000000..d896fad0 --- /dev/null +++ b/crates/ag-ui-server/src/transport/mod.rs @@ -0,0 +1,64 @@ +//! Transport Layer for AG-UI Events +//! +//! This module provides transport implementations for streaming AG-UI events +//! to frontend clients: +//! +//! - **SSE (Server-Sent Events)**: HTTP-based unidirectional streaming via [`sse`] +//! - **WebSocket**: Bidirectional WebSocket transport via [`ws`] +//! +//! # SSE Example +//! +//! ```rust,ignore +//! use ag_ui_server::transport::sse; +//! use ag_ui_core::{Event, RunErrorEvent}; +//! +//! // Create channel pair +//! let (sender, handler) = sse::channel::(32); +//! +//! // Send events from agent code +//! sender.send(Event::RunError(RunErrorEvent::new("error"))).await?; +//! +//! // Return handler as axum response +//! handler.into_response() +//! ``` +//! +//! # WebSocket Example +//! +//! ```rust,ignore +//! use ag_ui_server::transport::ws; +//! use ag_ui_core::{Event, RunErrorEvent}; +//! use axum::extract::ws::WebSocketUpgrade; +//! +//! async fn ws_endpoint(upgrade: WebSocketUpgrade) -> impl IntoResponse { +//! let (sender, handler) = ws::channel::(32); +//! +//! tokio::spawn(async move { +//! sender.send(Event::RunError(RunErrorEvent::new("error"))).await.ok(); +//! }); +//! +//! handler.into_response(upgrade) +//! } +//! ``` +//! +//! # Choosing Between SSE and WebSocket +//! +//! | Feature | SSE | WebSocket | +//! |---------|-----|-----------| +//! | Direction | Server → Client | Bidirectional | +//! | Auto-reconnect | Built-in (EventSource) | Manual | +//! | HTTP/2 multiplexing | Yes | No | +//! | Binary data | No (text only) | Yes | +//! | Browser connection limit | Per-domain | Per-domain | + +pub mod sse; +pub mod ws; + +// Re-export SSE types (default transport) +pub use sse::{channel, format_sse_event, SendError, SseHandler, SseSender}; + +// Re-export WebSocket types with ws_ prefix to avoid conflicts +pub use ws::{ + channel as ws_channel, channel_with_config as ws_channel_with_config, + format_ws_message, SendError as WsSendError, WsConfig, WsHandler, WsSender, + DEFAULT_PING_INTERVAL, +}; diff --git a/crates/ag-ui-server/src/transport/sse.rs b/crates/ag-ui-server/src/transport/sse.rs new file mode 100644 index 00000000..73d69962 --- /dev/null +++ b/crates/ag-ui-server/src/transport/sse.rs @@ -0,0 +1,291 @@ +//! Server-Sent Events (SSE) Transport +//! +//! This module provides SSE transport for streaming AG-UI events to frontend clients. +//! It integrates with axum to provide HTTP SSE endpoints. +//! +//! # Architecture +//! +//! The SSE transport uses a channel-based design: +//! - [`SseSender`] - Used by agent code to send events into the stream +//! - [`SseHandler`] - Converted to an axum SSE response for the HTTP endpoint +//! +//! # Example +//! +//! ```rust,ignore +//! use ag_ui_server::transport::sse; +//! use ag_ui_core::{Event, TextMessageStartEvent, MessageId}; +//! +//! // Create a channel pair +//! let (sender, handler) = sse::channel::(32); +//! +//! // In your axum handler, return the SSE response +//! async fn events_endpoint() -> impl IntoResponse { +//! let (sender, handler) = sse::channel::(32); +//! +//! // Spawn task to send events +//! tokio::spawn(async move { +//! let event = Event::TextMessageStart( +//! TextMessageStartEvent::new(MessageId::random()) +//! ); +//! sender.send(event).await.ok(); +//! }); +//! +//! handler.into_response() +//! } +//! ``` + +use std::convert::Infallible; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use ag_ui_core::{AgentState, Event, JsonValue}; +use axum::response::sse::{Event as AxumSseEvent, KeepAlive, Sse}; +use axum::response::IntoResponse; +use futures::Stream; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; + +use crate::error::ServerError; + +/// Error type for SSE send operations. +#[derive(Debug, Clone)] +pub struct SendError(pub T); + +impl std::fmt::Display for SendError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "channel closed") + } +} + +impl std::error::Error for SendError {} + +/// Sender side of an SSE channel. +/// +/// Use this to send AG-UI events that will be streamed to connected clients. +/// Events are serialized to JSON and formatted as SSE data frames. +#[derive(Debug, Clone)] +pub struct SseSender { + sender: mpsc::Sender>, +} + +impl SseSender { + /// Sends an event to the SSE stream. + /// + /// Returns an error if the receiver has been dropped (client disconnected). + pub async fn send(&self, event: Event) -> Result<(), SendError>> { + self.sender.send(event).await.map_err(|e| SendError(e.0)) + } + + /// Sends multiple events to the SSE stream. + /// + /// Stops and returns an error on the first failed send. + pub async fn send_many( + &self, + events: impl IntoIterator>, + ) -> Result<(), SendError>> { + for event in events { + self.send(event).await?; + } + Ok(()) + } + + /// Tries to send an event without waiting. + /// + /// Returns an error if the channel is full or closed. + pub fn try_send(&self, event: Event) -> Result<(), SendError>> { + self.sender.try_send(event).map_err(|e| SendError(e.into_inner())) + } + + /// Checks if the receiver is still connected. + pub fn is_closed(&self) -> bool { + self.sender.is_closed() + } +} + +/// Handler side of an SSE channel. +/// +/// This is converted to an axum SSE response that streams events to the client. +/// Each event is serialized to JSON and sent as an SSE data frame. +pub struct SseHandler { + receiver: mpsc::Receiver>, +} + +impl SseHandler { + /// Converts this handler into an axum SSE response. + /// + /// The response will stream events as they are sent through the corresponding + /// [`SseSender`]. The stream ends when the sender is dropped. + pub fn into_response(self) -> impl IntoResponse { + let stream = SseEventStream { + inner: ReceiverStream::new(self.receiver), + }; + + Sse::new(stream).keep_alive(KeepAlive::default()) + } +} + +/// Internal stream wrapper that converts Events to axum SSE events. +struct SseEventStream { + inner: ReceiverStream>, +} + +impl Stream for SseEventStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Ready(Some(event)) => { + // Serialize event to JSON + let json = match serde_json::to_string(&event) { + Ok(json) => json, + Err(e) => { + // Log error and send error event + eprintln!("SSE serialization error: {}", e); + format!(r#"{{"type":"RUN_ERROR","message":"Serialization error: {}"}}"#, e) + } + }; + + // Create SSE event with the event type as the SSE event name + let sse_event = AxumSseEvent::default() + .event(event.event_type().as_str()) + .data(json); + + Poll::Ready(Some(Ok(sse_event))) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +/// Creates a new SSE channel pair. +/// +/// The `buffer` parameter controls how many events can be queued before +/// sends will block (or fail for `try_send`). +/// +/// # Arguments +/// +/// * `buffer` - The capacity of the internal channel buffer +/// +/// # Returns +/// +/// A tuple of (`SseSender`, `SseHandler`) that are connected. +/// +/// # Example +/// +/// ```rust,ignore +/// let (sender, handler) = sse::channel::(32); +/// ``` +pub fn channel(buffer: usize) -> (SseSender, SseHandler) { + let (tx, rx) = mpsc::channel(buffer); + (SseSender { sender: tx }, SseHandler { receiver: rx }) +} + +/// Serializes an event to SSE format. +/// +/// Returns the event formatted as `data: {json}\n\n`. +pub fn format_sse_event(event: &Event) -> Result { + let json = serde_json::to_string(event) + .map_err(|e| ServerError::Serialization(e.to_string()))?; + Ok(format!("data: {}\n\n", json)) +} + +#[cfg(test)] +mod tests { + use super::*; + use ag_ui_core::{ + MessageId, RunErrorEvent, TextMessageContentEvent, TextMessageStartEvent, + }; + + #[tokio::test] + async fn test_channel_creation() { + let (sender, _handler) = channel::(10); + assert!(!sender.is_closed()); + } + + #[tokio::test] + async fn test_send_event() { + let (sender, mut handler) = channel::(10); + + let event: Event = Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())); + + sender.send(event.clone()).await.unwrap(); + + // Receive from the handler's receiver directly for testing + let received = handler.receiver.recv().await.unwrap(); + assert_eq!(received.event_type(), event.event_type()); + } + + #[tokio::test] + async fn test_send_many_events() { + let (sender, mut handler) = channel::(10); + + let events: Vec = vec![ + Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())), + Event::TextMessageContent(TextMessageContentEvent::new_unchecked( + MessageId::random(), + "Hello", + )), + Event::RunError(RunErrorEvent::new("test error")), + ]; + + sender.send_many(events.clone()).await.unwrap(); + + // Verify all events received + for expected in &events { + let received = handler.receiver.recv().await.unwrap(); + assert_eq!(received.event_type(), expected.event_type()); + } + } + + #[tokio::test] + async fn test_channel_close_detection() { + let (sender, handler) = channel::(10); + + // Drop the handler + drop(handler); + + // Sender should detect closure + assert!(sender.is_closed()); + + // Send should fail + let event: Event = Event::RunError(RunErrorEvent::new("test")); + let result = sender.send(event).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_try_send() { + let (sender, _handler) = channel::(2); + + let event: Event = Event::RunError(RunErrorEvent::new("test")); + + // First two should succeed (buffer size is 2) + assert!(sender.try_send(event.clone()).is_ok()); + assert!(sender.try_send(event.clone()).is_ok()); + + // Third should fail (buffer full) + assert!(sender.try_send(event).is_err()); + } + + #[test] + fn test_format_sse_event() { + let event: Event = Event::RunError(RunErrorEvent::new("test error")); + let formatted = format_sse_event(&event).unwrap(); + + assert!(formatted.starts_with("data: ")); + assert!(formatted.ends_with("\n\n")); + assert!(formatted.contains("\"type\":\"RUN_ERROR\"")); + assert!(formatted.contains("\"message\":\"test error\"")); + } + + #[test] + fn test_format_sse_event_with_complex_event() { + let event: Event = Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())); + let formatted = format_sse_event(&event).unwrap(); + + assert!(formatted.contains("\"type\":\"TEXT_MESSAGE_START\"")); + assert!(formatted.contains("\"messageId\":")); + assert!(formatted.contains("\"role\":\"assistant\"")); + } +} diff --git a/crates/ag-ui-server/src/transport/ws.rs b/crates/ag-ui-server/src/transport/ws.rs new file mode 100644 index 00000000..333e8c68 --- /dev/null +++ b/crates/ag-ui-server/src/transport/ws.rs @@ -0,0 +1,443 @@ +//! WebSocket Transport for AG-UI Events +//! +//! This module provides WebSocket transport for streaming AG-UI events to frontend clients. +//! It integrates with axum to provide WebSocket endpoints as an alternative to SSE. +//! +//! # Architecture +//! +//! The WebSocket transport uses a channel-based design similar to SSE: +//! - [`WsSender`] - Used by agent code to send events into the WebSocket stream +//! - [`WsHandler`] - Handles the WebSocket connection and streams events +//! +//! # Example +//! +//! ```rust,ignore +//! use ag_ui_server::transport::ws; +//! use ag_ui_core::{Event, TextMessageStartEvent, MessageId}; +//! use axum::extract::ws::WebSocketUpgrade; +//! +//! async fn ws_endpoint(upgrade: WebSocketUpgrade) -> impl IntoResponse { +//! let (sender, handler) = ws::channel::(32); +//! +//! // Spawn task to send events +//! tokio::spawn(async move { +//! let event = Event::TextMessageStart( +//! TextMessageStartEvent::new(MessageId::random()) +//! ); +//! sender.send(event).await.ok(); +//! }); +//! +//! handler.into_response(upgrade) +//! } +//! ``` +//! +//! # SSE vs WebSocket +//! +//! Choose WebSocket when: +//! - You need bidirectional communication (future AG-UI extensions) +//! - You want lower latency for high-frequency updates +//! - You need to work around SSE connection limits in browsers +//! +//! Choose SSE when: +//! - You only need server-to-client streaming (current AG-UI) +//! - You want automatic reconnection (built into EventSource) +//! - You need HTTP/2 multiplexing benefits + +use std::time::Duration; + +use ag_ui_core::{AgentState, Event, JsonValue}; +use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade}; +use axum::response::IntoResponse; +use futures::{SinkExt, StreamExt}; +use tokio::sync::mpsc; +use tokio::time::interval; + +use crate::error::ServerError; + +/// Default ping interval for WebSocket keep-alive (30 seconds). +pub const DEFAULT_PING_INTERVAL: Duration = Duration::from_secs(30); + +/// Error type for WebSocket send operations. +#[derive(Debug, Clone)] +pub struct SendError(pub T); + +impl std::fmt::Display for SendError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "WebSocket channel closed") + } +} + +impl std::error::Error for SendError {} + +/// Configuration for WebSocket connections. +#[derive(Debug, Clone)] +pub struct WsConfig { + /// Interval between ping messages for keep-alive. + pub ping_interval: Duration, + /// Whether to send ping messages. + pub enable_ping: bool, +} + +impl Default for WsConfig { + fn default() -> Self { + Self { + ping_interval: DEFAULT_PING_INTERVAL, + enable_ping: true, + } + } +} + +impl WsConfig { + /// Creates a new configuration with default values. + pub fn new() -> Self { + Self::default() + } + + /// Sets the ping interval. + pub fn ping_interval(mut self, interval: Duration) -> Self { + self.ping_interval = interval; + self + } + + /// Disables ping messages. + pub fn disable_ping(mut self) -> Self { + self.enable_ping = false; + self + } +} + +/// Sender side of a WebSocket channel. +/// +/// Use this to send AG-UI events that will be streamed to connected clients. +/// Events are serialized to JSON and sent as WebSocket text messages. +#[derive(Debug, Clone)] +pub struct WsSender { + sender: mpsc::Sender>, +} + +impl WsSender { + /// Sends an event to the WebSocket stream. + /// + /// Returns an error if the receiver has been dropped (client disconnected). + pub async fn send(&self, event: Event) -> Result<(), SendError>> { + self.sender.send(event).await.map_err(|e| SendError(e.0)) + } + + /// Sends multiple events to the WebSocket stream. + /// + /// Stops and returns an error on the first failed send. + pub async fn send_many( + &self, + events: impl IntoIterator>, + ) -> Result<(), SendError>> { + for event in events { + self.send(event).await?; + } + Ok(()) + } + + /// Tries to send an event without waiting. + /// + /// Returns an error if the channel is full or closed. + pub fn try_send(&self, event: Event) -> Result<(), SendError>> { + self.sender + .try_send(event) + .map_err(|e| SendError(e.into_inner())) + } + + /// Checks if the receiver is still connected. + pub fn is_closed(&self) -> bool { + self.sender.is_closed() + } +} + +/// Handler side of a WebSocket channel. +/// +/// This handles the WebSocket connection and streams events from the sender. +pub struct WsHandler { + receiver: mpsc::Receiver>, + config: WsConfig, +} + +impl WsHandler { + /// Converts a WebSocket upgrade into an axum response. + /// + /// The response will upgrade to WebSocket and stream events as they are + /// sent through the corresponding [`WsSender`]. + pub fn into_response(self, upgrade: WebSocketUpgrade) -> impl IntoResponse { + upgrade.on_upgrade(move |socket| self.handle_socket(socket)) + } + + /// Handles the WebSocket connection. + async fn handle_socket(self, socket: WebSocket) { + let (mut ws_sender, mut ws_receiver) = socket.split(); + let mut event_receiver = self.receiver; + + // Create ping interval if enabled + let mut ping_interval = if self.config.enable_ping { + Some(interval(self.config.ping_interval)) + } else { + None + }; + + loop { + tokio::select! { + // Handle incoming events to send + event = event_receiver.recv() => { + match event { + Some(event) => { + // Serialize event to JSON + let json = match serde_json::to_string(&event) { + Ok(json) => json, + Err(e) => { + eprintln!("WebSocket serialization error: {}", e); + continue; + } + }; + + // Send as text message + if ws_sender.send(Message::Text(json.into())).await.is_err() { + // Client disconnected + break; + } + } + None => { + // Event channel closed, send close frame and exit + let _ = ws_sender.send(Message::Close(None)).await; + break; + } + } + } + + // Handle ping interval + _ = async { + if let Some(ref mut interval) = ping_interval { + interval.tick().await; + } else { + // Never completes if ping disabled + std::future::pending::<()>().await; + } + } => { + if ws_sender.send(Message::Ping(vec![].into())).await.is_err() { + break; + } + } + + // Handle incoming WebSocket messages (for close/pong) + msg = ws_receiver.next() => { + match msg { + Some(Ok(Message::Pong(_))) => { + // Pong received, connection is alive + } + Some(Ok(Message::Close(_))) | None => { + // Client closed connection + break; + } + Some(Ok(_)) => { + // Ignore other message types (Text, Binary) + // AG-UI is unidirectional server->client + } + Some(Err(_)) => { + // WebSocket error + break; + } + } + } + } + } + } +} + +/// Creates a new WebSocket channel pair with default configuration. +/// +/// The `buffer` parameter controls how many events can be queued before +/// sends will block (or fail for `try_send`). +/// +/// # Arguments +/// +/// * `buffer` - The capacity of the internal channel buffer +/// +/// # Returns +/// +/// A tuple of (`WsSender`, `WsHandler`) that are connected. +/// +/// # Example +/// +/// ```rust,ignore +/// let (sender, handler) = ws::channel::(32); +/// ``` +pub fn channel(buffer: usize) -> (WsSender, WsHandler) { + channel_with_config(buffer, WsConfig::default()) +} + +/// Creates a new WebSocket channel pair with custom configuration. +/// +/// # Arguments +/// +/// * `buffer` - The capacity of the internal channel buffer +/// * `config` - WebSocket configuration options +/// +/// # Returns +/// +/// A tuple of (`WsSender`, `WsHandler`) that are connected. +/// +/// # Example +/// +/// ```rust,ignore +/// let config = WsConfig::new() +/// .ping_interval(Duration::from_secs(15)) +/// .disable_ping(); +/// let (sender, handler) = ws::channel_with_config::(32, config); +/// ``` +pub fn channel_with_config( + buffer: usize, + config: WsConfig, +) -> (WsSender, WsHandler) { + let (tx, rx) = mpsc::channel(buffer); + ( + WsSender { sender: tx }, + WsHandler { + receiver: rx, + config, + }, + ) +} + +/// Serializes an event to a WebSocket text message. +/// +/// Returns the JSON string suitable for sending as a WebSocket text frame. +pub fn format_ws_message(event: &Event) -> Result { + serde_json::to_string(event).map_err(|e| ServerError::Serialization(e.to_string())) +} + +#[cfg(test)] +mod tests { + use super::*; + use ag_ui_core::{MessageId, RunErrorEvent, TextMessageContentEvent, TextMessageStartEvent}; + + #[tokio::test] + async fn test_channel_creation() { + let (sender, _handler) = channel::(10); + assert!(!sender.is_closed()); + } + + #[tokio::test] + async fn test_channel_with_config() { + let config = WsConfig::new() + .ping_interval(Duration::from_secs(10)) + .disable_ping(); + + let (sender, handler) = channel_with_config::(10, config); + assert!(!sender.is_closed()); + assert!(!handler.config.enable_ping); + assert_eq!(handler.config.ping_interval, Duration::from_secs(10)); + } + + #[tokio::test] + async fn test_send_event() { + let (sender, mut handler) = channel::(10); + + let event: Event = Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())); + + sender.send(event.clone()).await.unwrap(); + + // Receive from the handler's receiver directly for testing + let received = handler.receiver.recv().await.unwrap(); + assert_eq!(received.event_type(), event.event_type()); + } + + #[tokio::test] + async fn test_send_many_events() { + let (sender, mut handler) = channel::(10); + + let events: Vec = vec![ + Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())), + Event::TextMessageContent(TextMessageContentEvent::new_unchecked( + MessageId::random(), + "Hello", + )), + Event::RunError(RunErrorEvent::new("test error")), + ]; + + sender.send_many(events.clone()).await.unwrap(); + + // Verify all events received + for expected in &events { + let received = handler.receiver.recv().await.unwrap(); + assert_eq!(received.event_type(), expected.event_type()); + } + } + + #[tokio::test] + async fn test_channel_close_detection() { + let (sender, handler) = channel::(10); + + // Drop the handler + drop(handler); + + // Sender should detect closure + assert!(sender.is_closed()); + + // Send should fail + let event: Event = Event::RunError(RunErrorEvent::new("test")); + let result = sender.send(event).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_try_send() { + let (sender, _handler) = channel::(2); + + let event: Event = Event::RunError(RunErrorEvent::new("test")); + + // First two should succeed (buffer size is 2) + assert!(sender.try_send(event.clone()).is_ok()); + assert!(sender.try_send(event.clone()).is_ok()); + + // Third should fail (buffer full) + assert!(sender.try_send(event).is_err()); + } + + #[test] + fn test_format_ws_message() { + let event: Event = Event::RunError(RunErrorEvent::new("test error")); + let message = format_ws_message(&event).unwrap(); + + assert!(message.contains("\"type\":\"RUN_ERROR\"")); + assert!(message.contains("\"message\":\"test error\"")); + } + + #[test] + fn test_format_ws_message_complex() { + let event: Event = + Event::TextMessageStart(TextMessageStartEvent::new(MessageId::random())); + let message = format_ws_message(&event).unwrap(); + + assert!(message.contains("\"type\":\"TEXT_MESSAGE_START\"")); + assert!(message.contains("\"messageId\":")); + assert!(message.contains("\"role\":\"assistant\"")); + } + + #[test] + fn test_ws_config_default() { + let config = WsConfig::default(); + assert!(config.enable_ping); + assert_eq!(config.ping_interval, DEFAULT_PING_INTERVAL); + } + + #[test] + fn test_ws_config_builder() { + let config = WsConfig::new() + .ping_interval(Duration::from_secs(60)) + .disable_ping(); + + assert!(!config.enable_ping); + assert_eq!(config.ping_interval, Duration::from_secs(60)); + } + + #[test] + fn test_send_error_display() { + let error: SendError = SendError(42); + assert_eq!(format!("{}", error), "WebSocket channel closed"); + } +}