diff --git a/mcp/client.go b/mcp/client.go index 9f6f2955..74dbadf0 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -399,6 +399,12 @@ func (cs *ClientSession) getCachedTool(name string) *Tool { return cs.toolCache[name] } +// hasPendingRequests is the client-side counterpart of +// ServerSession.hasPendingRequests. See keepaliveSession in shared.go. +func (cs *ClientSession) hasPendingRequests() bool { + return false +} + // registerElicitationWaiter registers a waiter for an elicitation complete // notification with the given elicitation ID. It returns two functions: an await // function that waits for the notification or context cancellation, and a cleanup diff --git a/mcp/server.go b/mcp/server.go index b86b3e6c..d1047242 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1167,6 +1167,16 @@ type ServerSession struct { state ServerSessionState } +func (ss *ServerSession) hasPendingRequests() bool { + type pendingReporter interface { + pendingClientRequests() int + } + if c, ok := ss.mcpConn.(pendingReporter); ok { + return c.pendingClientRequests() > 0 + } + return false +} + func (ss *ServerSession) updateState(mut func(*ServerSessionState)) { ss.mu.Lock() mut(&ss.state) diff --git a/mcp/shared.go b/mcp/shared.go index 1caacac3..165f4b5a 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -745,6 +745,7 @@ type listResult[T any] interface { type keepaliveSession interface { Ping(ctx context.Context, params *PingParams) error Close() error + hasPendingRequests() bool } // startKeepalive starts the keepalive mechanism for a session. @@ -769,6 +770,10 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr case <-ctx.Done(): return case <-ticker.C: + if session.hasPendingRequests() { + // Active request is a liveness signal; skip this tick. + continue + } pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2) err := session.Ping(pingCtx, nil) pingCancel() diff --git a/mcp/streamable.go b/mcp/streamable.go index 0f4e65b8..8bff3ffd 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -825,6 +825,16 @@ func (c *streamableServerConn) SessionID() string { return c.sessionID } +// pendingClientRequests returns the number of incoming requests from the +// client that the server has not yet finished responding to. Used by the +// keepalive loop to skip pings while a tool call (or any other request) +// is still being handled — the in-flight response IS the liveness signal. +func (c *streamableServerConn) pendingClientRequests() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.requestStreams) +} + // A stream is a single logical stream of SSE events within a server session. // A stream begins with a client request, or with a client GET that has // no Last-Event-ID header.