Skip to content

Commit 5ced9fa

Browse files
author
root
committed
Add HTTP transport safeguards against GitHub API rate limiting
Closes #2037
1 parent 457f599 commit 5ced9fa

5 files changed

Lines changed: 472 additions & 4 deletions

File tree

internal/ghmcp/server.go

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv
6262
}
6363

6464
// Construct REST client
65+
rateLimitState := transport.NewRateLimitState()
66+
6567
restUATransport := &transport.UserAgentTransport{
66-
Transport: http.DefaultTransport,
68+
Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState),
6769
Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version),
6870
}
6971
restClient, err := gogithub.NewClient(
@@ -80,7 +82,7 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv
8082
gqlHTTPClient := &http.Client{
8183
Transport: &transport.BearerAuthTransport{
8284
Transport: &transport.GraphQLFeaturesTransport{
83-
Transport: http.DefaultTransport,
85+
Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState),
8486
},
8587
Token: cfg.Token,
8688
},

pkg/github/dependencies.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,8 @@ type RequestDeps struct {
276276

277277
// Observability exporters (includes logger)
278278
obsv observability.Exporters
279+
280+
rateLimits *transport.RateLimitRegistry
279281
}
280282

281283
// NewRequestDeps creates a RequestDeps with the provided clients and configuration.
@@ -298,6 +300,7 @@ func NewRequestDeps(
298300
ContentWindowSize: contentWindowSize,
299301
featureChecker: featureChecker,
300302
obsv: obsv,
303+
rateLimits: transport.NewRateLimitRegistry(),
301304
}
302305
}
303306

