Skip to content

Commit 25cbf06

Browse files
committed
fix: add collision-resistant request and operation ids (#2099)
1 parent b222072 commit 25cbf06

File tree

6 files changed

+186
-0
lines changed

6 files changed

+186
-0
lines changed

pkg/context/request_ids.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package context
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"fmt"
7+
)
8+
9+
type requestIDCtxKey struct{}
10+
type operationIDCtxKey struct{}
11+
12+
func WithRequestID(ctx context.Context, requestID string) context.Context {
13+
return context.WithValue(ctx, requestIDCtxKey{}, requestID)
14+
}
15+
16+
func RequestID(ctx context.Context) (string, bool) {
17+
requestID, ok := ctx.Value(requestIDCtxKey{}).(string)
18+
return requestID, ok
19+
}
20+
21+
func WithOperationID(ctx context.Context, operationID string) context.Context {
22+
return context.WithValue(ctx, operationIDCtxKey{}, operationID)
23+
}
24+
25+
func OperationID(ctx context.Context) (string, bool) {
26+
operationID, ok := ctx.Value(operationIDCtxKey{}).(string)
27+
return operationID, ok
28+
}
29+
30+
func GenerateRequestID() (string, error) {
31+
return generateID("req")
32+
}
33+
34+
func GenerateOperationID() (string, error) {
35+
return generateID("op")
36+
}
37+
38+
func generateID(prefix string) (string, error) {
39+
buf := make([]byte, 16)
40+
if _, err := rand.Read(buf); err != nil {
41+
return "", fmt.Errorf("generate %s id: %w", prefix, err)
42+
}
43+
return fmt.Sprintf("%s_%x", prefix, buf), nil
44+
}

pkg/github/server.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"strings"
99
"time"
1010

