Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 33 additions & 8 deletions pkg/vmcp/discovery/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
"sync"
"time"

"golang.org/x/sync/singleflight"

"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/logger"
"github.com/stacklok/toolhive/pkg/vmcp"
Expand Down Expand Up @@ -68,6 +70,9 @@ type DefaultManager struct {
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
// singleFlight ensures only one aggregation happens per cache key at a time
// This prevents concurrent requests from all triggering aggregation
singleFlight singleflight.Group
}

// NewManager creates a new discovery manager with the given aggregator.
Expand Down Expand Up @@ -131,6 +136,9 @@ func NewManagerWithRegistry(agg aggregator.Aggregator, registry vmcp.DynamicRegi
//
// The context must contain an authenticated user identity (set by auth middleware).
// Returns ErrNoIdentity if user identity is not found in context.
//
// This method uses singleflight to ensure that concurrent requests for the same
// cache key only trigger one aggregation, preventing duplicate work.
func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend) (*aggregator.AggregatedCapabilities, error) {
// Validate user identity is present (set by auth middleware)
// This ensures discovery happens with proper user authentication context
Expand All @@ -142,24 +150,41 @@ func (m *DefaultManager) Discover(ctx context.Context, backends []vmcp.Backend)
// Generate cache key from user identity and backend set
cacheKey := m.generateCacheKey(identity.Subject, backends)

// Check cache first
// Check cache first (with read lock)
if caps := m.getCachedCapabilities(cacheKey); caps != nil {
logger.Debugf("Cache hit for user %s (key: %s)", identity.Subject, cacheKey)
return caps, nil
}
Comment on lines +153 to 157
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blocker: Based on my reading of singleflight source code, singleflight actually implements a cache as well. If the function has been called for the cache key, it returns whatever results was originally produced.

Unfortunately, it's missing one important bit of functionality: expires behavior. Without this, the cache is unbounded in size and can cause OOMs for long running vMCPs. I don't see a good way to implement this on top of singleflight unfortunately, since the only delete API is Forget.

To fix this, I'd recommend:
Create a generic, time-limited cache:

func NewCacheWithTTL[V any](ttl time.Duration) Cache[V] {...}


type Cache[V any] interface {
    Get(key string, loader func() (V, error)) (V, error)
}

I haven't thought too deeply about the implementation, so I'll leave that up to you. Factoring it out like this is nice for a few reasons:

  1. This cache would easily be reusable in different circumstances.
  2. We can thoroughly unit test the cache's behavior. The current implementation that is coupled to capability aggregation is hard to unit test.


logger.Debugf("Cache miss - performing capability discovery for user: %s", identity.Subject)

// Cache miss - perform aggregation
caps, err := m.aggregator.AggregateCapabilities(ctx, backends)
// Use singleflight to ensure only one aggregation happens per cache key
// Even if multiple requests come in concurrently, they'll all wait for the same result
result, err, _ := m.singleFlight.Do(cacheKey, func() (interface{}, error) {
// Double-check cache after acquiring singleflight lock
// Another goroutine might have populated it while we were waiting
if caps := m.getCachedCapabilities(cacheKey); caps != nil {
logger.Debugf("Cache populated while waiting - returning cached result for user %s", identity.Subject)
return caps, nil
}

// Perform aggregation
caps, err := m.aggregator.AggregateCapabilities(ctx, backends)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err)
}

// Cache the result (skips caching if at capacity and key doesn't exist)
m.cacheCapabilities(cacheKey, caps)

return caps, nil
})

if err != nil {
return nil, fmt.Errorf("%w: %w", ErrDiscoveryFailed, err)
return nil, err
}

// Cache the result (skips caching if at capacity and key doesn't exist)
m.cacheCapabilities(cacheKey, caps)

return caps, nil
return result.(*aggregator.AggregatedCapabilities), nil
}

// Stop gracefully stops the manager and cleans up resources.
Expand Down
Loading