Skip to content

Commit 3e35703

Browse files
committed
Move PAT Scope fetching into a middleware.
1 parent f3f88d0 commit 3e35703

File tree

5 files changed

+250
-17
lines changed

5 files changed

+250
-17
lines changed

pkg/context/token.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,6 @@ func WithTokenInfo(ctx context.Context, tokenInfo *TokenInfo) context.Context {
2323
return context.WithValue(ctx, tokenCtxKey, tokenInfo)
2424
}
2525

26-
func SetTokenScopes(ctx context.Context, scopes []string) {
27-
if tokenInfo, ok := GetTokenInfo(ctx); ok {
28-
tokenInfo.Scopes = scopes
29-
tokenInfo.ScopesFetched = true
30-
}
31-
}
32-
3326
// GetTokenInfo retrieves the authentication token from the context
3427
func GetTokenInfo(ctx context.Context) (*TokenInfo, bool) {
3528
if tokenInfo, ok := ctx.Value(tokenCtxKey).(*TokenInfo); ok {

pkg/http/handler.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ func (h *Handler) RegisterMiddleware(r chi.Router) {
120120
middleware.ExtractUserToken(h.oauthCfg),
121121
middleware.WithRequestConfig,
122122
middleware.WithMCPParse(),
123+
middleware.WithPATScopes(h.logger, h.scopeFetcher),
123124
)
124125

125126
if h.config.ScopeChallenge {
@@ -266,19 +267,20 @@ func PATScopeFilter(b *inventory.Builder, r *http.Request, fetcher scopes.Fetche
266267
return b
267268
}
268269

269-
// Fetch token scopes for scope-based tool filtering (PAT tokens only)
270+
// Scopes should have already been fetched by the WithPATScopes middleware.
270271
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
271272
// Fine-grained PATs and other token types don't support this, so we skip filtering.
272273
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
273-
scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token)
274-
if err != nil {
275-
return b
274+
if tokenInfo.ScopesFetched {
275+
return b.WithFilter(github.CreateToolScopeFilter(tokenInfo.Scopes))
276+
} else {
277+
scopesList, err := fetcher.FetchTokenScopes(ctx, tokenInfo.Token)
278+
if err != nil {
279+
return b
280+
}
281+
282+
return b.WithFilter(github.CreateToolScopeFilter(scopesList))
276283
}
277-
278-
// Store fetched scopes in context for downstream use
279-
ghcontext.SetTokenScopes(ctx, scopesList)
280-
281-
return b.WithFilter(github.CreateToolScopeFilter(scopesList))
282284
}
283285

284286
return b

pkg/http/middleware/pat_scope.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package middleware
2+
3+
import (
4+
"log/slog"
5+
"net/http"
6+
7+
ghcontext "github.com/github/github-mcp-server/pkg/context"
8+
"github.com/github/github-mcp-server/pkg/scopes"
9+
"github.com/github/github-mcp-server/pkg/utils"
10+
)
11+
12+
// WithScopeChallenge creates a new middleware that determines if an OAuth request contains sufficient scopes to
13+
// complete the request and returns a scope challenge if not.
14+
func WithPATScopes(logger *slog.Logger, scopeFetcher scopes.FetcherInterface) func(http.Handler) http.Handler {
15+
return func(next http.Handler) http.Handler {
16+
fn := func(w http.ResponseWriter, r *http.Request) {
17+
ctx := r.Context()
18+
19+
tokenInfo, ok := ghcontext.GetTokenInfo(ctx)
20+
if !ok || tokenInfo == nil {
21+
logger.Warn("no token info found in context")
22+
next.ServeHTTP(w, r)
23+
return
24+
}
25+
26+
// Fetch token scopes for scope-based tool filtering (PAT tokens only)
27+
// Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header.
28+
// Fine-grained PATs and other token types don't support this, so we skip filtering.
29+
if tokenInfo.TokenType == utils.TokenTypePersonalAccessToken {
30+
scopesList, err := scopeFetcher.FetchTokenScopes(ctx, tokenInfo.Token)
31+
if err != nil {
32+
logger.Warn("failed to fetch PAT scopes", "error", err)
33+
next.ServeHTTP(w, r)
34+
return
35+
}
36+
37+
tokenInfo.Scopes = scopesList
38+
tokenInfo.ScopesFetched = true
39+
40+
// Store fetched scopes in context for downstream use
41+
ctx := ghcontext.WithTokenInfo(ctx, tokenInfo)
42+
43+
next.ServeHTTP(w, r.WithContext(ctx))
44+
}
45+
}
46+
return http.HandlerFunc(fn)
47+
}
48+
}
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"errors"
6+
"log/slog"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
11+
ghcontext "github.com/github/github-mcp-server/pkg/context"
12+
"github.com/github/github-mcp-server/pkg/utils"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
// mockScopeFetcher is a mock implementation of scopes.FetcherInterface
18+
type mockScopeFetcher struct {
19+
scopes []string
20+
err error
21+
}
22+
23+
func (m *mockScopeFetcher) FetchTokenScopes(_ context.Context, _ string) ([]string, error) {
24+
return m.scopes, m.err
25+
}
26+
27+
func TestWithPATScopes(t *testing.T) {
28+
logger := slog.Default()
29+
30+
tests := []struct {
31+
name string
32+
tokenInfo *ghcontext.TokenInfo
33+
fetcherScopes []string
34+
fetcherErr error
35+
expectScopesFetched bool
36+
expectedScopes []string
37+
expectNextHandlerCalled bool
38+
}{
39+
{
40+
name: "no token info in context calls next handler",
41+
tokenInfo: nil,
42+
expectScopesFetched: false,
43+
expectedScopes: nil,
44+
expectNextHandlerCalled: true,
45+
},
46+
{
47+
name: "non-PAT token type skips scope fetching",
48+
tokenInfo: &ghcontext.TokenInfo{
49+
Token: "gho_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
50+
TokenType: utils.TokenTypeOAuthAccessToken,
51+
},
52+
expectScopesFetched: false,
53+
expectedScopes: nil,
54+
expectNextHandlerCalled: false, // middleware doesn't call next for non-PAT tokens
55+
},
56+
{
57+
name: "fine-grained PAT skips scope fetching",
58+
tokenInfo: &ghcontext.TokenInfo{
59+
Token: "github_pat_xxxxxxxxxxxxxxxxxxxxxxx",
60+
TokenType: utils.TokenTypeFineGrainedPersonalAccessToken,
61+
},
62+
expectScopesFetched: false,
63+
expectedScopes: nil,
64+
expectNextHandlerCalled: false,
65+
},
66+
{
67+
name: "classic PAT fetches and stores scopes",
68+
tokenInfo: &ghcontext.TokenInfo{
69+
Token: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
70+
TokenType: utils.TokenTypePersonalAccessToken,
71+
},
72+
fetcherScopes: []string{"repo", "user", "read:org"},
73+
expectScopesFetched: true,
74+
expectedScopes: []string{"repo", "user", "read:org"},
75+
expectNextHandlerCalled: true,
76+
},
77+
{
78+
name: "classic PAT with empty scopes",
79+
tokenInfo: &ghcontext.TokenInfo{
80+
Token: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
81+
TokenType: utils.TokenTypePersonalAccessToken,
82+
},
83+
fetcherScopes: []string{},
84+
expectScopesFetched: true,
85+
expectedScopes: []string{},
86+
expectNextHandlerCalled: true,
87+
},
88+
{
89+
name: "fetcher error calls next handler without scopes",
90+
tokenInfo: &ghcontext.TokenInfo{
91+
Token: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
92+
TokenType: utils.TokenTypePersonalAccessToken,
93+
},
94+
fetcherErr: errors.New("network error"),
95+
expectScopesFetched: false,
96+
expectedScopes: nil,
97+
expectNextHandlerCalled: true,
98+
},
99+
{
100+
name: "old-style PAT (40 hex chars) fetches scopes",
101+
tokenInfo: &ghcontext.TokenInfo{
102+
Token: "0123456789abcdef0123456789abcdef01234567",
103+
TokenType: utils.TokenTypePersonalAccessToken,
104+
},
105+
fetcherScopes: []string{"repo"},
106+
expectScopesFetched: true,
107+
expectedScopes: []string{"repo"},
108+
expectNextHandlerCalled: true,
109+
},
110+
}
111+
112+
for _, tt := range tests {
113+
t.Run(tt.name, func(t *testing.T) {
114+
var capturedTokenInfo *ghcontext.TokenInfo
115+
var nextHandlerCalled bool
116+
117+
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
118+
nextHandlerCalled = true
119+
capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context())
120+
w.WriteHeader(http.StatusOK)
121+
})
122+
123+
fetcher := &mockScopeFetcher{
124+
scopes: tt.fetcherScopes,
125+
err: tt.fetcherErr,
126+
}
127+
128+
middleware := WithPATScopes(logger, fetcher)
129+
handler := middleware(nextHandler)
130+
131+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
132+
133+
// Set up context with token info if provided
134+
if tt.tokenInfo != nil {
135+
ctx := ghcontext.WithTokenInfo(req.Context(), tt.tokenInfo)
136+
req = req.WithContext(ctx)
137+
}
138+
139+
rr := httptest.NewRecorder()
140+
handler.ServeHTTP(rr, req)
141+
142+
assert.Equal(t, tt.expectNextHandlerCalled, nextHandlerCalled, "next handler called mismatch")
143+
144+
if tt.expectNextHandlerCalled && tt.tokenInfo != nil {
145+
require.NotNil(t, capturedTokenInfo, "expected token info in context")
146+
assert.Equal(t, tt.expectScopesFetched, capturedTokenInfo.ScopesFetched)
147+
assert.Equal(t, tt.expectedScopes, capturedTokenInfo.Scopes)
148+
}
149+
})
150+
}
151+
}
152+
153+
func TestWithPATScopes_PreservesExistingTokenInfo(t *testing.T) {
154+
logger := slog.Default()
155+
156+
var capturedTokenInfo *ghcontext.TokenInfo
157+
158+
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
159+
capturedTokenInfo, _ = ghcontext.GetTokenInfo(r.Context())
160+
w.WriteHeader(http.StatusOK)
161+
})
162+
163+
fetcher := &mockScopeFetcher{
164+
scopes: []string{"repo", "user"},
165+
}
166+
167+
originalTokenInfo := &ghcontext.TokenInfo{
168+
Token: "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
169+
TokenType: utils.TokenTypePersonalAccessToken,
170+
}
171+
172+
middleware := WithPATScopes(logger, fetcher)
173+
handler := middleware(nextHandler)
174+
175+
req := httptest.NewRequest(http.MethodGet, "/test", nil)
176+
ctx := ghcontext.WithTokenInfo(req.Context(), originalTokenInfo)
177+
req = req.WithContext(ctx)
178+
179+
rr := httptest.NewRecorder()
180+
handler.ServeHTTP(rr, req)
181+
182+
require.NotNil(t, capturedTokenInfo)
183+
assert.Equal(t, originalTokenInfo.Token, capturedTokenInfo.Token)
184+
assert.Equal(t, originalTokenInfo.TokenType, capturedTokenInfo.TokenType)
185+
assert.True(t, capturedTokenInfo.ScopesFetched)
186+
assert.Equal(t, []string{"repo", "user"}, capturedTokenInfo.Scopes)
187+
}

pkg/http/middleware/scope_challenge.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ func WithScopeChallenge(oauthCfg *oauth.Config, scopeFetcher scopes.FetcherInter
102102
}
103103

104104
// Store active scopes in context for downstream use
105-
ghcontext.SetTokenScopes(ctx, activeScopes)
105+
tokenInfo.Scopes = activeScopes
106+
tokenInfo.ScopesFetched = true
107+
ctx = ghcontext.WithTokenInfo(ctx, tokenInfo)
108+
r = r.WithContext(ctx)
106109

107110
// Check if user has the required scopes
108111
if toolScopeInfo.HasAcceptedScope(activeScopes...) {

0 commit comments

Comments
 (0)