Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions internal/runtime/supervisor/lifecycle_drain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package supervisor

import (
"context"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"
"go.uber.org/zap"

"github.com/smart-mcp-proxy/mcpproxy-go/internal/config"
"github.com/smart-mcp-proxy/mcpproxy-go/internal/runtime/configsvc"
)

// blockingUpstream is a test double whose ConnectServer blocks until released,
// so a test can hold a Connect "in flight" and observe whether Close() (the
// upstream disconnect path driven by Supervisor.Stop) overlaps it.
//
// This reproduces the MCP-770 root cause: runtime.Close -> Supervisor.Stop ->
// ShutdownAll/Disconnect must NOT run while a reconcile-dispatched Connect is
// still executing on the same client.
type blockingUpstream struct {
mu sync.Mutex

connectStarted chan struct{} // closed when ConnectServer first enters
releaseConnect chan struct{} // ConnectServer blocks until this is closed

connectInFlight bool // true while ConnectServer is blocked
overlapDetected bool // set true if Close() ran while a Connect was in flight
closed bool

states map[string]*ServerState
eventCh chan Event
}

func newBlockingUpstream() *blockingUpstream {
return &blockingUpstream{
connectStarted: make(chan struct{}),
releaseConnect: make(chan struct{}),
states: make(map[string]*ServerState),
eventCh: make(chan Event, 10),
}
}

func (b *blockingUpstream) AddServer(name string, cfg *config.ServerConfig) error {
b.mu.Lock()
defer b.mu.Unlock()
b.states[name] = &ServerState{Name: name, Config: cfg, Enabled: cfg.Enabled, Connected: false}
return nil
}

func (b *blockingUpstream) RemoveServer(name string) error {
b.mu.Lock()
defer b.mu.Unlock()
delete(b.states, name)
return nil
}

func (b *blockingUpstream) ConnectServer(_ context.Context, name string) error {
b.mu.Lock()
b.connectInFlight = true
// Signal exactly once that a connect is in flight.
select {
case <-b.connectStarted:
default:
close(b.connectStarted)
}
b.mu.Unlock()

<-b.releaseConnect // block here, simulating a slow Connect

b.mu.Lock()
b.connectInFlight = false
if state, ok := b.states[name]; ok {
state.Connected = true
}
b.mu.Unlock()
return nil
}

func (b *blockingUpstream) DisconnectServer(string) error { return nil }
func (b *blockingUpstream) ConnectAll(context.Context) error { return nil }

func (b *blockingUpstream) GetServerState(name string) (*ServerState, error) {
b.mu.Lock()
defer b.mu.Unlock()
if s, ok := b.states[name]; ok {
cp := *s
return &cp, nil
}
return nil, nil
}

func (b *blockingUpstream) GetAllStates() map[string]*ServerState {
b.mu.Lock()
defer b.mu.Unlock()
out := make(map[string]*ServerState, len(b.states))
for k, v := range b.states {
cp := *v
out[k] = &cp
}
return out
}

func (b *blockingUpstream) IsUserLoggedOut(string) bool { return false }
func (b *blockingUpstream) Subscribe() <-chan Event { return b.eventCh }
func (b *blockingUpstream) Unsubscribe(<-chan Event) {}

func (b *blockingUpstream) Close() {
b.mu.Lock()
defer b.mu.Unlock()
if b.connectInFlight {
b.overlapDetected = true
}
b.closed = true
}

func (b *blockingUpstream) release() { close(b.releaseConnect) }

func (b *blockingUpstream) sawOverlap() bool {
b.mu.Lock()
defer b.mu.Unlock()
return b.overlapDetected
}

// TestSupervisor_Stop_DrainsInFlightConnectBeforeClose is the MCP-783 regression
// guard. Stop() must wait for in-flight reconcile action goroutines (here, a slow
// Connect) to finish before it disconnects upstream clients via upstream.Close().
// Before the drain fix, Stop() returned immediately and Close() overlapped the
// still-running Connect — the root of the MCP-770 race cascade.
func TestSupervisor_Stop_DrainsInFlightConnectBeforeClose(t *testing.T) {
cfg := &config.Config{
Listen: "127.0.0.1:8080",
Servers: []*config.ServerConfig{{Name: "slow-server", Enabled: true, Quarantined: false}},
}
configSvc := configsvc.NewService(cfg, "/tmp/config.json", zap.NewNop())
defer configSvc.Close()

up := newBlockingUpstream()
sup := New(configSvc, up, zap.NewNop())

// Dispatch the Connect action (runs in its own goroutine).
require.NoError(t, sup.reconcile(configSvc.Current()))

// Wait until Connect is actually in flight (blocked on releaseConnect).
select {
case <-up.connectStarted:
case <-time.After(2 * time.Second):
t.Fatal("ConnectServer never started")
}

// Call Stop() in the background; it must block on draining the in-flight Connect.
stopReturned := make(chan struct{})
go func() {
sup.Stop()
close(stopReturned)
}()

// Stop() must NOT return while Connect is still in flight.
select {
case <-stopReturned:
t.Fatal("Stop() returned before in-flight Connect completed (no drain)")
case <-time.After(200 * time.Millisecond):
// expected: Stop is draining
}

// Release the Connect; Stop() should now complete.
up.release()
select {
case <-stopReturned:
case <-time.After(3 * time.Second):
t.Fatal("Stop() did not return after Connect was released")
}

require.False(t, up.sawOverlap(),
"upstream.Close() overlapped an in-flight Connect — drain-before-disconnect failed")
}
65 changes: 64 additions & 1 deletion internal/runtime/supervisor/supervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ type Supervisor struct {
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup

// actionWg tracks in-flight reconcile action goroutines (Connect/Disconnect/
// Reconnect/Remove). Stop() drains it BEFORE disconnecting upstream clients so
// a Connect can never overlap a Disconnect on the same client (root fix for the
// MCP-770 race cascade, MCP-783). stopping (guarded by stateMu) gates dispatch
// so no new action is added once Stop() begins — preventing a WaitGroup
// Add-after-Wait.
actionWg sync.WaitGroup
stopping bool
}