11+
ghcontext "github.com/github/github-mcp-server/pkg/context"
1112
gherrors "github.com/github/github-mcp-server/pkg/errors"
1213
"github.com/github/github-mcp-server/pkg/inventory"
1314
"github.com/github/github-mcp-server/pkg/octicons"
@@ -107,6 +108,7 @@ func NewMCPServer(ctx context.Context, cfg *MCPServerConfig, deps ToolDependenci
107108
// and any middleware that needs to read or modify the context should be before it.
108109
ghServer.AddReceivingMiddleware(middleware...)
109110
ghServer.AddReceivingMiddleware(InjectDepsMiddleware(deps))
111+
ghServer.AddReceivingMiddleware(withOperationID)
110112
ghServer.AddReceivingMiddleware(addGitHubAPIErrorToContext)
111113

112114
if unrecognized := inv.UnrecognizedToolsets(); len(unrecognized) > 0 {
@@ -176,6 +178,27 @@ func addGitHubAPIErrorToContext(next mcp.MethodHandler) mcp.MethodHandler {
176178
}
177179
}
178180

181+
func withOperationID(next mcp.MethodHandler) mcp.MethodHandler {
182+
return func(ctx context.Context, method string, req mcp.Request) (result mcp.Result, err error) {
183+
requestID, ok := ghcontext.RequestID(ctx)
184+
if !ok || requestID == "" {
185+
requestID, err = ghcontext.GenerateRequestID()
186+
if err != nil {
187+
return nil, err
188+
}
189+
ctx = ghcontext.WithRequestID(ctx, requestID)
190+
}
191+
192+
operationID, err := ghcontext.GenerateOperationID()
193+
if err != nil {
194+
return nil, err
195+
}
196+
ctx = ghcontext.WithOperationID(ctx, operationID)
197+
198+
return next(ctx, method, req)
199+
}
200+
}
201+
179202
// NewServer creates a new GitHub MCP server with the specified GH client and logger.
180203
func NewServer(version string, opts *mcp.ServerOptions) *mcp.Server {
181204
if opts == nil {
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package github
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
ghcontext "github.com/github/github-mcp-server/pkg/context"
8+
"github.com/modelcontextprotocol/go-sdk/mcp"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestWithOperationID_PreservesRequestIDAndAddsOperationID(t *testing.T) {
14+
t.Parallel()
15+
16+
var capturedRequestID string
17+
var capturedOperationID string
18+
handler := withOperationID(func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) {
19+
var ok bool
20+
capturedRequestID, ok = ghcontext.RequestID(ctx)
21+
require.True(t, ok)
22+
23+
capturedOperationID, ok = ghcontext.OperationID(ctx)
24+
require.True(t, ok)
25+
return nil, nil
26+
})
27+
28+
_, err := handler(ghcontext.WithRequestID(context.Background(), "req_client"), "tools/call", nil)
29+
require.NoError(t, err)
30+
31+
assert.Equal(t, "req_client", capturedRequestID)
32+
assert.Regexp(t, `^op_[0-9a-f]+$`, capturedOperationID)
33+
}
34+
35+
func TestWithOperationID_GeneratesUniqueOperationIDs(t *testing.T) {
36+
t.Parallel()
37+
38+
var operationIDs []string
39+
handler := withOperationID(func(ctx context.Context, _ string, _ mcp.Request) (mcp.Result, error) {
40+
operationID, ok := ghcontext.OperationID(ctx)
41+
require.True(t, ok)
42+
operationIDs = append(operationIDs, operationID)
43+
return nil, nil
44+
})
45+
46+
_, err := handler(context.Background(), "tools/call", nil)
47+
require.NoError(t, err)
48+
_, err = handler(context.Background(), "tools/call", nil)
49+
require.NoError(t, err)
50+
51+
require.Len(t, operationIDs, 2)
52+
assert.NotEqual(t, operationIDs[0], operationIDs[1])
53+
}

pkg/http/headers/headers.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ const (
2525
ForwardedHostHeader = "X-Forwarded-Host"
2626
// ForwardedProtoHeader is a standard HTTP Header for preserving the original protocol when proxying.
2727
ForwardedProtoHeader = "X-Forwarded-Proto"
28+
// RequestIDHeader is a standard request-correlation header.
29+
RequestIDHeader = "X-Request-ID"
2830

2931
// RequestHmacHeader is used to authenticate requests to the Raw API.
3032
RequestHmacHeader = "Request-Hmac"

pkg/http/middleware/request_config.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ func WithRequestConfig(next http.Handler) http.Handler {
1515
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1616
ctx := r.Context()
1717

18+
requestID := strings.TrimSpace(r.Header.Get(headers.RequestIDHeader))
19+
if requestID == "" {
20+
generatedRequestID, err := ghcontext.GenerateRequestID()
21+
if err != nil {
22+
http.Error(w, "failed to generate request id", http.StatusInternalServerError)
23+
return
24+
}
25+
requestID = generatedRequestID
26+
}
27+
ctx = ghcontext.WithRequestID(ctx, requestID)
28+
w.Header().Set(headers.RequestIDHeader, requestID)
29+
1830
// Readonly mode
1931
if relaxedParseBool(r.Header.Get(headers.MCPReadOnlyHeader)) {
2032
ctx = ghcontext.WithReadonly(ctx, true)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package middleware
2+
3+
import (
4+
"net/http"
5+
"net/http/httptest"
6+
"testing"
7+
8+
ghcontext "github.com/github/github-mcp-server/pkg/context"
9+
"github.com/github/github-mcp-server/pkg/http/headers"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestWithRequestConfig_PreservesProvidedRequestID(t *testing.T) {
15+
t.Parallel()
16+
17+
recorder := httptest.NewRecorder()
18+
request := httptest.NewRequest(http.MethodGet, "/", nil)
19+
request.Header.Set(headers.RequestIDHeader, "client-request-id")
20+
21+
var requestID string
22+
handler := WithRequestConfig(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
23+
var ok bool
24+
requestID, ok = ghcontext.RequestID(r.Context())
25+
require.True(t, ok)
26+
}))
27+
28+
handler.ServeHTTP(recorder, request)
29+
30+
assert.Equal(t, "client-request-id", requestID)
31+
assert.Equal(t, "client-request-id", recorder.Header().Get(headers.RequestIDHeader))
32+
}
33+
34+
func TestWithRequestConfig_GeneratesRequestIDWhenMissing(t *testing.T) {
35+
t.Parallel()
36+
37+
recorder := httptest.NewRecorder()
38+
request := httptest.NewRequest(http.MethodGet, "/", nil)
39+
40+
var requestID string
41+
handler := WithRequestConfig(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
42+
var ok bool
43+
requestID, ok = ghcontext.RequestID(r.Context())
44+
require.True(t, ok)
45+
}))
46+
47+
handler.ServeHTTP(recorder, request)
48+
49+
assert.NotEmpty(t, requestID)
50+
assert.Equal(t, requestID, recorder.Header().Get(headers.RequestIDHeader))
51+
assert.Regexp(t, `^req_[0-9a-f]+$`, requestID)
52+
}

0 commit comments

Comments
 (0)