diff --git a/pkg/context/token.go b/pkg/context/token.go index 97091a922..7398993c1 100644 --- a/pkg/context/token.go +++ b/pkg/context/token.go @@ -28,15 +28,47 @@ func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) { type tokenScopesKey struct{} +type tokenScopesValue struct { + Token string + Scopes []string +} + // WithTokenScopes adds token scopes to the context func WithTokenScopes(ctx context.Context, scopes []string) context.Context { return context.WithValue(ctx, tokenScopesKey{}, scopes) } +// WithTokenScopesForToken adds token scopes and the associated token to the context. +func WithTokenScopesForToken(ctx context.Context, token string, scopes []string) context.Context { + return context.WithValue(ctx, tokenScopesKey{}, tokenScopesValue{ + Token: token, + Scopes: scopes, + }) +} + // GetTokenScopes retrieves token scopes from the context func GetTokenScopes(ctx context.Context) ([]string, bool) { + if scoped, ok := ctx.Value(tokenScopesKey{}).(tokenScopesValue); ok { + return scoped.Scopes, true + } if scopes, ok := ctx.Value(tokenScopesKey{}).([]string); ok { return scopes, true } return nil, false } + +// GetTokenScopesForToken retrieves token scopes only when they are bound to the active token. +func GetTokenScopesForToken(ctx context.Context, token string) ([]string, bool) { + if scoped, ok := ctx.Value(tokenScopesKey{}).(tokenScopesValue); ok { + if scoped.Token == token { + return scoped.Scopes, true + } + return nil, false + } + if token == "" { + if scopes, ok := ctx.Value(tokenScopesKey{}).([]string); ok { + return scopes, true + } + } + return nil, false +} diff --git a/pkg/context/token_test.go b/pkg/context/token_test.go new file mode 100644 index 000000000..0b04ab5fb --- /dev/null +++ b/pkg/context/token_test.go @@ -0,0 +1,32 @@ +package context + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetTokenScopesForToken_MatchesBoundToken(t *testing.T) { + ctx := WithTokenScopesForToken(context.Background(), "token-a", []string{"repo"}) + + scopes, ok := GetTokenScopesForToken(ctx, "token-a") + assert.True(t, ok) + assert.Equal(t, []string{"repo"}, scopes) + + scopes, ok = GetTokenScopesForToken(ctx, "token-b") + assert.False(t, ok) + assert.Nil(t, scopes) +} + +func TestGetTokenScopesForToken_DoesNotReuseLegacyScopesForNonEmptyToken(t *testing.T) { + ctx := WithTokenScopes(context.Background(), []string{"repo"}) + + scopes, ok := GetTokenScopesForToken(ctx, "token-a") + assert.False(t, ok) + assert.Nil(t, scopes) + + legacyScopes, legacyOK := GetTokenScopes(ctx) + assert.True(t, legacyOK) + assert.Equal(t, []string{"repo"}, legacyScopes) +} diff --git a/pkg/http/handler.go b/pkg/http/handler.go index 2e828211d..6dd2e5bd2 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -295,7 +295,7 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche // Fine-grained PATs and other token types don't support this, so we skip filtering. if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { // Check if scopes are already in context (should be set by WithPATScopes). If not, fetch them. - existingScopes, ok := ghcontext.GetTokenScopes(ctx) + existingScopes, ok := ghcontext.GetTokenScopesForToken(ctx, tokenInfo.Token) if ok { return b.WithFilter(github.CreateToolScopeFilter(existingScopes)) } diff --git a/pkg/http/middleware/pat_scope.go b/pkg/http/middleware/pat_scope.go index bb1efdc01..37ab28352 100644 --- a/pkg/http/middleware/pat_scope.go +++ b/pkg/http/middleware/pat_scope.go @@ -26,7 +26,7 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. // Fine-grained PATs and other token types don't support this, so we skip filtering. if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken { - existingScopes, ok := ghcontext.GetTokenScopes(ctx) + existingScopes, ok := ghcontext.GetTokenScopesForToken(ctx, tokenInfo.Token) if ok { logger.Debug("using existing scopes from context", "scopes", existingScopes) next.ServeHTTP(w, r) @@ -41,7 +41,7 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu } // Store fetched scopes in context for downstream use - ctx = ghcontext.WithTokenScopes(ctx, scopesList) + ctx = ghcontext.WithTokenScopesForToken(ctx, tokenInfo.Token, scopesList) next.ServeHTTP(w, r.WithContext(ctx)) return diff --git a/pkg/http/middleware/pat_scope_test.go b/pkg/http/middleware/pat_scope_test.go index 0607b8cf2..e09e87cfe 100644 --- a/pkg/http/middleware/pat_scope_test.go +++ b/pkg/http/middleware/pat_scope_test.go @@ -16,11 +16,15 @@ import ( // mockScopeFetcher is a mock implementation of scopes.FetcherInterface type mockScopeFetcher struct { - scopes []string - err error + scopes []string + err error + callCount int + tokens []string } -func (m *mockScopeFetcher) FetchTokenScopes(_ context.Context, _ string) ([]string, error) { +func (m *mockScopeFetcher) FetchTokenScopes(_ context.Context, token string) ([]string, error) { + m.callCount += 1 + m.tokens = append(m.tokens, token) return m.scopes, m.err } @@ -188,3 +192,40 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) { assert.True(t, scopesFound) assert.Equal(t, []string{"repo", "user"}, capturedScopes) } + +func TestWithPATScopes_RefetchesWhenCachedScopesBelongToDifferentToken(t *testing.T) { + logger := slog.Default() + + var capturedScopes []string + var scopesFound bool + + nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context()) + w.WriteHeader(http.StatusOK) + }) + + fetcher := &mockScopeFetcher{ + scopes: []string{"read:org"}, + } + + middleware := WithPATScopes(logger, fetcher) + handler := middleware(nextHandler) + + req := httptest.NewRequest(http.MethodGet, "/test", nil) + ctx := req.Context() + ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{ + Token: "ghp_new_token", + TokenType: utils.TokenTypePersonalAccessToken, + }) + ctx = ghcontext.WithTokenScopesForToken(ctx, "ghp_old_token", []string{"repo"}) + req = req.WithContext(ctx) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.True(t, scopesFound) + assert.Equal(t, []string{"read:org"}, capturedScopes) + assert.Equal(t, 1, fetcher.callCount) + require.Len(t, fetcher.tokens, 1) + assert.Equal(t, "ghp_new_token", fetcher.tokens[0]) +} diff --git a/pkg/http/middleware/scope_challenge.go b/pkg/http/middleware/scope_challenge.go index 1a86bf93c..54dccdafe 100644 --- a/pkg/http/middleware/scope_challenge.go +++ b/pkg/http/middleware/scope_challenge.go @@ -96,7 +96,7 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter // Get OAuth scopes for Token. First check if scopes are already in context, then fetch from GitHub if not present. // This allows Remote Server to pass scope info to avoid redundant GitHub API calls. - activeScopes, ok := ghcontext.GetTokenScopes(ctx) + activeScopes, ok := ghcontext.GetTokenScopesForToken(ctx, tokenInfo.Token) if !ok || (len(activeScopes) == 0 && tokenInfo.Token != "") { activeScopes, err = scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token) if err != nil { @@ -106,7 +106,7 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter } // Store active scopes in context for downstream use - ctx = ghcontext.WithTokenScopes(ctx, activeScopes) + ctx = ghcontext.WithTokenScopesForToken(ctx, tokenInfo.Token, activeScopes) r = r.WithContext(ctx) // Check if user has the required scopes