Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ type ClientOptions struct {
// If the peer fails to respond to pings originating from the keepalive check,
// the session is automatically closed.
KeepAlive time.Duration
// KeepAliveFailureThreshold is the number of consecutive keepalive ping
// failures tolerated before the session is closed. A value of 0 or 1
// closes the session on the first failure (the default). Higher values
// align with the spec's "multiple failed pings MAY trigger a connection
// reset" guidance, letting a transient miss pass without tearing down an
// otherwise live session. Has no effect unless KeepAlive is non-zero.
KeepAliveFailureThreshold int
}

// toolContextKeyType is the context key type for passing tool definitions
Expand Down Expand Up @@ -441,7 +448,7 @@ func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await

// startKeepalive starts the keepalive mechanism for this client session.
func (cs *ClientSession) startKeepalive(interval time.Duration) {
startKeepalive(cs, interval, &cs.keepaliveCancel, cs.client.opts.Logger)
startKeepalive(cs, interval, cs.client.opts.KeepAliveFailureThreshold, &cs.keepaliveCancel, cs.client.opts.Logger)
}

// AddRoots adds the given roots to the client,
Expand Down
76 changes: 76 additions & 0 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1920,6 +1920,82 @@ func TestKeepAliveFailure_Logged(t *testing.T) {
})
}

// scriptedKeepaliveSession is a keepaliveSession test double whose Ping
// returns errors from a script (one entry consumed per call; the last entry
// repeats once exhausted), and records how many times Close was called. Ping
// returns immediately so the keepalive loop's pace is driven purely by the
// ticker, making the test deterministic under synctest.
type scriptedKeepaliveSession struct {
pingErrs []error
pingCalls atomic.Int64
closeCalls atomic.Int64
}

func (s *scriptedKeepaliveSession) Ping(context.Context, *PingParams) error {
n := int(s.pingCalls.Add(1)) - 1
if n >= len(s.pingErrs) {
n = len(s.pingErrs) - 1
}
return s.pingErrs[n]
}

func (s *scriptedKeepaliveSession) Close() error {
s.closeCalls.Add(1)
return nil
}

// TestStartKeepalive_FailureThreshold verifies that the session is kept alive
// across consecutive ping failures below the threshold and only closed once the
// threshold is reached.
func TestStartKeepalive_FailureThreshold(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
const interval = 100 * time.Millisecond
sess := &scriptedKeepaliveSession{pingErrs: []error{errors.New("boom")}}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
var cancel context.CancelFunc
startKeepalive(sess, interval, 3, &cancel, logger)
defer cancel()

// After two ticks → two failures, still below threshold 3: not closed.
time.Sleep(2*interval + interval/2)
synctest.Wait()
if got := sess.closeCalls.Load(); got != 0 {
t.Fatalf("session closed below threshold: closeCalls=%d (pingCalls=%d)", got, sess.pingCalls.Load())
}

// Third tick → third failure reaches threshold: session closed.
time.Sleep(interval)
synctest.Wait()
if got := sess.closeCalls.Load(); got != 1 {
t.Fatalf("expected one Close at threshold, got closeCalls=%d (pingCalls=%d)", got, sess.pingCalls.Load())
}
})
}

// TestStartKeepalive_SuccessResetsFailures verifies that a successful ping
// resets the consecutive-failure counter, so an isolated failure between
// successes never accumulates toward the threshold.
func TestStartKeepalive_SuccessResetsFailures(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
const interval = 100 * time.Millisecond
// fail, success, fail, fail, then success (the tail repeats): the run
// never has 3 consecutive failures, so the session is never closed.
sess := &scriptedKeepaliveSession{pingErrs: []error{
errors.New("boom"), nil, errors.New("boom"), errors.New("boom"), nil,
}}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
var cancel context.CancelFunc
startKeepalive(sess, interval, 3, &cancel, logger)
defer cancel()

time.Sleep(6 * interval)
synctest.Wait()
if got := sess.closeCalls.Load(); got != 0 {
t.Fatalf("session closed despite a success resetting the counter: closeCalls=%d (pingCalls=%d)", got, sess.pingCalls.Load())
}
})
}