// inspectionFailureInfo tracks inspection failures for circuit breaker pattern
Expand Down Expand Up @@ -311,6 +320,16 @@ func (s *Supervisor) reconcile(configSnapshot *configsvc.Snapshot) error {

plan := s.computeReconcilePlan(configSnapshot, actualStates, userLoggedOut)

// MCP-783: once Stop() has begun (stopping set under stateMu), do not dispatch
// new action goroutines. This keeps all actionWg.Add calls strictly ordered
// before Stop()'s actionWg.Wait (no Add-after-Wait) and guarantees no Connect
// can start after we begin draining for disconnect.
if s.stopping {
s.logger.Debug("Supervisor stopping, skipping reconcile action dispatch")
s.updateSnapshot(configSnapshot, actualStates)
return nil
}

// Phase 6 Fix: Execute actions asynchronously to prevent blocking
// Each action runs in its own goroutine with timeout
actionCount := 0
Expand All @@ -324,8 +343,12 @@ func (s *Supervisor) reconcile(configSnapshot *configsvc.Snapshot) error {
zap.String("server", serverName),
zap.String("action", string(action)))

// Launch each action in a goroutine - no waiting!
// Launch each action in a goroutine. Tracked by actionWg (Add under stateMu,
// before the goroutine starts) so Stop() can drain in-flight actions before
// disconnecting clients.
s.actionWg.Add(1)
go func(name string, act ReconcileAction, snapshot *configsvc.Snapshot) {
defer s.actionWg.Done()
if err := s.executeAction(name, act, snapshot); err != nil {
s.logger.Error("Failed to execute action",
zap.String("server", name),
Expand Down Expand Up @@ -1038,12 +1061,32 @@ func (s *Supervisor) emitEvent(event Event) {
}
}

// actionDrainTimeout bounds how long Stop() waits for in-flight reconcile action
// goroutines to finish before disconnecting clients. It exceeds the per-action
// context timeout (executeAction, 30s) so a well-behaved action that observes the
// cancelled context returns first; the timeout is only a backstop against a wedged
// Connect so shutdown can't hang forever.
const actionDrainTimeout = 35 * time.Second

// Stop gracefully stops the supervisor.
func (s *Supervisor) Stop() {
s.logger.Info("Stopping supervisor")

// MCP-783: mark stopping under stateMu so reconcile() dispatches no further
// action goroutines. Serializing on stateMu (the same lock reconcile holds while
// dispatching) ensures every actionWg.Add has happened before the drain below.
s.stateMu.Lock()
s.stopping = true
s.stateMu.Unlock()

s.cancel()
s.wg.Wait()

// Drain in-flight reconcile actions (Connect/Disconnect/...) BEFORE disconnecting
// upstream clients. Without this, ShutdownAll -> Disconnect overlaps an in-flight
// Connect on the same client — the root of the MCP-770 race cascade.
s.drainActions()

// Close upstream adapter
s.upstream.Close()

Expand All @@ -1058,6 +1101,26 @@ func (s *Supervisor) Stop() {
s.logger.Info("Supervisor stopped")
}

// drainActions waits for in-flight reconcile action goroutines to finish, bounded
// by actionDrainTimeout. Called from Stop() before disconnecting clients so a
// Connect can never overlap a Disconnect on the same client (MCP-783).
func (s *Supervisor) drainActions() {
done := make(chan struct{})
go func() {
s.actionWg.Wait()
close(done)
}()

select {
case <-done:
s.logger.Debug("Drained in-flight reconcile actions before disconnect")
case <-time.After(actionDrainTimeout):
s.logger.Warn("Timed out draining in-flight reconcile actions before disconnect; "+
"proceeding to disconnect (a Connect may still be in flight)",
zap.Duration("timeout", actionDrainTimeout))
}
}

// RequestInspectionExemption grants temporary connection permission for a quarantined server.
// This allows security inspection to temporarily connect to quarantined servers.
// Triggers immediate reconciliation to connect the server.
Expand Down
Loading