Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions pkg/context/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,47 @@ func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) {

type tokenScopesKey struct{}

type tokenScopesValue struct {
Token string
Scopes []string
}

// WithTokenScopes adds token scopes to the context
func WithTokenScopes(ctx context.Context, scopes []string) context.Context {
return context.WithValue(ctx, tokenScopesKey{}, scopes)
}

// WithTokenScopesForToken adds token scopes and the associated token to the context.
func WithTokenScopesForToken(ctx context.Context, token string, scopes []string) context.Context {
return context.WithValue(ctx, tokenScopesKey{}, tokenScopesValue{
Token: token,
Scopes: scopes,
})
}

// GetTokenScopes retrieves token scopes from the context
func GetTokenScopes(ctx context.Context) ([]string, bool) {
if scoped, ok := ctx.Value(tokenScopesKey{}).(tokenScopesValue); ok {
return scoped.Scopes, true
}
if scopes, ok := ctx.Value(tokenScopesKey{}).([]string); ok {
return scopes, true
}
return nil, false
}

// GetTokenScopesForToken retrieves token scopes only when they are bound to the active token.
func GetTokenScopesForToken(ctx context.Context, token string) ([]string, bool) {
if scoped, ok := ctx.Value(tokenScopesKey{}).(tokenScopesValue); ok {
if scoped.Token == token {
return scoped.Scopes, true
}
return nil, false
}
if token == "" {
if scopes, ok := ctx.Value(tokenScopesKey{}).([]string); ok {
return scopes, true
}
}
return nil, false
}
32 changes: 32 additions & 0 deletions pkg/context/token_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package context

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetTokenScopesForToken_MatchesBoundToken(t *testing.T) {
ctx := WithTokenScopesForToken(context.Background(), "token-a", []string{"repo"})

scopes, ok := GetTokenScopesForToken(ctx, "token-a")
assert.True(t, ok)
assert.Equal(t, []string{"repo"}, scopes)

scopes, ok = GetTokenScopesForToken(ctx, "token-b")
assert.False(t, ok)
assert.Nil(t, scopes)
}

func TestGetTokenScopesForToken_DoesNotReuseLegacyScopesForNonEmptyToken(t *testing.T) {
ctx := WithTokenScopes(context.Background(), []string{"repo"})

scopes, ok := GetTokenScopesForToken(ctx, "token-a")
assert.False(t, ok)
assert.Nil(t, scopes)

legacyScopes, legacyOK := GetTokenScopes(ctx)
assert.True(t, legacyOK)
assert.Equal(t, []string{"repo"}, legacyScopes)
}
2 changes: 1 addition & 1 deletion pkg/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche
// Fine-grained PATs and other token types don't support this, so we skip filtering.
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
// Check if scopes are already in context (should be set by WithPATScopes). If not, fetch them.
existingScopes, ok := ghcontext.GetTokenScopes(ctx)
existingScopes, ok := ghcontext.GetTokenScopesForToken(ctx, tokenInfo.Token)
if ok {
return b.WithFilter(github.CreateToolScopeFilter(existingScopes))
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/http/middleware/pat_scope.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
// Fine-grained PATs and other token types don't support this, so we skip filtering.
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
existingScopes, ok := ghcontext.GetTokenScopes(ctx)
existingScopes, ok := ghcontext.GetTokenScopesForToken(ctx, tokenInfo.Token)
if ok {
logger.Debug("using existing scopes from context", "scopes", existingScopes)
next.ServeHTTP(w, r)
Expand All @@ -41,7 +41,7 @@ func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) fu
}

// Store fetched scopes in context for downstream use
ctx = ghcontext.WithTokenScopes(ctx, scopesList)
ctx = ghcontext.WithTokenScopesForToken(ctx, tokenInfo.Token, scopesList)

next.ServeHTTP(w, r.WithContext(ctx))
return
Expand Down
47 changes: 44 additions & 3 deletions pkg/http/middleware/pat_scope_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ import (

// mockScopeFetcher is a mock implementation of scopes.FetcherInterface
type mockScopeFetcher struct {
scopes []string
err error
scopes []string
err error
callCount int
tokens []string
}

func (m *mockScopeFetcher) FetchTokenScopes(_ context.Context, _ string) ([]string, error) {
func (m *mockScopeFetcher) FetchTokenScopes(_ context.Context, token string) ([]string, error) {
m.callCount += 1
m.tokens = append(m.tokens, token)
return m.scopes, m.err
}

Expand Down Expand Up @@ -188,3 +192,40 @@ func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) {
assert.True(t, scopesFound)
assert.Equal(t, []string{"repo", "user"}, capturedScopes)
}

func TestWithPATScopes_RefetchesWhenCachedScopesBelongToDifferentToken(t *testing.T) {
logger := slog.Default()

var capturedScopes []string
var scopesFound bool

nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedScopes, scopesFound = ghcontext.GetTokenScopes(r.Context())
w.WriteHeader(http.StatusOK)
})

fetcher := &mockScopeFetcher{
scopes: []string{"read:org"},
}

middleware := WithPATScopes(logger, fetcher)
handler := middleware(nextHandler)

req := httptest.NewRequest(http.MethodGet, "/test", nil)
ctx := req.Context()
ctx = ghcontext.WithTokenInfo(ctx, &ghcontext.TokenInfo{
Token: "ghp_new_token",
TokenType: utils.TokenTypePersonalAccessToken,
})
ctx = ghcontext.WithTokenScopesForToken(ctx, "ghp_old_token", []string{"repo"})
req = req.WithContext(ctx)

rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.True(t, scopesFound)
assert.Equal(t, []string{"read:org"}, capturedScopes)
assert.Equal(t, 1, fetcher.callCount)
require.Len(t, fetcher.tokens, 1)
assert.Equal(t, "ghp_new_token", fetcher.tokens[0])
}
4 changes: 2 additions & 2 deletions pkg/http/middleware/scope_challenge.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter

// Get OAuth scopes for Token. First check if scopes are already in context, then fetch from GitHub if not present.
// This allows Remote Server to pass scope info to avoid redundant GitHub API calls.
activeScopes, ok := ghcontext.GetTokenScopes(ctx)
activeScopes, ok := ghcontext.GetTokenScopesForToken(ctx, tokenInfo.Token)
if !ok || (len(activeScopes) == 0 && tokenInfo.Token != "") {
activeScopes, err = scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
if err != nil {
Expand All @@ -106,7 +106,7 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter
}

// Store active scopes in context for downstream use
ctx = ghcontext.WithTokenScopes(ctx, activeScopes)
ctx = ghcontext.WithTokenScopesForToken(ctx, tokenInfo.Token, activeScopes)
r = r.WithContext(ctx)

// Check if user has the required scopes
Expand Down