@@ -321,8 +324,13 @@ func (d *RequestDeps) GetClient(ctx context.Context) (*gogithub.Client, error) {
321324

322325
// Construct REST client
323326
restClient, err := gogithub.NewClient(
327+
gogithub.WithHTTPClient(&http.Client{
328+
Transport: &transport.UserAgentTransport{
329+
Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token)),
330+
Agent: fmt.Sprintf("github-mcp-server/%s", d.version),
331+
},
332+
}),
324333
gogithub.WithAuthToken(token),
325-
gogithub.WithUserAgent(fmt.Sprintf("github-mcp-server/%s", d.version)),
326334
gogithub.WithEnterpriseURLs(baseRestURL.String(), uploadURL.String()),
327335
)
328336
if err != nil {
@@ -347,7 +355,7 @@ func (d *RequestDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error
347355
gqlHTTPClient := &http.Client{
348356
Transport: &transport.BearerAuthTransport{
349357
Transport: &transport.GraphQLFeaturesTransport{
350-
Transport: http.DefaultTransport,
358+
Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token)),
351359
},
352360
Token: token,
353361
},
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package github
2+
3+
import (
4+
"log/slog"
5+
"testing"
6+
7+
"github.com/github/github-mcp-server/pkg/observability"
8+
"github.com/github/github-mcp-server/pkg/observability/metrics"
9+
"github.com/github/github-mcp-server/pkg/translations"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func TestNewRequestDeps_InitializesRateLimitRegistry(t *testing.T) {
15+
t.Parallel()
16+
17+
obs, err := observability.NewExporters(slog.New(slog.DiscardHandler), metrics.NewNoopMetrics())
18+
require.NoError(t, err)
19+
20+
deps := NewRequestDeps(
21+
nil,
22+
"test",
23+
false,
24+
nil,
25+
translations.NullTranslationHelper,
26+
0,
27+
nil,
28+
obs,
29+
)
30+
31+
require.NotNil(t, deps.rateLimits)
32+
33+
stateA1 := deps.rateLimits.Get("token-a")
34+
stateA2 := deps.rateLimits.Get("token-a")
35+
stateB := deps.rateLimits.Get("token-b")
36+
37+
assert.Same(t, stateA1, stateA2)
38+
assert.NotSame(t, stateA1, stateB)
39+
}

pkg/http/transport/rate_limit.go

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
package transport
2+
3+
import (
4+
"context"
5+
"log/slog"
6+
"net/http"
7+
"strconv"
8+
"sync"
9+
"time"
10+
)
11+
12+
const (
13+
DefaultMinRateLimitRemaining = 50
14+
DefaultMinRequestInterval = 50 * time.Millisecond
15+
DefaultMaxRateLimitRetries = 3
16+
)
17+
18+
type RateLimitState struct {
19+
mu sync.Mutex
20+
21+
remaining int // -1 means unknown
22+
reset time.Time
23+
lastReq time.Time
24+
}
25+
26+
func NewRateLimitState() *RateLimitState {
27+
return &RateLimitState{remaining: -1}
28+
}
29+
30+
type RateLimitRegistry struct {
31+
states sync.Map
32+
}
33+
34+
func NewRateLimitRegistry() *RateLimitRegistry {
35+
return &RateLimitRegistry{}
36+
}
37+
38+
func (r *RateLimitRegistry) Get(token string) *RateLimitState {
39+
if state, ok := r.states.Load(token); ok {
40+
return state.(*RateLimitState)
41+
}
42+
43+
state := NewRateLimitState()
44+
actual, _ := r.states.LoadOrStore(token, state)
45+
return actual.(*RateLimitState)
46+
}
47+
48+
type RateLimitTransport struct {
49+
Transport http.RoundTripper
50+
State *RateLimitState
51+
52+
MinInterval time.Duration
53+
MinRemaining int
54+
MaxRetries int
55+
Logger *slog.Logger
56+
}
57+
58+
func WrapWithRateLimit(base http.RoundTripper, state *RateLimitState) http.RoundTripper {
59+
if state == nil {
60+
state = NewRateLimitState()
61+
}
62+
63+
return &RateLimitTransport{
64+
Transport: base,
65+
State: state,
66+
MinInterval: DefaultMinRequestInterval,
67+
MinRemaining: DefaultMinRateLimitRemaining,
68+
MaxRetries: DefaultMaxRateLimitRetries,
69+
}
70+
}
71+
72+
func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) {
73+
transport := t.Transport
74+
if transport == nil {
75+
transport = http.DefaultTransport
76+
}
77+
78+
maxRetries := t.MaxRetries
79+
if maxRetries < 0 {
80+
maxRetries = DefaultMaxRateLimitRetries
81+
}
82+
83+
for attempt := 0; attempt <= maxRetries; attempt++ {
84+
t.waitBeforeRequest(req.Context())
85+
86+
resp, err := transport.RoundTrip(req)
87+
if err != nil {
88+
return resp, err
89+
}
90+
91+
t.updateFromResponse(resp)
92+
93+
if !isRateLimitedResponse(resp) || attempt == maxRetries {
94+
return resp, nil
95+
}
96+
97+
wait := retryAfterDuration(resp)
98+
if t.Logger != nil {
99+
t.Logger.Warn(
100+
"GitHub API rate limit hit, waiting before retry",
101+
"attempt", attempt+1,
102+
"max_retries", maxRetries,
103+
"wait", wait.Round(time.Second),
104+
"status", resp.StatusCode,
105+
)
106+
}
107+
108+
resp.Body.Close()
109+
waitForContext(req.Context(), wait)
110+
}
111+
112+
return nil, nil
113+
}
114+
115+
func (t *RateLimitTransport) waitBeforeRequest(ctx context.Context) {
116+
minInterval := t.MinInterval
117+
if minInterval <= 0 {
118+
minInterval = DefaultMinRequestInterval
119+
}
120+
121+
minRemaining := t.MinRemaining
122+
if minRemaining <= 0 {
123+
minRemaining = DefaultMinRateLimitRemaining
124+
}
125+
126+
t.State.mu.Lock()
127+
defer t.State.mu.Unlock()
128+
129+
if wait := time.Until(t.State.lastReq.Add(minInterval)); wait > 0 {
130+
waitForContext(ctx, wait)
131+
}
132+
133+
if t.State.remaining >= 0 && t.State.remaining < minRemaining && !t.State.reset.IsZero() {
134+
if wait := time.Until(t.State.reset) + time.Second; wait > 0 {
135+
if t.Logger != nil {
136+
t.Logger.Warn(
137+
"GitHub API rate limit nearly exhausted, waiting for reset",
138+
"remaining", t.State.remaining,
139+
"wait", wait.Round(time.Second),
140+
)
141+
}
142+
waitForContext(ctx, wait)
143+
t.State.remaining = -1
144+
}
145+
}
146+
147+
t.State.lastReq = time.Now()
148+
}
149+
150+
func (t *RateLimitTransport) updateFromResponse(resp *http.Response) {
151+
remaining, reset, ok := parseRateLimitHeaders(resp)
152+
if !ok {
153+
return
154+
}
155+
156+
t.State.mu.Lock()
157+
defer t.State.mu.Unlock()
158+
t.State.remaining = remaining
159+
t.State.reset = reset
160+
}
161+
162+
func parseRateLimitHeaders(resp *http.Response) (remaining int, reset time.Time, ok bool) {
163+
remainingStr := resp.Header.Get("X-RateLimit-Remaining")
164+
resetStr := resp.Header.Get("X-RateLimit-Reset")
165+
if remainingStr == "" || resetStr == "" {
166+
return 0, time.Time{}, false
167+
}
168+
169+
remainingVal, err := strconv.Atoi(remainingStr)
170+
if err != nil {
171+
return 0, time.Time{}, false
172+
}
173+
174+
resetUnix, err := strconv.ParseInt(resetStr, 10, 64)
175+
if err != nil {
176+
return 0, time.Time{}, false
177+
}
178+
179+
return remainingVal, time.Unix(resetUnix, 0), true
180+
}
181+
182+
func isRateLimitedResponse(resp *http.Response) bool {
183+
if resp == nil {
184+
return false
185+
}
186+
187+
switch resp.StatusCode {
188+
case http.StatusTooManyRequests:
189+
return true
190+
case http.StatusForbidden:
191+
return resp.Header.Get("Retry-After") != ""
192+
default:
193+
return false
194+
}
195+
}
196+
197+
func retryAfterDuration(resp *http.Response) time.Duration {
198+
if resp == nil {
199+
return time.Second
200+
}
201+
202+
if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" {
203+
if seconds, err := strconv.Atoi(retryAfter); err == nil && seconds > 0 {
204+
return time.Duration(seconds) * time.Second
205+
}
206+
}
207+
208+
if _, reset, ok := parseRateLimitHeaders(resp); ok && !reset.IsZero() {
209+
if wait := time.Until(reset) + time.Second; wait > 0 {
210+
return wait
211+
}
212+
}
213+
214+
return time.Second
215+
}
216+
217+
func waitForContext(ctx context.Context, d time.Duration) {
218+
if d <= 0 {
219+
return
220+
}
221+
222+
timer := time.NewTimer(d)
223+
defer timer.Stop()
224+
225+
select {
226+
case <-ctx.Done():
227+
case <-timer.C:
228+
}
229+
}

0 commit comments

Comments
 (0)