func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
// Adding the same tool pointer twice should not panic and should not
// produce duplicates in the server's tool list.
Expand Down
9 changes: 8 additions & 1 deletion mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ type ServerOptions struct {
// If the peer fails to respond to pings originating from the keepalive check,
// the session is automatically closed.
KeepAlive time.Duration
// KeepAliveFailureThreshold is the number of consecutive keepalive ping
// failures tolerated before the session is closed. A value of 0 or 1
// closes the session on the first failure (the default). Higher values
// align with the spec's "multiple failed pings MAY trigger a connection
// reset" guidance, letting a transient miss pass without tearing down an
// otherwise live session. Has no effect unless KeepAlive is non-zero.
KeepAliveFailureThreshold int
// Function called when a client session subscribes to a resource.
SubscribeHandler func(context.Context, *SubscribeRequest) error
// Function called when a client session unsubscribes from a resource.
Expand Down Expand Up @@ -1605,7 +1612,7 @@ func (ss *ServerSession) Wait() error {

// startKeepalive starts the keepalive mechanism for this server session.
func (ss *ServerSession) startKeepalive(interval time.Duration) {
startKeepalive(ss, interval, &ss.keepaliveCancel, ss.server.opts.Logger)
startKeepalive(ss, interval, ss.server.opts.KeepAliveFailureThreshold, &ss.keepaliveCancel, ss.server.opts.Logger)
}

// pageToken is the internal structure for the opaque pagination cursor.
Expand Down
51 changes: 39 additions & 12 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -751,9 +751,20 @@ type keepaliveSession interface {
// It assigns the cancel function to the provided cancelPtr and starts a goroutine
// that sends ping messages at the specified interval.
//
// logger must be non-nil; ping failures (which terminate the keepalive loop and
// close the session) are reported via logger so they are not silently dropped.
func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr *context.CancelFunc, logger *slog.Logger) {
// failureThreshold is the number of consecutive ping failures tolerated before
// the session is closed; a value below 1 is treated as 1 (close on the first
// failure). A successful ping resets the counter. This mirrors the spec's
// "multiple failed pings MAY trigger a connection reset" language, letting a
// transient miss pass without tearing down an otherwise live session.
//
// logger must be non-nil; ping failures (both the tolerated ones and the final
// one that closes the session) are reported via logger so they are not silently
// dropped.
func startKeepalive(session keepaliveSession, interval time.Duration, failureThreshold int, cancelPtr *context.CancelFunc, logger *slog.Logger) {
if failureThreshold < 1 {
failureThreshold = 1
}

ctx, cancel := context.WithCancel(context.Background())
// Assign cancel function before starting goroutine to avoid race condition.
// We cannot return it because the caller may need to cancel during the
Expand All @@ -764,6 +775,7 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr
ticker := time.NewTicker(interval)
defer ticker.Stop()

consecutiveFailures := 0
for {
select {
case <-ctx.Done():
Expand All @@ -772,17 +784,32 @@ func startKeepalive(session keepaliveSession, interval time.Duration, cancelPtr
pingCtx, pingCancel := context.WithTimeout(context.Background(), interval/2)
err := session.Ping(pingCtx, nil)
pingCancel()
if err != nil {
if errors.Is(err, jsonrpc2.ErrMethodNotFound) {
// Peer doesn't support ping, stop the keepalive process.
return
}
// Ping failed; log it before closing the session so the
// failure is observable to operators. See #218.
logger.Error("keepalive ping failed; closing session", "error", err)
_ = session.Close()
if err == nil {
consecutiveFailures = 0
continue
}
if errors.Is(err, jsonrpc2.ErrMethodNotFound) {
// Peer doesn't support ping, stop the keepalive process.
return
}
consecutiveFailures++
if consecutiveFailures < failureThreshold {
// Tolerate transient failures below the threshold; log so
// the misses are still observable to operators. See #218.
logger.Warn("keepalive ping failed; tolerating below threshold",
"error", err,
"consecutiveFailures", consecutiveFailures,
"failureThreshold", failureThreshold)
continue
}
// Threshold reached; log before closing the session so the
// failure is observable to operators. See #218.
logger.Error("keepalive ping failed; closing session",
"error", err,
"consecutiveFailures", consecutiveFailures,
"failureThreshold", failureThreshold)
_ = session.Close()
return
}
}
}()
Expand Down
Loading