diff --git a/pkg/vmcp/discovery/manager.go b/pkg/vmcp/discovery/manager.go index 0845118ee..6dfa65951 100644 --- a/pkg/vmcp/discovery/manager.go +++ b/pkg/vmcp/discovery/manager.go @@ -18,6 +18,8 @@ import ( "sync" "time" + "golang.org/x/sync/singleflight" + "github.com/stacklok/toolhive/pkg/auth" "github.com/stacklok/toolhive/pkg/logger" "github.com/stacklok/toolhive/pkg/vmcp" @@ -68,6 +70,9 @@ type DefaultManager struct { stopCh chan struct{} stopOnce sync.Once wg sync.WaitGroup + // singleFlight ensures only one aggregation happens per cache key at a time + // This prevents concurrent requests from all triggering aggregation + singleFlight singleflight.Group } // NewManager creates a new discovery manager with the given aggregator. @@ -131,6 +136,9 @@ func NewManagerWithRegistry(agg aggregator.Aggregator, registry vmcp.DynamicRegi // // The context must contain an authenticated user identity (set by auth middleware). // Returns ErrNoIdentity if user identity is not found in context. +// +// This method uses singleflight to ensure that concurrent requests for the same +// cache key only trigger one aggregation, preventing duplicate work. func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) { // Validate user identity is present (set by auth middleware) // This ensures discovery happens with proper user authentication context @@ -142,7 +150,7 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) // Generate cache key from user identity and backend set cacheKey := m.generateCacheKey(identity.Subject, backends) - // Check cache first + // Check cache first (with read lock) if caps := m.getCachedCapabilities(cacheKey); caps != nil { logger.Debugf("Cache hit for user %s (key: %s)", identity.Subject, cacheKey) return caps, nil @@ -150,16 +158,33 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) logger.Debugf("Cache miss - performing capability discovery for user: %s", identity.Subject) - // Cache miss - perform aggregation - caps, err := m.aggregator.AggregateCapabilities(ctx, backends) + // Use singleflight to ensure only one aggregation happens per cache key + // Even if multiple requests come in concurrently, they'll all wait for the same result + result, err, _ := m.singleFlight.Do(cacheKey, func() (interface{}, error) { + // Double-check cache after acquiring singleflight lock + // Another goroutine might have populated it while we were waiting + if caps := m.getCachedCapabilities(cacheKey); caps != nil { + logger.Debugf("Cache populated while waiting - returning cached result for user %s", identity.Subject) + return caps, nil + } + + // Perform aggregation + caps, err := m.aggregator.AggregateCapabilities(ctx, backends) + if err != nil { + return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) + } + + // Cache the result (skips caching if at capacity and key doesn't exist) + m.cacheCapabilities(cacheKey, caps) + + return caps, nil + }) + if err != nil { - return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err) + return nil, err } - // Cache the result (skips caching if at capacity and key doesn't exist) - m.cacheCapabilities(cacheKey, caps) - - return caps, nil + return result.(*aggregator.AggregatedCapabilities), nil } // Stop gracefully stops the manager and cleans up resources.