diff --git a/internal/runtime/supervisor/actor_pool.go b/internal/runtime/supervisor/actor_pool.go index 5cdc3c24..1ad6e2e1 100644 --- a/internal/runtime/supervisor/actor_pool.go +++ b/internal/runtime/supervisor/actor_pool.go @@ -211,12 +211,12 @@ func (p *ActorPoolSimple) GetAllStates() map[string]*ServerState { state := &ServerState{ Name: name, - Config: client.Config, - Enabled: client.Config.Enabled, + Config: client.GetConfig(), + Enabled: client.GetConfig().Enabled, Connected: connected, } - if client.Config.Quarantined { + if client.GetConfig().Quarantined { state.Quarantined = true } diff --git a/internal/runtime/supervisor/actor_pool_complex_reference.go b/internal/runtime/supervisor/actor_pool_complex_reference.go index 2599332b..24b26abe 100644 --- a/internal/runtime/supervisor/actor_pool_complex_reference.go +++ b/internal/runtime/supervisor/actor_pool_complex_reference.go @@ -16,10 +16,10 @@ import ( // ActorPool manages the lifecycle of server actors and provides stats for Supervisor. // This replaces UpstreamAdapter with direct Actor integration (Phase 7.2). type ActorPool struct { - actors map[string]*actor.Actor - mu sync.RWMutex - logger *zap.Logger - manager *upstream.Manager // Use existing manager for client creation + actors map[string]*actor.Actor + mu sync.RWMutex + logger *zap.Logger + manager *upstream.Manager // Use existing manager for client creation // Event aggregation eventCh chan Event @@ -218,12 +218,12 @@ func (p *ActorPool) GetServerState(name string) (*ServerState, error) { state := &ServerState{ Name: name, - Config: client.Config, - Enabled: client.Config.Enabled, + Config: client.GetConfig(), + Enabled: client.GetConfig().Enabled, Connected: client.IsConnected(), } - if client.Config.Quarantined { + if client.GetConfig().Quarantined { state.Quarantined = true } @@ -258,12 +258,12 @@ func (p *ActorPool) GetAllStates() map[string]*ServerState { connected := client.IsConnected() state := &ServerState{ Name: name, - Config: client.Config, - Enabled: client.Config.Enabled, + Config: client.GetConfig(), + Enabled: client.GetConfig().Enabled, Connected: connected, } - if client.Config.Quarantined { + if client.GetConfig().Quarantined { state.Quarantined = true } @@ -328,9 +328,9 @@ func (p *ActorPool) forwardActorEvents(name string, a *actor.Actor) { ServerName: name, Timestamp: event.Timestamp, Payload: map[string]interface{}{ - "connected": event.State == actor.StateConnected, - "state": string(event.State), - "actor_event": string(event.Type), + "connected": event.State == actor.StateConnected, + "state": string(event.State), + "actor_event": string(event.Type), }, }) } diff --git a/internal/upstream/client_test.go b/internal/upstream/client_test.go index 93cf8391..101dfa49 100644 --- a/internal/upstream/client_test.go +++ b/internal/upstream/client_test.go @@ -309,7 +309,7 @@ func TestClient_Headers_Support(t *testing.T) { require.NotNil(t, client) // Test that headers are stored in config - assert.Equal(t, tt.headers, client.Config.Headers) + assert.Equal(t, tt.headers, client.GetConfig().Headers) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() diff --git a/internal/upstream/core/client.go b/internal/upstream/core/client.go index 375380bc..724f87ee 100644 --- a/internal/upstream/core/client.go +++ b/internal/upstream/core/client.go @@ -75,10 +75,21 @@ type Client struct { // Cached tools list from successful immediate call cachedTools []mcp.Tool - // Stderr monitoring + // monitoringMu serializes the stderr/process monitoring lifecycle methods + // (Start*/Stop*Monitoring). Connect (StartStderrMonitoring) and Disconnect + // (StopStderrMonitoring) can run concurrently on the same client during a + // reconcile-vs-shutdown overlap, racing the ctx/cancel/WaitGroup fields + // below (notably WG.Add vs WG.Wait). This mutex makes start and stop + // mutually exclusive. It is never held across c.mu. + monitoringMu sync.Mutex + + // Stderr monitoring. stderrMonitoringDone is a per-cycle channel closed by + // the monitor goroutine when it exits; Stop waits on it instead of a reused + // sync.WaitGroup, so an abandoned (timed-out) wait never races a later + // Start's counter. All three fields are written only under monitoringMu. stderrMonitoringCtx context.Context stderrMonitoringCancel context.CancelFunc - stderrMonitoringWG sync.WaitGroup + stderrMonitoringDone chan struct{} // Ring buffer of recent stderr lines from the subprocess. // Populated by monitorStderr; surfaced in initialize failure messages so @@ -92,7 +103,7 @@ type Client struct { processGroupID int // Process group ID for proper cleanup processMonitorCtx context.Context processMonitorCancel context.CancelFunc - processMonitorWG sync.WaitGroup + processMonitorDone chan struct{} // Docker container tracking containerID string diff --git a/internal/upstream/core/monitoring.go b/internal/upstream/core/monitoring.go index 92be5fe3..809a684d 100644 --- a/internal/upstream/core/monitoring.go +++ b/internal/upstream/core/monitoring.go @@ -24,17 +24,24 @@ const ( // StartStderrMonitoring starts monitoring stderr output and logging it func (c *Client) StartStderrMonitoring() { + c.monitoringMu.Lock() + defer c.monitoringMu.Unlock() + if c.stderr == nil || c.transportType != transportStdio { return } - // Create context for stderr monitoring - c.stderrMonitoringCtx, c.stderrMonitoringCancel = context.WithCancel(context.Background()) + // Create context for stderr monitoring. The monitor goroutine receives the + // context and its done channel as locals so an abandoned (timed-out) + // goroutine never reads the shared fields a later Start may overwrite. + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + c.stderrMonitoringCtx, c.stderrMonitoringCancel = ctx, cancel + c.stderrMonitoringDone = done - c.stderrMonitoringWG.Add(1) go func() { - defer c.stderrMonitoringWG.Done() - c.monitorStderr() + defer close(done) + c.monitorStderr(ctx) }() c.logger.Debug("Started stderr monitoring", @@ -43,41 +50,55 @@ func (c *Client) StartStderrMonitoring() { // StopStderrMonitoring stops stderr monitoring func (c *Client) StopStderrMonitoring() { - if c.stderrMonitoringCancel != nil { - c.stderrMonitoringCancel() + c.monitoringMu.Lock() + defer c.monitoringMu.Unlock() - // Use a timeout for the wait to prevent hanging - done := make(chan struct{}) - go func() { - c.stderrMonitoringWG.Wait() - close(done) - }() + if c.stderrMonitoringCancel == nil { + return + } - select { - case <-done: - c.logger.Debug("Stopped stderr monitoring", - zap.String("server", c.config.Name)) - case <-time.After(500 * time.Millisecond): - c.logger.Warn("Stderr monitoring stop timed out after 500ms, forcing shutdown", - zap.String("server", c.config.Name)) - } + c.stderrMonitoringCancel() + done := c.stderrMonitoringDone + c.stderrMonitoringCancel = nil + c.stderrMonitoringDone = nil + if done == nil { + return + } + + // Wait for the monitor goroutine directly under monitoringMu (no detached + // waiter that could outlive the lock). On timeout the goroutine is abandoned; + // it closes its own done channel and touches only its captured ctx, so it + // cannot race a subsequent Start. + select { + case <-done: + c.logger.Debug("Stopped stderr monitoring", + zap.String("server", c.config.Name)) + case <-time.After(500 * time.Millisecond): + c.logger.Warn("Stderr monitoring stop timed out after 500ms, forcing shutdown", + zap.String("server", c.config.Name)) } } // StartProcessMonitoring starts monitoring the underlying process func (c *Client) StartProcessMonitoring() { + c.monitoringMu.Lock() + defer c.monitoringMu.Unlock() + // Start monitoring even if processCmd is nil for Docker containers if c.processCmd == nil && !c.isDockerCommand { return } - // Create context for process monitoring - c.processMonitorCtx, c.processMonitorCancel = context.WithCancel(context.Background()) + // Create context for process monitoring (ctx + done passed as locals; see + // StartStderrMonitoring for the abandoned-goroutine rationale). + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + c.processMonitorCtx, c.processMonitorCancel = ctx, cancel + c.processMonitorDone = done - c.processMonitorWG.Add(1) go func() { - defer c.processMonitorWG.Done() - c.monitorProcess() + defer close(done) + c.monitorProcess(ctx) }() if c.processCmd != nil { @@ -94,29 +115,33 @@ func (c *Client) StartProcessMonitoring() { // StopProcessMonitoring stops process monitoring func (c *Client) StopProcessMonitoring() { - if c.processMonitorCancel != nil { - c.processMonitorCancel() + c.monitoringMu.Lock() + defer c.monitoringMu.Unlock() - // Use a timeout for the wait to prevent hanging - done := make(chan struct{}) - go func() { - c.processMonitorWG.Wait() - close(done) - }() + if c.processMonitorCancel == nil { + return + } - select { - case <-done: - c.logger.Debug("Stopped process monitoring", - zap.String("server", c.config.Name)) - case <-time.After(500 * time.Millisecond): - c.logger.Warn("Process monitoring stop timed out after 500ms, forcing shutdown", - zap.String("server", c.config.Name)) - } + c.processMonitorCancel() + done := c.processMonitorDone + c.processMonitorCancel = nil + c.processMonitorDone = nil + if done == nil { + return + } + + select { + case <-done: + c.logger.Debug("Stopped process monitoring", + zap.String("server", c.config.Name)) + case <-time.After(500 * time.Millisecond): + c.logger.Warn("Process monitoring stop timed out after 500ms, forcing shutdown", + zap.String("server", c.config.Name)) } } // monitorProcess monitors the underlying process health -func (c *Client) monitorProcess() { +func (c *Client) monitorProcess(ctx context.Context) { // Only return early if we have neither processCmd nor Docker command if c.processCmd == nil && !c.isDockerCommand { return @@ -130,7 +155,7 @@ func (c *Client) monitorProcess() { for { select { - case <-c.processMonitorCtx.Done(): + case <-ctx.Done(): return case <-ticker.C: if isDocker { @@ -141,11 +166,11 @@ func (c *Client) monitorProcess() { } // monitorStderr monitors stderr output and logs it to both main and server-specific logs -func (c *Client) monitorStderr() { +func (c *Client) monitorStderr(ctx context.Context) { scanner := bufio.NewScanner(c.stderr) for scanner.Scan() { select { - case <-c.stderrMonitoringCtx.Done(): + case <-ctx.Done(): return default: line := strings.TrimSpace(scanner.Text()) diff --git a/internal/upstream/core/monitoring_race_test.go b/internal/upstream/core/monitoring_race_test.go new file mode 100644 index 00000000..6ab87c5a --- /dev/null +++ b/internal/upstream/core/monitoring_race_test.go @@ -0,0 +1,86 @@ +package core + +import ( + "io" + "strings" + "sync" + "testing" + + "go.uber.org/zap" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" +) + +// TestStderrMonitoring_StartStopRace reproduces the Connect-vs-Disconnect race +// on the stderr-monitoring lifecycle fields (stderrMonitoringCtx/Cancel/WG). +// StartStderrMonitoring runs from connectStdio during a reconcile-driven Connect +// while StopStderrMonitoring runs from Disconnect during Manager.ShutdownAll, with +// no synchronization on those fields — the -race detector flags WG.Add (Start) +// vs WG.Wait (Stop). Run under `go test -race`: trips without monitoringMu, green +// with it. A reused empty stderr reader returns EOF immediately so monitorStderr +// exits at once and the loop stays fast. +func TestStderrMonitoring_StartStopRace(t *testing.T) { + c := &Client{ + transportType: transportStdio, + stderr: strings.NewReader(""), + logger: zap.NewNop(), + config: &config.ServerConfig{Name: "race"}, + } + + const iterations = 500 + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + c.StartStderrMonitoring() + } + }() + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + c.StopStderrMonitoring() + } + }() + + wg.Wait() + c.StopStderrMonitoring() +} + +// TestStderrMonitoring_AbandonedMonitorNoRace models the round-5 escape: the +// monitor goroutine is still alive when Stop is called (its stderr Read blocks), +// so Stop hits the 500ms timeout and abandons it. With the old reused-WaitGroup +// design the abandoned WG.Wait raced the next cycle's WG.Add; the per-cycle done +// channel + ctx-as-param design must keep concurrent Start/Stop race-free even +// while a prior monitor lingers. A blocking pipe keeps monitorStderr alive; +// closing the writer on cleanup lets the leaked goroutines exit. +func TestStderrMonitoring_AbandonedMonitorNoRace(t *testing.T) { + pr, pw := io.Pipe() + t.Cleanup(func() { _ = pw.Close() }) + + c := &Client{ + transportType: transportStdio, + stderr: pr, // Read blocks until the writer is closed + logger: zap.NewNop(), + config: &config.ServerConfig{Name: "race"}, + } + + const cycles = 4 // each Stop times out at 500ms; keep small + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for i := 0; i < cycles; i++ { + c.StartStderrMonitoring() + } + }() + go func() { + defer wg.Done() + for i := 0; i < cycles; i++ { + c.StopStderrMonitoring() + } + }() + wg.Wait() + c.StopStderrMonitoring() +} diff --git a/internal/upstream/managed/client.go b/internal/upstream/managed/client.go index d4c65db5..eee83ec9 100644 --- a/internal/upstream/managed/client.go +++ b/internal/upstream/managed/client.go @@ -5,6 +5,7 @@ import ( "fmt" "strings" "sync" + "sync/atomic" "time" "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" @@ -20,8 +21,15 @@ import ( // Client wraps a core client with state management, concurrency control, and background recovery type Client struct { - id string - Config *config.ServerConfig // Public field for compatibility with existing code + id string + // cfg holds the server configuration as an atomic pointer. SetConfig swaps it + // (reconcile add path, off mc.mu) while many readers — including detached + // state-change callback goroutines and Connect's unlocked phase — read it + // concurrently. An atomic pointer makes every read/write data-race-free and + // is lock-free, so it is safe to read whether or not mc.mu is held (the RLock + // accessor approach would deadlock the in-lock readers). Access via + // GetConfig() / SetConfig() only — never touch the field directly. (MCP-770) + cfg atomic.Pointer[config.ServerConfig] coreClient *core.Client logger *zap.Logger StateManager *types.StateManager // Public field for callback access @@ -91,7 +99,6 @@ func NewClient(id string, serverConfig *config.ServerConfig, logger *zap.Logger, // Create managed client mc := &Client{ id: id, - Config: serverConfig, coreClient: coreClient, logger: logger.With(zap.String("component", "managed_client")), StateManager: types.NewStateManager(), @@ -100,6 +107,7 @@ func NewClient(id string, serverConfig *config.ServerConfig, logger *zap.Logger, storage: storage, stopMonitoring: make(chan struct{}), } + mc.cfg.Store(serverConfig) // Set up state change callback mc.StateManager.SetStateChangeCallback(mc.onStateChange) @@ -152,8 +160,14 @@ func (mc *Client) Connect(ctx context.Context) error { return fmt.Errorf("connection already in progress or established (state: %s)", mc.StateManager.GetState().String()) } + // Snapshot the server name while mc.mu is held. Phase 3 below runs WITHOUT + // mc.mu, so dereferencing mc.GetConfig() there races with SetConfig swapping the + // pointer under the lock (MCP-770: SetConfig vs Connect). Use this local for + // any logging in the unlocked window. + serverName := mc.GetConfig().Name + mc.logger.Info("Starting managed connection to upstream server", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("current_state", mc.StateManager.GetState().String()), zap.Bool("list_tools_in_progress", mc.listToolsInProgress)) @@ -164,11 +178,11 @@ func (mc *Client) Connect(ctx context.Context) error { currentState := mc.StateManager.GetState() if currentState == types.StateError || currentState == types.StateDisconnected { mc.logger.Debug("Disconnecting core client before reconnect to clear stale state", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("from_state", currentState.String())) if err := mc.coreClient.Disconnect(); err != nil { mc.logger.Debug("Core client disconnect before reconnect returned", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Error(err)) } } @@ -194,7 +208,7 @@ func (mc *Client) Connect(ctx context.Context) error { // Phase 3: Execute the actual connection (potentially slow - OAuth, MCP initialize) // mc.mu is NOT held here, so Disconnect/SetConfig/GetConfig won't block mc.logger.Debug("Invoking core client Connect for managed client", - zap.String("server", mc.Config.Name)) + zap.String("server", serverName)) connectErr := mc.coreClient.Connect(connectCtx) // Phase 4: Re-acquire lock to update state based on result @@ -205,7 +219,7 @@ func (mc *Client) Connect(ctx context.Context) error { // Check if this is a deferred OAuth requirement (pending user action) if core.IsOAuthPending(connectErr) { mc.logger.Info("⏳ OAuth authentication pending user action", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) // Transition to PendingAuth state instead of Error mc.StateManager.TransitionTo(types.StatePendingAuth) mc.StateManager.SetError(connectErr) @@ -216,7 +230,7 @@ func (mc *Client) Connect(ctx context.Context) error { // Check if this is a token refresh scenario vs full re-auth isRefreshScenario := mc.isTokenRefreshScenario(connectErr) mc.logger.Info("🎯 OAuth authorization required during MCP initialization", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Bool("token_refresh_scenario", isRefreshScenario)) // Don't apply backoff for OAuth authorization requirement mc.StateManager.SetError(connectErr) @@ -225,7 +239,7 @@ func (mc *Client) Connect(ctx context.Context) error { // Check if this is a token refresh scenario vs full re-auth isRefreshScenario := mc.isTokenRefreshScenario(connectErr) mc.logger.Warn("OAuth authentication failed, applying extended backoff", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Bool("token_refresh_scenario", isRefreshScenario), zap.Error(connectErr)) mc.StateManager.SetOAuthError(connectErr) @@ -236,7 +250,7 @@ func (mc *Client) Connect(ctx context.Context) error { } mc.logger.Debug("Core client Connect returned successfully", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) // Transition to ready state only if not already ready if mc.StateManager.GetState() != types.StateReady { @@ -254,11 +268,11 @@ func (mc *Client) Connect(ctx context.Context) error { } mc.logger.Info("Successfully established managed connection", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) // Add a small delay before starting background monitoring to let connection stabilize mc.logger.Debug("🔍 Adding stabilization delay before starting background monitoring", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) // Create cancellable context for monitoring startup monitoringCtx, monitoringCancel := context.WithCancel(context.Background()) @@ -271,13 +285,13 @@ func (mc *Client) Connect(ctx context.Context) error { mc.mu.Lock() if mc.monitoringCancelFunc != nil { mc.logger.Debug("🔍 Starting background monitoring after stabilization delay", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) mc.startBackgroundMonitoring() } mc.mu.Unlock() case <-monitoringCtx.Done(): mc.logger.Debug("🔍 Background monitoring startup cancelled", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) } }() @@ -292,7 +306,7 @@ func (mc *Client) Disconnect() error { mc.mu.Lock() defer mc.mu.Unlock() - mc.logger.Info("Disconnecting managed client", zap.String("server", mc.Config.Name)) + mc.logger.Info("Disconnecting managed client", zap.String("server", mc.GetConfig().Name)) // Ensure no ListTools operations remain after acquiring the lock mc.cancelInFlightListTools() @@ -315,7 +329,7 @@ func (mc *Client) Disconnect() error { mc.StateManager.Reset() mc.logger.Debug("Managed client disconnect complete", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Bool("list_tools_in_progress", mc.listToolsInProgress)) return nil @@ -341,18 +355,16 @@ func (mc *Client) GetConnectionInfo() types.ConnectionInfo { return mc.StateManager.GetConnectionInfo() } -// GetConfig returns a thread-safe copy of the server configuration +// GetConfig returns the current server configuration pointer in a thread-safe, +// lock-free manner. Safe to call whether or not mc.mu is held. func (mc *Client) GetConfig() *config.ServerConfig { - mc.mu.RLock() - defer mc.mu.RUnlock() - return mc.Config + return mc.cfg.Load() } -// SetConfig updates the server configuration in a thread-safe manner +// SetConfig atomically swaps the server configuration. Lock-free; callers must +// not hold mc.mu (they don't need to — the swap is atomic). func (mc *Client) SetConfig(config *config.ServerConfig) { - mc.mu.Lock() - defer mc.mu.Unlock() - mc.Config = config + mc.cfg.Store(config) } // GetServerInfo returns server information @@ -409,11 +421,11 @@ func (mc *Client) IsDockerIsolated() bool { return false } // Check if server has isolation explicitly disabled - if mc.Config.Isolation != nil && mc.Config.Isolation.Enabled != nil && !*mc.Config.Isolation.Enabled { + if mc.GetConfig().Isolation != nil && mc.GetConfig().Isolation.Enabled != nil && !*mc.GetConfig().Isolation.Enabled { return false } // Only stdio servers with commands get Docker-isolated - return mc.Config.Command != "" + return mc.GetConfig().Command != "" } // SetUserLoggedOut marks that the user has explicitly logged out @@ -494,13 +506,13 @@ func (mc *Client) publishListToolsResult(tools []*config.ToolMetadata, err error // callers onto a single in-flight upstream call. func (mc *Client) ListTools(ctx context.Context) ([]*config.ToolMetadata, error) { mc.logger.Debug("🔍 ListTools called", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("state", mc.StateManager.GetState().String()), zap.Bool("connected", mc.IsConnected())) if !mc.IsConnected() { mc.logger.Debug("🔍 ListTools rejected - client not connected", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("state", mc.StateManager.GetState().String())) return nil, fmt.Errorf("client not connected (state: %s)", mc.StateManager.GetState().String()) } @@ -525,11 +537,11 @@ func (mc *Client) ListTools(ctx context.Context) ([]*config.ToolMetadata, error) // Defensive fallback: every leader path is supposed to allocate a // wait channel via acquireListToolsContext, so this should be // unreachable. Fail fast rather than block forever on a nil channel. - return nil, fmt.Errorf("ListTools operation already in progress for server %s", mc.Config.Name) + return nil, fmt.Errorf("ListTools operation already in progress for server %s", mc.GetConfig().Name) } mc.logger.Debug("🔍 ListTools already in progress, waiting for shared result", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) select { case <-ctx.Done(): @@ -554,10 +566,10 @@ func (mc *Client) runListToolsAsLeader(listCtx context.Context, release func() b defer func() { if release() { mc.logger.Debug("🔍 ListTools operation completed, flag reset", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) } else { mc.logger.Debug("🔍 ListTools operation completed while disconnected", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) } }() @@ -566,12 +578,12 @@ func (mc *Client) runListToolsAsLeader(listCtx context.Context, release func() b if err != nil { mc.logger.Error("ListTools operation failed", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Error(err)) if mc.isConnectionError(err) { mc.logger.Warn("Connection error detected during ListTools, updating server state", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Error(err)) mc.StateManager.SetError(err) } @@ -595,13 +607,13 @@ func (mc *Client) CallTool(ctx context.Context, toolName string, args map[string // Use different log levels based on error type if mc.isNormalReconnectionError(err) { mc.logger.Warn("Tool call failed due to connection loss, will attempt reconnection", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("tool", toolName), zap.String("error_type", "normal_reconnection"), zap.Error(err)) } else { mc.logger.Error("Tool call failed with connection error", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("tool", toolName), zap.Error(err)) } @@ -609,7 +621,7 @@ func (mc *Client) CallTool(ctx context.Context, toolName string, args map[string } else { // Log non-connection errors at error level mc.logger.Error("Tool call failed", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("tool", toolName), zap.Error(err)) } @@ -630,7 +642,7 @@ func (mc *Client) cancelInFlightListTools() { } mc.logger.Debug("Cancelling in-flight ListTools operation", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) cancel() @@ -649,7 +661,7 @@ func (mc *Client) cancelInFlightListTools() { } mc.logger.Debug("Timed out waiting for ListTools operation to cancel", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) } // cancelInFlightConnect cancels any in-flight Connect() operation. @@ -665,7 +677,7 @@ func (mc *Client) cancelInFlightConnect() { } mc.logger.Debug("Cancelling in-flight Connect operation", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) cancel() } @@ -674,15 +686,15 @@ func (mc *Client) onStateChange(oldState, newState types.ConnectionState, info * mc.logger.Info("State transition", zap.String("from", oldState.String()), zap.String("to", newState.String()), - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) // Handle error states with appropriate log levels if newState == types.StateError && info.LastError != nil { // Check for deprecated endpoint errors first - these require URL changes, not reconnection if mc.isDeprecatedEndpointError(info.LastError) { mc.logger.Error("⚠️ ENDPOINT DEPRECATED: Server URL needs to be updated", - zap.String("server", mc.Config.Name), - zap.String("current_url", mc.Config.URL), + zap.String("server", mc.GetConfig().Name), + zap.String("current_url", mc.GetConfig().URL), zap.String("error_type", "endpoint_deprecated"), zap.String("action", "Update the server URL in your configuration"), zap.String("hint", "The server may have migrated from /sse to /mcp - check the server's documentation"), @@ -692,13 +704,13 @@ func (mc *Client) onStateChange(oldState, newState types.ConnectionState, info * if mc.isNormalReconnectionError(info.LastError) { mc.logger.Warn("Connection error, will attempt automatic reconnection", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("error_type", "normal_reconnection"), zap.Error(info.LastError), zap.Int("retry_count", info.RetryCount)) } else { mc.logger.Error("Connection error", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Error(info.LastError), zap.Int("retry_count", info.RetryCount)) } @@ -721,7 +733,7 @@ func (mc *Client) stopBackgroundMonitoring() { // Only proceed if monitoring was actually started if !mc.monitoringStarted { mc.logger.Debug("Background monitoring was never started, skipping stop", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) return } @@ -737,10 +749,10 @@ func (mc *Client) stopBackgroundMonitoring() { select { case <-done: mc.logger.Debug("Background monitoring stopped successfully", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) case <-time.After(1 * time.Second): mc.logger.Warn("Background monitoring stop timed out after 1s, forcing shutdown", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) } mc.monitoringStarted = false @@ -760,7 +772,7 @@ func (mc *Client) backgroundHealthCheck() { mc.performHealthCheck() case <-mc.stopMonitoring: mc.logger.Debug("Background health monitoring stopped", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) return } } @@ -771,7 +783,7 @@ func (mc *Client) performHealthCheck() { // Skip all health/reconnect work when user explicitly logged out if mc.IsUserLoggedOut() { mc.logger.Debug("Health check skipped - user explicitly logged out", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) return } @@ -780,14 +792,14 @@ func (mc *Client) performHealthCheck() { if mc.StateManager.ShouldRetryOAuth() { info := mc.StateManager.GetConnectionInfo() mc.logger.Info("Attempting OAuth reconnection with extended backoff", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Int("oauth_retry_count", info.OAuthRetryCount), zap.Time("last_oauth_attempt", info.LastOAuthAttempt)) mc.tryReconnect() } else { info := mc.StateManager.GetConnectionInfo() mc.logger.Debug("OAuth backoff period not elapsed, skipping reconnection", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Int("oauth_retry_count", info.OAuthRetryCount), zap.Time("last_oauth_attempt", info.LastOAuthAttempt)) } @@ -801,14 +813,14 @@ func (mc *Client) performHealthCheck() { // Log once at WARN then suppress — server needs manual reconnect if info.RetryCount == types.MaxConnectionRetries { mc.logger.Warn("Giving up automatic reconnection after max retries — use manual reconnect or reconnect-on-use", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Int("retry_count", info.RetryCount)) } return } if mc.ShouldRetry() { mc.logger.Info("Attempting automatic reconnection with exponential backoff", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Int("retry_count", info.RetryCount)) mc.tryReconnect() @@ -824,8 +836,8 @@ func (mc *Client) performHealthCheck() { // Skip health checks for Docker servers to avoid interference with container management if mc.isDockerServer() { mc.logger.Debug("Skipping health check for Docker server", - zap.String("server", mc.Config.Name), - zap.String("command", mc.Config.Command)) + zap.String("server", mc.GetConfig().Name), + zap.String("command", mc.GetConfig().Command)) return } @@ -836,7 +848,7 @@ func (mc *Client) performHealthCheck() { listCtx, release, ok := mc.acquireListToolsContext(ctx, 5*time.Second) if !ok { mc.logger.Debug("Health check skipped - ListTools already in progress", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) return } @@ -852,20 +864,20 @@ func (mc *Client) performHealthCheck() { if mc.isConnectionError(err) { if mc.recordHealthCheckFailure(err) { mc.logger.Warn("Health check failed repeatedly, marking as error", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Int("consecutive_failures", mc.consecutiveHealthFailures), zap.Error(err)) mc.StateManager.SetError(err) } else { mc.logger.Info("Health check failed transiently, tolerating below threshold", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Int("consecutive_failures", mc.consecutiveHealthFailures), zap.Int("threshold", healthCheckFailureThreshold), zap.Error(err)) } } else { mc.logger.Debug("Health check failed with timeout (high activity), ignoring", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Error(err)) } return @@ -873,7 +885,7 @@ func (mc *Client) performHealthCheck() { mc.recordHealthCheckSuccess() mc.logger.Debug("Health check passed successfully", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) } // recordHealthCheckFailure increments the consecutive-failure counter and @@ -957,14 +969,14 @@ func (mc *Client) ForceReconnect(reason string) { if mc.IsUserLoggedOut() { mc.logger.Info("Force reconnect skipped - user explicitly logged out", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("reason", reason)) return } serverName := "" - if mc.Config != nil { - serverName = mc.Config.Name + if mc.GetConfig() != nil { + serverName = mc.GetConfig().Name } if mc.IsConnected() { @@ -995,7 +1007,7 @@ func (mc *Client) ForceReconnect(reason string) { func (mc *Client) tryReconnect() { if mc.IsUserLoggedOut() { mc.logger.Info("Skipping reconnection attempt - user explicitly logged out", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) return } @@ -1004,7 +1016,7 @@ func (mc *Client) tryReconnect() { if mc.reconnectInProgress { mc.reconnectMu.Unlock() mc.logger.Debug("Reconnection already in progress, skipping duplicate attempt", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) return } mc.reconnectInProgress = true @@ -1022,7 +1034,7 @@ func (mc *Client) tryReconnect() { defer cancel() mc.logger.Info("Starting reconnection attempt", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("current_state", mc.StateManager.GetState().String())) // First, disconnect the current client to clean up any broken connections @@ -1031,7 +1043,7 @@ func (mc *Client) tryReconnect() { mc.cancelInFlightListTools() if err := mc.coreClient.Disconnect(); err != nil { mc.logger.Warn("Failed to disconnect during reconnection attempt", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Error(err)) } @@ -1046,19 +1058,19 @@ func (mc *Client) tryReconnect() { // Use different log levels based on error type and retry count if mc.isOAuthError(err) { mc.logger.Warn("OAuth reconnection attempt failed, extended backoff will apply", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("error_type", "oauth_authentication"), zap.Error(err), zap.Int("oauth_retry_count", info.OAuthRetryCount)) } else if mc.isNormalReconnectionError(err) && info.RetryCount <= 5 { mc.logger.Warn("Reconnection attempt failed, will retry with exponential backoff", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("error_type", "normal_reconnection"), zap.Error(err), zap.Int("retry_count", info.RetryCount)) } else { mc.logger.Error("Reconnection attempt failed", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Error(err), zap.Int("retry_count", info.RetryCount)) } @@ -1067,7 +1079,7 @@ func (mc *Client) tryReconnect() { } mc.logger.Info("Reconnection attempt successful", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("new_state", mc.StateManager.GetState().String())) } @@ -1121,8 +1133,8 @@ func (mc *Client) TryReconnectSync(ctx context.Context) error { }() serverName := "" - if mc.Config != nil { - serverName = mc.Config.Name + if mc.GetConfig() != nil { + serverName = mc.GetConfig().Name } mc.logger.Info("TryReconnectSync: starting synchronous reconnect", @@ -1261,7 +1273,7 @@ func (mc *Client) isTokenRefreshScenario(err error) bool { for _, indicator := range tokenRefreshIndicators { if containsString(errStr, indicator) { mc.logger.Debug("🔄 Detected token refresh scenario", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("indicator", indicator)) return true } @@ -1378,7 +1390,7 @@ func (mc *Client) GetCachedToolCount(ctx context.Context) (int, error) { // Cache miss or expired - need to fetch fresh count if !mc.IsConnected() { mc.logger.Debug("🔍 Tool count fetch skipped - client not connected", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.String("state", mc.StateManager.GetState().String())) return 0, fmt.Errorf("client not connected (state: %s)", mc.StateManager.GetState().String()) } @@ -1386,14 +1398,14 @@ func (mc *Client) GetCachedToolCount(ctx context.Context) (int, error) { listCtx, release, ok := mc.acquireListToolsContext(ctx, 30*time.Second) if !ok { mc.logger.Debug("🔍 Tool count fetch skipped - ListTools already in progress", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) // Return cached count even if expired rather than causing another concurrent call return cachedCount, nil } defer release() mc.logger.Debug("🔍 Tool count cache miss - fetching fresh count", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Bool("cache_expired", !cachedTime.IsZero()), zap.Duration("cache_age", time.Since(cachedTime))) @@ -1403,7 +1415,7 @@ func (mc *Client) GetCachedToolCount(ctx context.Context) (int, error) { mc.publishListToolsResult(tools, err) if err != nil { mc.logger.Debug("Tool count fetch failed, returning cached value", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Error(err), zap.Int("cached_count", cachedCount)) @@ -1425,7 +1437,7 @@ func (mc *Client) GetCachedToolCount(ctx context.Context) (int, error) { mc.setToolCountCache(freshCount) mc.logger.Debug("🔍 Tool count cache updated", - zap.String("server", mc.Config.Name), + zap.String("server", mc.GetConfig().Name), zap.Int("fresh_count", freshCount), zap.Int("previous_count", cachedCount)) @@ -1450,7 +1462,7 @@ func (mc *Client) InvalidateToolCountCache() { mc.toolCountMu.Unlock() mc.logger.Debug("🔍 Tool count cache invalidated", - zap.String("server", mc.Config.Name)) + zap.String("server", mc.GetConfig().Name)) } // Helper function to check if string contains substring @@ -1493,5 +1505,5 @@ func (mc *Client) setToolCountCache(count int) { // isDockerServer checks if the server is running via Docker func (mc *Client) isDockerServer() bool { - return containsString(mc.Config.Command, "docker") + return containsString(mc.GetConfig().Command, "docker") } diff --git a/internal/upstream/managed/health_flap_test.go b/internal/upstream/managed/health_flap_test.go index b2b9dcd2..d80cc1bb 100644 --- a/internal/upstream/managed/health_flap_test.go +++ b/internal/upstream/managed/health_flap_test.go @@ -16,9 +16,9 @@ import ( func newTestClientForHealth(t *testing.T) *Client { t.Helper() mc := &Client{ - Config: &config.ServerConfig{Name: "flap-server"}, logger: zap.NewNop(), } + mc.SetConfig(&config.ServerConfig{Name: "flap-server"}) mc.StateManager = types.NewStateManager() mc.StateManager.TransitionTo(types.StateConnecting) mc.StateManager.TransitionTo(types.StateReady) diff --git a/internal/upstream/managed/listtools_coalescing_test.go b/internal/upstream/managed/listtools_coalescing_test.go index 427d1138..da15435c 100644 --- a/internal/upstream/managed/listtools_coalescing_test.go +++ b/internal/upstream/managed/listtools_coalescing_test.go @@ -18,9 +18,9 @@ import ( func newTestReadyClient(t *testing.T) *Client { t.Helper() mc := &Client{ - Config: &config.ServerConfig{Name: "test-server"}, logger: zap.NewNop(), } + mc.SetConfig(&config.ServerConfig{Name: "test-server"}) mc.StateManager = types.NewStateManager() mc.StateManager.TransitionTo(types.StateConnecting) mc.StateManager.TransitionTo(types.StateReady) diff --git a/internal/upstream/manager.go b/internal/upstream/manager.go index b66c9895..acc4c4f2 100644 --- a/internal/upstream/manager.go +++ b/internal/upstream/manager.go @@ -273,7 +273,7 @@ func (m *Manager) AddServerConfig(id string, serverConfig *config.ServerConfig) // Check if existing client exists and if config has changed var clientToDisconnect *managed.Client if existingClient, exists := m.clients[id]; exists { - existingConfig := existingClient.Config + existingConfig := existingClient.GetConfig() // Compare configurations to determine if reconnection is needed configChanged := existingConfig.URL != serverConfig.URL || @@ -822,14 +822,21 @@ func (m *Manager) DiscoverTools(ctx context.Context) ([]*config.ToolMetadata, er for id, client := range m.clients { name := "" quarantined := false - if client != nil && client.Config != nil { - name = client.Config.Name - quarantined = client.Config.Quarantined + enabled := false + // Read config through the thread-safe GetConfig() accessor — the reconcile + // add path (AddServerConfig) calls SetConfig (an atomic swap) off m.mu, so + // a direct config-field read would race with it (MCP-770). + if client != nil { + if cfg := client.GetConfig(); cfg != nil { + name = cfg.Name + quarantined = cfg.Quarantined + enabled = cfg.Enabled + } } snapshots = append(snapshots, clientSnapshot{ id: id, name: name, - enabled: client != nil && client.Config != nil && client.Config.Enabled, + enabled: enabled, quarantined: quarantined, client: client, }) @@ -916,7 +923,7 @@ func (m *Manager) CallTool(ctx context.Context, toolName string, args map[string // Find the client for this server var targetClient *managed.Client for _, client := range m.clients { - if client.Config.Name == serverName { + if client.GetConfig().Name == serverName { targetClient = client break } @@ -930,11 +937,11 @@ func (m *Manager) CallTool(ctx context.Context, toolName string, args map[string m.logger.Debug("CallTool: client found", zap.String("server_name", serverName), - zap.Bool("enabled", targetClient.Config.Enabled), + zap.Bool("enabled", targetClient.GetConfig().Enabled), zap.Bool("connected", targetClient.IsConnected()), zap.String("state", targetClient.GetState().String())) - if !targetClient.Config.Enabled { + if !targetClient.GetConfig().Enabled { return nil, fmt.Errorf("client for server %s is disabled", serverName) } @@ -947,9 +954,9 @@ func (m *Manager) CallTool(ctx context.Context, toolName string, args map[string // Attempt reconnect-on-use if enabled for this server reconnected := false - if targetClient.Config.ReconnectOnUse && + if targetClient.GetConfig().ReconnectOnUse && !targetClient.IsUserLoggedOut() && - !targetClient.Config.Quarantined { + !targetClient.GetConfig().Quarantined { m.logger.Info("reconnect_on_use: attempting reconnect for tool call", zap.String("server", serverName), zap.String("tool", actualToolName), @@ -1074,29 +1081,29 @@ func (m *Manager) ConnectAll(ctx context.Context) error { for id, client := range clients { m.logger.Debug("Evaluating client for connection", zap.String("id", id), - zap.String("name", client.Config.Name), - zap.Bool("enabled", client.Config.Enabled), + zap.String("name", client.GetConfig().Name), + zap.Bool("enabled", client.GetConfig().Enabled), zap.Bool("is_connected", client.IsConnected()), zap.Bool("is_connecting", client.IsConnecting()), zap.String("current_state", client.GetState().String()), - zap.Bool("quarantined", client.Config.Quarantined)) + zap.Bool("quarantined", client.GetConfig().Quarantined)) - if !client.Config.Enabled { + if !client.GetConfig().Enabled { m.logger.Debug("Skipping disabled client", zap.String("id", id), - zap.String("name", client.Config.Name)) + zap.String("name", client.GetConfig().Name)) if client.IsConnected() { - m.logger.Info("Disconnecting disabled client", zap.String("id", id), zap.String("name", client.Config.Name)) + m.logger.Info("Disconnecting disabled client", zap.String("id", id), zap.String("name", client.GetConfig().Name)) _ = client.Disconnect() } continue } - if client.Config.Quarantined { + if client.GetConfig().Quarantined { m.logger.Info("Skipping quarantined client", zap.String("id", id), - zap.String("name", client.Config.Name)) + zap.String("name", client.GetConfig().Name)) continue } @@ -1104,7 +1111,7 @@ func (m *Manager) ConnectAll(ctx context.Context) error { if client.IsUserLoggedOut() { m.logger.Debug("Skipping client - user explicitly logged out, waiting for manual login", zap.String("id", id), - zap.String("name", client.Config.Name)) + zap.String("name", client.GetConfig().Name)) continue } @@ -1112,14 +1119,14 @@ func (m *Manager) ConnectAll(ctx context.Context) error { if client.IsConnected() { m.logger.Debug("Client already connected, skipping", zap.String("id", id), - zap.String("name", client.Config.Name)) + zap.String("name", client.GetConfig().Name)) continue } if client.IsConnecting() { m.logger.Debug("Client already connecting, skipping", zap.String("id", id), - zap.String("name", client.Config.Name)) + zap.String("name", client.GetConfig().Name)) continue } @@ -1127,7 +1134,7 @@ func (m *Manager) ConnectAll(ctx context.Context) error { info := client.GetConnectionInfo() m.logger.Debug("Client backoff active, skipping connect attempt", zap.String("id", id), - zap.String("name", client.Config.Name), + zap.String("name", client.GetConfig().Name), zap.Int("retry_count", info.RetryCount), zap.Time("last_retry_time", info.LastRetryTime)) continue @@ -1135,10 +1142,10 @@ func (m *Manager) ConnectAll(ctx context.Context) error { m.logger.Info("Attempting to connect client", zap.String("id", id), - zap.String("name", client.Config.Name), - zap.String("url", client.Config.URL), - zap.String("command", client.Config.Command), - zap.String("protocol", client.Config.Protocol)) + zap.String("name", client.GetConfig().Name), + zap.String("url", client.GetConfig().URL), + zap.String("command", client.GetConfig().Command), + zap.String("protocol", client.GetConfig().Protocol)) wg.Add(1) go func(id string, c *managed.Client) { @@ -1155,13 +1162,13 @@ func (m *Manager) ConnectAll(ctx context.Context) error { if err := c.Connect(connectCtx); err != nil { m.logger.Error("Failed to connect to upstream server", zap.String("id", id), - zap.String("name", c.Config.Name), + zap.String("name", c.GetConfig().Name), zap.String("state", c.GetState().String()), zap.Error(err)) } else { m.logger.Info("Successfully initiated connection to upstream server", zap.String("id", id), - zap.String("name", c.Config.Name)) + zap.String("name", c.GetConfig().Name)) } }(id, client) } @@ -1312,15 +1319,22 @@ func (m *Manager) GetStats() map[string]interface{} { // Get detailed connection info from state manager connectionInfo := client.GetConnectionInfo() + // Read config through the thread-safe accessor to avoid racing with + // SetConfig on the reconcile add path (MCP-770). + name, url, protocol := "", "", "" + if cfg := client.GetConfig(); cfg != nil { + name, url, protocol = cfg.Name, cfg.URL, cfg.Protocol + } + status := map[string]interface{}{ "state": connectionInfo.State.String(), "connected": connectionInfo.State == types.StateReady, "connecting": client.IsConnecting(), "retry_count": connectionInfo.RetryCount, "should_retry": client.ShouldRetry(), - "name": client.Config.Name, - "url": client.Config.URL, - "protocol": client.Config.Protocol, + "name": name, + "url": url, + "protocol": protocol, } if connectionInfo.State == types.StateReady { @@ -1386,7 +1400,12 @@ func (m *Manager) GetTotalToolCount() int { // Now process clients without holding lock totalTools := 0 for _, client := range clientsCopy { - if client == nil || client.Config == nil || !client.Config.Enabled || !client.IsConnected() { + if client == nil { + continue + } + // Read config through the thread-safe accessor (MCP-770). + cfg := client.GetConfig() + if cfg == nil || !cfg.Enabled || !client.IsConnected() { continue } @@ -1403,7 +1422,8 @@ func (m *Manager) ListServers() map[string]*config.ServerConfig { servers := make(map[string]*config.ServerConfig) for id, client := range m.clients { - servers[id] = client.Config + // Read config through the thread-safe accessor (MCP-770). + servers[id] = client.GetConfig() } return servers } @@ -1453,7 +1473,7 @@ func (m *Manager) RetryConnection(serverName string) error { var hasToken bool var tokenExpires time.Time if m.storage != nil { - ts := oauth.NewPersistentTokenStore(client.Config.Name, client.Config.URL, m.storage) + ts := oauth.NewPersistentTokenStore(client.GetConfig().Name, client.GetConfig().URL, m.storage) if tok, err := ts.GetToken(context.Background()); err == nil && tok != nil { hasToken = true tokenExpires = tok.ExpiresAt @@ -1822,7 +1842,7 @@ func (m *Manager) StartManualOAuth(serverName string, force bool) error { return fmt.Errorf("server not found: %s", serverName) } - cfg := client.Config + cfg := client.GetConfig() m.logger.Info("Starting in-process manual OAuth", zap.String("server", cfg.Name), zap.Bool("force", force)) @@ -1905,7 +1925,7 @@ func (m *Manager) StartManualOAuthQuick(serverName string) (*core.OAuthStartResu return nil, fmt.Errorf("server not found: %s", serverName) } - cfg := client.Config + cfg := client.GetConfig() m.logger.Info("Starting quick OAuth flow (returns browser status immediately)", zap.String("server", cfg.Name)) @@ -1988,7 +2008,7 @@ func (m *Manager) StartManualOAuthWithInfo(serverName string, force bool) (*core return nil, fmt.Errorf("server not found: %s", serverName) } - cfg := client.Config + cfg := client.GetConfig() m.logger.Info("Starting in-process manual OAuth with info tracking", zap.String("server", cfg.Name), zap.Bool("force", force)) diff --git a/internal/upstream/manager_config_race_test.go b/internal/upstream/manager_config_race_test.go new file mode 100644 index 00000000..85fc289f --- /dev/null +++ b/internal/upstream/manager_config_race_test.go @@ -0,0 +1,110 @@ +package upstream + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/smart-mcp-proxy/mcpproxy-go/internal/config" +) + +// TestDiscoverTools_ConfigRace reproduces MCP-770: a data race between +// Manager.DiscoverTools (background tool indexing) reading client.Config and +// managed.Client.SetConfig (reconcile add path in AddServerConfig) writing it. +// +// AddServerConfig releases m.mu before calling SetConfig (to avoid deadlock with +// GetServerState), so the write is guarded only by the managed client's mc.mu. +// DiscoverTools must therefore read the config through the mutex-guarded +// GetConfig() accessor rather than touching client.Config directly. Run under +// `go test -race` — without the fix the race detector flags concurrent +// read/write on the mc.Config field. +func TestDiscoverTools_ConfigRace(t *testing.T) { + serverConfig := &config.ServerConfig{ + Name: "race-server", + URL: "http://127.0.0.1:0", + Protocol: "http", + Enabled: true, + Created: time.Now(), + } + + manager, _ := createTestManagerWithClient(t, serverConfig) + + const iterations = 200 + var wg sync.WaitGroup + wg.Add(2) + + // Writer: reconcile add path -> SetConfig swaps the mc.Config pointer. + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + // Fresh, equal config each iteration so the unchanged-config branch + // in AddServerConfig calls SetConfig with a new pointer. + cfg := *serverConfig + cfg.Created = time.Now() + _ = manager.AddServerConfig(serverConfig.Name, &cfg) + } + }() + + // Reader: background tool indexing + API-facing status readers snapshot + // client.Config. All must go through the mutex-guarded accessor. + go func() { + defer wg.Done() + ctx := context.Background() + for i := 0; i < iterations; i++ { + _, _ = manager.DiscoverTools(ctx) + _ = manager.GetStats() + _ = manager.GetTotalToolCount() + _ = manager.ListServers() + } + }() + + wg.Wait() +} + +// TestConnect_ConfigRace reproduces the sibling MCP-770 race surfaced on PR #555 +// (macOS -race unit job): reconcile spawns AddServer (-> SetConfig writes the +// mc.Config pointer under mc.mu) and ConnectServer (-> Client.Connect) as +// concurrent goroutines. Connect releases mc.mu before the slow core connect and +// logged the server name by dereferencing mc.Config in that unlocked window, +// racing SetConfig's write. The fix snapshots the name under the Phase-1 lock. +// Run under `go test -race`. +func TestConnect_ConfigRace(t *testing.T) { + serverConfig := &config.ServerConfig{ + Name: "race-server", + URL: "http://127.0.0.1:0", // unreachable -> core Connect fails fast + Protocol: "http", + Enabled: true, + Created: time.Now(), + } + + manager, client := createTestManagerWithClient(t, serverConfig) + + const iterations = 200 + var wg sync.WaitGroup + wg.Add(2) + + // Writer: reconcile add path -> SetConfig swaps the mc.Config pointer. + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + cfg := *serverConfig + cfg.Created = time.Now() + _ = manager.AddServerConfig(serverConfig.Name, &cfg) + } + }() + + // Reader: reconcile connect path -> Client.Connect reads the config in its + // unlocked phase. The failing core connect leaves the client in Error state, + // so each iteration passes the connecting/ready guard and reaches the read. + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + _ = client.Connect(ctx) + cancel() + } + }() + + wg.Wait() +}