diff --git a/src/agent-client-protocol/src/agent.rs b/src/agent-client-protocol/src/agent.rs index 0606b42..dd163e1 100644 --- a/src/agent-client-protocol/src/agent.rs +++ b/src/agent-client-protocol/src/agent.rs @@ -11,6 +11,8 @@ use agent_client_protocol_schema::{ use agent_client_protocol_schema::{CloseSessionRequest, CloseSessionResponse}; #[cfg(feature = "unstable_session_fork")] use agent_client_protocol_schema::{ForkSessionRequest, ForkSessionResponse}; +#[cfg(feature = "unstable_logout")] +use agent_client_protocol_schema::{LogoutRequest, LogoutResponse}; #[cfg(feature = "unstable_session_resume")] use agent_client_protocol_schema::{ResumeSessionRequest, ResumeSessionResponse}; #[cfg(feature = "unstable_session_model")] @@ -46,6 +48,21 @@ pub trait Agent { /// See protocol docs: [Initialization](https://agentclientprotocol.com/protocol/initialization) async fn authenticate(&self, args: AuthenticateRequest) -> Result; + /// **UNSTABLE** + /// + /// This capability is not part of the spec yet, and may be removed or changed at any point. + /// + /// Logs out of the current authenticated state. + /// + /// After a successful logout, all new sessions will require authentication. + /// There is no guarantee about the behavior of already running sessions. + /// + /// Only available if the Agent supports the `auth.logout` capability. + #[cfg(feature = "unstable_logout")] + async fn logout(&self, _args: LogoutRequest) -> Result { + Err(Error::method_not_found()) + } + /// Creates a new conversation session with the agent. /// /// Sessions represent independent conversation contexts with their own history and state. @@ -229,6 +246,10 @@ impl Agent for Rc { async fn authenticate(&self, args: AuthenticateRequest) -> Result { self.as_ref().authenticate(args).await } + #[cfg(feature = "unstable_logout")] + async fn logout(&self, args: LogoutRequest) -> Result { + self.as_ref().logout(args).await + } async fn new_session(&self, args: NewSessionRequest) -> Result { self.as_ref().new_session(args).await } @@ -292,6 +313,10 @@ impl Agent for Arc { async fn authenticate(&self, args: AuthenticateRequest) -> Result { self.as_ref().authenticate(args).await } + #[cfg(feature = "unstable_logout")] + async fn logout(&self, args: LogoutRequest) -> Result { + self.as_ref().logout(args).await + } async fn new_session(&self, args: NewSessionRequest) -> Result { self.as_ref().new_session(args).await } diff --git a/src/agent-client-protocol/src/lib.rs b/src/agent-client-protocol/src/lib.rs index 813e6e6..383473b 100644 --- a/src/agent-client-protocol/src/lib.rs +++ b/src/agent-client-protocol/src/lib.rs @@ -95,6 +95,17 @@ impl Agent for ClientSideConnection { .map(Option::unwrap_or_default) } + #[cfg(feature = "unstable_logout")] + async fn logout(&self, args: LogoutRequest) -> Result { + self.conn + .request::>( + AGENT_METHOD_NAMES.logout, + Some(ClientRequest::LogoutRequest(args)), + ) + .await + .map(Option::unwrap_or_default) + } + async fn new_session(&self, args: NewSessionRequest) -> Result { self.conn .request( @@ -554,6 +565,10 @@ impl Side for AgentSide { m if m == AGENT_METHOD_NAMES.authenticate => serde_json::from_str(params.get()) .map(ClientRequest::AuthenticateRequest) .map_err(Into::into), + #[cfg(feature = "unstable_logout")] + m if m == AGENT_METHOD_NAMES.logout => serde_json::from_str(params.get()) + .map(ClientRequest::LogoutRequest) + .map_err(Into::into), m if m == AGENT_METHOD_NAMES.session_new => serde_json::from_str(params.get()) .map(ClientRequest::NewSessionRequest) .map_err(Into::into), @@ -635,6 +650,11 @@ impl MessageHandler for T { let response = self.authenticate(args).await?; Ok(AgentResponse::AuthenticateResponse(response)) } + #[cfg(feature = "unstable_logout")] + ClientRequest::LogoutRequest(args) => { + let response = self.logout(args).await?; + Ok(AgentResponse::LogoutResponse(response)) + } ClientRequest::NewSessionRequest(args) => { let response = self.new_session(args).await?; Ok(AgentResponse::NewSessionResponse(response)) diff --git a/src/agent-client-protocol/src/rpc_tests.rs b/src/agent-client-protocol/src/rpc_tests.rs index e45c827..ecdbf2d 100644 --- a/src/agent-client-protocol/src/rpc_tests.rs +++ b/src/agent-client-protocol/src/rpc_tests.rs @@ -133,6 +133,8 @@ struct TestAgent { sessions: Arc>>, prompts_received: Arc>>, cancellations_received: Arc>>, + #[cfg(feature = "unstable_logout")] + logout_count: Arc>, extension_notifications: Arc>>, } @@ -144,6 +146,8 @@ impl TestAgent { sessions: Arc::new(Mutex::new(std::collections::HashMap::new())), prompts_received: Arc::new(Mutex::new(Vec::new())), cancellations_received: Arc::new(Mutex::new(Vec::new())), + #[cfg(feature = "unstable_logout")] + logout_count: Arc::new(Mutex::new(0)), extension_notifications: Arc::new(Mutex::new(Vec::new())), } } @@ -153,6 +157,12 @@ impl TestAgent { impl Agent for TestAgent { async fn initialize(&self, arguments: InitializeRequest) -> Result { Ok(InitializeResponse::new(arguments.protocol_version) + .agent_capabilities( + AgentCapabilities::new().auth( + agent_client_protocol_schema::AgentAuthCapabilities::new() + .logout(agent_client_protocol_schema::LogoutCapabilities::new()), + ), + ) .agent_info(Implementation::new("test-agent", "0.0.0").title("Test Agent"))) } @@ -160,6 +170,15 @@ impl Agent for TestAgent { Ok(AuthenticateResponse::default()) } + #[cfg(feature = "unstable_logout")] + async fn logout( + &self, + _arguments: agent_client_protocol_schema::LogoutRequest, + ) -> Result { + *self.logout_count.lock().unwrap() += 1; + Ok(agent_client_protocol_schema::LogoutResponse::default()) + } + async fn new_session(&self, arguments: NewSessionRequest) -> Result { let session_id = SessionId::new("test-session-123"); self.sessions @@ -886,6 +905,44 @@ async fn test_session_info_update() { .await; } +#[cfg(feature = "unstable_logout")] +#[tokio::test] +async fn test_logout() { + let local_set = tokio::task::LocalSet::new(); + local_set + .run_until(async { + let client = TestClient::new(); + let agent = TestAgent::new(); + + let (agent_conn, _client_conn) = create_connection_pair(&client, &agent); + + let initialize_response = + agent_conn + .initialize(InitializeRequest::new(ProtocolVersion::LATEST).client_info( + Implementation::new("test-client", "0.0.0").title("Test Client"), + )) + .await + .expect("initialize failed"); + + assert!( + initialize_response.agent_capabilities.auth.logout.is_some(), + "agent should advertise auth.logout capability" + ); + + let response = agent_conn + .logout(agent_client_protocol_schema::LogoutRequest::new()) + .await + .expect("logout failed"); + + assert_eq!( + response, + agent_client_protocol_schema::LogoutResponse::default() + ); + assert_eq!(*agent.logout_count.lock().unwrap(), 1); + }) + .await; +} + #[tokio::test] async fn test_set_session_config_option() { let local_set = tokio::task::LocalSet::new();