diff --git a/crates/rmcp/src/handler/client.rs b/crates/rmcp/src/handler/client.rs index 15b1c0c0..db73bf36 100644 --- a/crates/rmcp/src/handler/client.rs +++ b/crates/rmcp/src/handler/client.rs @@ -4,6 +4,7 @@ use crate::{ model::*, service::{NotificationContext, RequestContext, RoleClient, Service, ServiceRole}, }; +use std::sync::Arc; impl Service for H { async fn handle_request( @@ -210,3 +211,115 @@ impl ClientHandler for ClientInfo { self.clone() } } + +macro_rules! impl_client_handler_for_wrapper { + ($wrapper:ident) => { + impl ClientHandler for $wrapper { + fn ping( + &self, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).ping(context) + } + + fn create_message( + &self, + params: CreateMessageRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).create_message(params, context) + } + + fn list_roots( + &self, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).list_roots(context) + } + + fn create_elicitation( + &self, + request: CreateElicitationRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).create_elicitation(request, context) + } + + fn on_custom_request( + &self, + request: CustomRequest, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).on_custom_request(request, context) + } + + fn on_cancelled( + &self, + params: CancelledNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_cancelled(params, context) + } + + fn on_progress( + &self, + params: ProgressNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_progress(params, context) + } + + fn on_logging_message( + &self, + params: LoggingMessageNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_logging_message(params, context) + } + + fn on_resource_updated( + &self, + params: ResourceUpdatedNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_resource_updated(params, context) + } + + fn on_resource_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_resource_list_changed(context) + } + + fn on_tool_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_tool_list_changed(context) + } + + fn on_prompt_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_prompt_list_changed(context) + } + + fn on_custom_notification( + &self, + notification: CustomNotification, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_custom_notification(notification, context) + } + + fn get_info(&self) -> ClientInfo { + (**self).get_info() + } + } + }; +} + +impl_client_handler_for_wrapper!(Box); +impl_client_handler_for_wrapper!(Arc); diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index f10cfa7c..30b90525 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -3,6 +3,7 @@ use crate::{ model::*, service::{NotificationContext, RequestContext, RoleServer, Service, ServiceRole}, }; +use std::sync::Arc; pub mod common; pub mod prompt; @@ -327,3 +328,206 @@ pub trait ServerHandler: Sized + Send + Sync + 'static { std::future::ready(Err(McpError::method_not_found::())) } } + +macro_rules! impl_server_handler_for_wrapper { + ($wrapper:ident) => { + impl ServerHandler for $wrapper { + fn enqueue_task( + &self, + request: CallToolRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).enqueue_task(request, context) + } + + fn ping( + &self, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).ping(context) + } + + fn initialize( + &self, + request: InitializeRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).initialize(request, context) + } + + fn complete( + &self, + request: CompleteRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).complete(request, context) + } + + fn set_level( + &self, + request: SetLevelRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).set_level(request, context) + } + + fn get_prompt( + &self, + request: GetPromptRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).get_prompt(request, context) + } + + fn list_prompts( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).list_prompts(request, context) + } + + fn list_resources( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).list_resources(request, context) + } + + fn list_resource_templates( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ + { + (**self).list_resource_templates(request, context) + } + + fn read_resource( + &self, + request: ReadResourceRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).read_resource(request, context) + } + + fn subscribe( + &self, + request: SubscribeRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).subscribe(request, context) + } + + fn unsubscribe( + &self, + request: UnsubscribeRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).unsubscribe(request, context) + } + + fn call_tool( + &self, + request: CallToolRequestParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).call_tool(request, context) + } + + fn list_tools( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).list_tools(request, context) + } + + fn on_custom_request( + &self, + request: CustomRequest, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).on_custom_request(request, context) + } + + fn on_cancelled( + &self, + notification: CancelledNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_cancelled(notification, context) + } + + fn on_progress( + &self, + notification: ProgressNotificationParam, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_progress(notification, context) + } + + fn on_initialized( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_initialized(context) + } + + fn on_roots_list_changed( + &self, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_roots_list_changed(context) + } + + fn on_custom_notification( + &self, + notification: CustomNotification, + context: NotificationContext, + ) -> impl Future + Send + '_ { + (**self).on_custom_notification(notification, context) + } + + fn get_info(&self) -> ServerInfo { + (**self).get_info() + } + + fn list_tasks( + &self, + request: Option, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).list_tasks(request, context) + } + + fn get_task_info( + &self, + request: GetTaskInfoParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).get_task_info(request, context) + } + + fn get_task_result( + &self, + request: GetTaskResultParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).get_task_result(request, context) + } + + fn cancel_task( + &self, + request: CancelTaskParam, + context: RequestContext, + ) -> impl Future> + Send + '_ { + (**self).cancel_task(request, context) + } + } + }; +} + +impl_server_handler_for_wrapper!(Box); +impl_server_handler_for_wrapper!(Arc); diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs index 0b908081..1f34ba5b 100644 --- a/crates/rmcp/src/handler/server/router.rs +++ b/crates/rmcp/src/handler/server/router.rs @@ -133,6 +133,6 @@ where } fn get_info(&self) -> ::Info { - self.service.get_info() + ServerHandler::get_info(&self.service) } } diff --git a/crates/rmcp/tests/test_handler_wrappers.rs b/crates/rmcp/tests/test_handler_wrappers.rs new file mode 100644 index 00000000..06558cdf --- /dev/null +++ b/crates/rmcp/tests/test_handler_wrappers.rs @@ -0,0 +1,29 @@ +// cargo test --test test_handler_wrappers --features "client server" + +mod common; + +use std::sync::Arc; + +use rmcp::{ClientHandler, ServerHandler}; + +use common::handlers::{TestClientHandler, TestServer}; + +#[test] +fn test_wrapped_server_handlers() { + // This test asserts that, when T: ServerHandler, both Box and Arc also implement ServerHandler. + fn accepts_server_handler(_handler: H) {} + + accepts_server_handler(Box::new(TestServer::new())); + accepts_server_handler(Arc::new(TestServer::new())); +} + +#[test] +fn test_wrapped_client_handlers() { + // This test asserts that, when T: ClientHandler, both Box and Arc also implement ClientHandler. + fn accepts_client_handler(_handler: H) {} + + let client = TestClientHandler::new(false, false); + + accepts_client_handler(Box::new(client.clone())); + accepts_client_handler(Arc::new(client)); +}