From 5ced9fa2ba62485ee922088fb35b2c569681edb3 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 6 Jun 2026 17:24:21 +0000 Subject: [PATCH 1/2] Add HTTP transport safeguards against GitHub API rate limiting Closes #2037 --- internal/ghmcp/server.go | 6 +- pkg/github/dependencies.go | 12 +- pkg/github/dependencies_ratelimit_test.go | 39 ++++ pkg/http/transport/rate_limit.go | 229 ++++++++++++++++++++++ pkg/http/transport/rate_limit_test.go | 190 ++++++++++++++++++ 5 files changed, 472 insertions(+), 4 deletions(-) create mode 100644 pkg/github/dependencies_ratelimit_test.go create mode 100644 pkg/http/transport/rate_limit.go create mode 100644 pkg/http/transport/rate_limit_test.go diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index a37c4d940d..22ed7cded3 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -62,8 +62,10 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv } // Construct REST client + rateLimitState := transport.NewRateLimitState() + restUATransport := &transport.UserAgentTransport{ - Transport: http.DefaultTransport, + Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState), Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), } restClient, err := gogithub.NewClient( @@ -80,7 +82,7 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv gqlHTTPClient := &http.Client{ Transport: &transport.BearerAuthTransport{ Transport: &transport.GraphQLFeaturesTransport{ - Transport: http.DefaultTransport, + Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState), }, Token: cfg.Token, }, diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index 1141fbce89..f8449bec77 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -276,6 +276,8 @@ type RequestDeps struct { // Observability exporters (includes logger) obsv observability.Exporters + + rateLimits *transport.RateLimitRegistry } // NewRequestDeps creates a RequestDeps with the provided clients and configuration. @@ -298,6 +300,7 @@ func NewRequestDeps( ContentWindowSize: contentWindowSize, featureChecker: featureChecker, obsv: obsv, + rateLimits: transport.NewRateLimitRegistry(), } } @@ -321,8 +324,13 @@ func (d *RequestDeps) GetClient(ctx context.Context) (*gogithub.Client, error) { // Construct REST client restClient, err := gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{ + Transport: &transport.UserAgentTransport{ + Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token)), + Agent: fmt.Sprintf("github-mcp-server/%s", d.version), + }, + }), gogithub.WithAuthToken(token), - gogithub.WithUserAgent(fmt.Sprintf("github-mcp-server/%s", d.version)), gogithub.WithEnterpriseURLs(baseRestURL.String(), uploadURL.String()), ) if err != nil { @@ -347,7 +355,7 @@ func (d *RequestDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error gqlHTTPClient := &http.Client{ Transport: &transport.BearerAuthTransport{ Transport: &transport.GraphQLFeaturesTransport{ - Transport: http.DefaultTransport, + Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token)), }, Token: token, }, diff --git a/pkg/github/dependencies_ratelimit_test.go b/pkg/github/dependencies_ratelimit_test.go new file mode 100644 index 0000000000..f68e07f94a --- /dev/null +++ b/pkg/github/dependencies_ratelimit_test.go @@ -0,0 +1,39 @@ +package github + +import ( + "log/slog" + "testing" + + "github.com/github/github-mcp-server/pkg/observability" + "github.com/github/github-mcp-server/pkg/observability/metrics" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRequestDeps_InitializesRateLimitRegistry(t *testing.T) { + t.Parallel() + + obs, err := observability.NewExporters(slog.New(slog.DiscardHandler), metrics.NewNoopMetrics()) + require.NoError(t, err) + + deps := NewRequestDeps( + nil, + "test", + false, + nil, + translations.NullTranslationHelper, + 0, + nil, + obs, + ) + + require.NotNil(t, deps.rateLimits) + + stateA1 := deps.rateLimits.Get("token-a") + stateA2 := deps.rateLimits.Get("token-a") + stateB := deps.rateLimits.Get("token-b") + + assert.Same(t, stateA1, stateA2) + assert.NotSame(t, stateA1, stateB) +} diff --git a/pkg/http/transport/rate_limit.go b/pkg/http/transport/rate_limit.go new file mode 100644 index 0000000000..45d5f3d128 --- /dev/null +++ b/pkg/http/transport/rate_limit.go @@ -0,0 +1,229 @@ +package transport + +import ( + "context" + "log/slog" + "net/http" + "strconv" + "sync" + "time" +) + +const ( + DefaultMinRateLimitRemaining = 50 + DefaultMinRequestInterval = 50 * time.Millisecond + DefaultMaxRateLimitRetries = 3 +) + +type RateLimitState struct { + mu sync.Mutex + + remaining int // -1 means unknown + reset time.Time + lastReq time.Time +} + +func NewRateLimitState() *RateLimitState { + return &RateLimitState{remaining: -1} +} + +type RateLimitRegistry struct { + states sync.Map +} + +func NewRateLimitRegistry() *RateLimitRegistry { + return &RateLimitRegistry{} +} + +func (r *RateLimitRegistry) Get(token string) *RateLimitState { + if state, ok := r.states.Load(token); ok { + return state.(*RateLimitState) + } + + state := NewRateLimitState() + actual, _ := r.states.LoadOrStore(token, state) + return actual.(*RateLimitState) +} + +type RateLimitTransport struct { + Transport http.RoundTripper + State *RateLimitState + + MinInterval time.Duration + MinRemaining int + MaxRetries int + Logger *slog.Logger +} + +func WrapWithRateLimit(base http.RoundTripper, state *RateLimitState) http.RoundTripper { + if state == nil { + state = NewRateLimitState() + } + + return &RateLimitTransport{ + Transport: base, + State: state, + MinInterval: DefaultMinRequestInterval, + MinRemaining: DefaultMinRateLimitRemaining, + MaxRetries: DefaultMaxRateLimitRetries, + } +} + +func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) { + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + + maxRetries := t.MaxRetries + if maxRetries < 0 { + maxRetries = DefaultMaxRateLimitRetries + } + + for attempt := 0; attempt <= maxRetries; attempt++ { + t.waitBeforeRequest(req.Context()) + + resp, err := transport.RoundTrip(req) + if err != nil { + return resp, err + } + + t.updateFromResponse(resp) + + if !isRateLimitedResponse(resp) || attempt == maxRetries { + return resp, nil + } + + wait := retryAfterDuration(resp) + if t.Logger != nil { + t.Logger.Warn( + "GitHub API rate limit hit, waiting before retry", + "attempt", attempt+1, + "max_retries", maxRetries, + "wait", wait.Round(time.Second), + "status", resp.StatusCode, + ) + } + + resp.Body.Close() + waitForContext(req.Context(), wait) + } + + return nil, nil +} + +func (t *RateLimitTransport) waitBeforeRequest(ctx context.Context) { + minInterval := t.MinInterval + if minInterval <= 0 { + minInterval = DefaultMinRequestInterval + } + + minRemaining := t.MinRemaining + if minRemaining <= 0 { + minRemaining = DefaultMinRateLimitRemaining + } + + t.State.mu.Lock() + defer t.State.mu.Unlock() + + if wait := time.Until(t.State.lastReq.Add(minInterval)); wait > 0 { + waitForContext(ctx, wait) + } + + if t.State.remaining >= 0 && t.State.remaining < minRemaining && !t.State.reset.IsZero() { + if wait := time.Until(t.State.reset) + time.Second; wait > 0 { + if t.Logger != nil { + t.Logger.Warn( + "GitHub API rate limit nearly exhausted, waiting for reset", + "remaining", t.State.remaining, + "wait", wait.Round(time.Second), + ) + } + waitForContext(ctx, wait) + t.State.remaining = -1 + } + } + + t.State.lastReq = time.Now() +} + +func (t *RateLimitTransport) updateFromResponse(resp *http.Response) { + remaining, reset, ok := parseRateLimitHeaders(resp) + if !ok { + return + } + + t.State.mu.Lock() + defer t.State.mu.Unlock() + t.State.remaining = remaining + t.State.reset = reset +} + +func parseRateLimitHeaders(resp *http.Response) (remaining int, reset time.Time, ok bool) { + remainingStr := resp.Header.Get("X-RateLimit-Remaining") + resetStr := resp.Header.Get("X-RateLimit-Reset") + if remainingStr == "" || resetStr == "" { + return 0, time.Time{}, false + } + + remainingVal, err := strconv.Atoi(remainingStr) + if err != nil { + return 0, time.Time{}, false + } + + resetUnix, err := strconv.ParseInt(resetStr, 10, 64) + if err != nil { + return 0, time.Time{}, false + } + + return remainingVal, time.Unix(resetUnix, 0), true +} + +func isRateLimitedResponse(resp *http.Response) bool { + if resp == nil { + return false + } + + switch resp.StatusCode { + case http.StatusTooManyRequests: + return true + case http.StatusForbidden: + return resp.Header.Get("Retry-After") != "" + default: + return false + } +} + +func retryAfterDuration(resp *http.Response) time.Duration { + if resp == nil { + return time.Second + } + + if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" { + if seconds, err := strconv.Atoi(retryAfter); err == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + + if _, reset, ok := parseRateLimitHeaders(resp); ok && !reset.IsZero() { + if wait := time.Until(reset) + time.Second; wait > 0 { + return wait + } + } + + return time.Second +} + +func waitForContext(ctx context.Context, d time.Duration) { + if d <= 0 { + return + } + + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + case <-timer.C: + } +} diff --git a/pkg/http/transport/rate_limit_test.go b/pkg/http/transport/rate_limit_test.go new file mode 100644 index 0000000000..fd9eeb3018 --- /dev/null +++ b/pkg/http/transport/rate_limit_test.go @@ -0,0 +1,190 @@ +package transport + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRateLimitTransport_UpdatesStateFromHeaders(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.Header().Set("X-RateLimit-Remaining", "100") + w.Header().Set("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10)) + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + state := NewRateLimitState() + client := &http.Client{Transport: WrapWithRateLimit(server.Client().Transport, state)} + + resp, err := client.Get(server.URL) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + state.mu.Lock() + defer state.mu.Unlock() + assert.Equal(t, 100, state.remaining) + assert.False(t, state.reset.IsZero()) + assert.Equal(t, int32(1), calls.Load()) +} + +func TestRateLimitTransport_WaitsWhenRemainingLow(t *testing.T) { + t.Parallel() + + state := NewRateLimitState() + state.mu.Lock() + state.remaining = 10 + state.reset = time.Now().Add(200 * time.Millisecond) + state.mu.Unlock() + + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + transport := &RateLimitTransport{ + Transport: server.Client().Transport, + State: state, + MinInterval: 0, + MinRemaining: 50, + MaxRetries: 0, + } + client := &http.Client{Transport: transport} + + start := time.Now() + resp, err := client.Get(server.URL) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + assert.GreaterOrEqual(t, time.Since(start), 150*time.Millisecond) +} + +func TestRateLimitTransport_EnforcesMinInterval(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + state := NewRateLimitState() + transport := &RateLimitTransport{ + Transport: server.Client().Transport, + State: state, + MinInterval: 100 * time.Millisecond, + MinRemaining: 0, + MaxRetries: 0, + } + client := &http.Client{Transport: transport} + + start := time.Now() + resp1, err := client.Get(server.URL) + require.NoError(t, err) + require.NoError(t, resp1.Body.Close()) + resp2, err := client.Get(server.URL) + require.NoError(t, err) + require.NoError(t, resp2.Body.Close()) + assert.GreaterOrEqual(t, time.Since(start), 100*time.Millisecond) +} + +func TestRateLimitTransport_RetriesOn429(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if calls.Add(1) == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + transport := &RateLimitTransport{ + Transport: server.Client().Transport, + State: NewRateLimitState(), + MinInterval: 0, + MinRemaining: 0, + MaxRetries: 1, + } + resp, err := (&http.Client{Transport: transport}).Get(server.URL) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(2), calls.Load()) +} + +func TestRateLimitTransport_DoesNotRetryOther403(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusForbidden) + })) + t.Cleanup(server.Close) + + transport := &RateLimitTransport{ + Transport: server.Client().Transport, + State: NewRateLimitState(), + MinInterval: 0, + MinRemaining: 0, + MaxRetries: 2, + } + resp, err := (&http.Client{Transport: transport}).Get(server.URL) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + assert.Equal(t, int32(1), calls.Load()) +} + +func TestRateLimitRegistry_SharesStatePerToken(t *testing.T) { + t.Parallel() + + registry := NewRateLimitRegistry() + stateA1 := registry.Get("token-a") + stateA2 := registry.Get("token-a") + stateB := registry.Get("token-b") + assert.Same(t, stateA1, stateA2) + assert.NotSame(t, stateA1, stateB) +} + +func TestParseRateLimitHeaders(t *testing.T) { + t.Parallel() + + reset := time.Now().Add(time.Minute).Unix() + resp := &http.Response{Header: make(http.Header), Body: io.NopCloser(http.NoBody)} + resp.Header.Set("X-RateLimit-Remaining", "42") + resp.Header.Set("X-RateLimit-Reset", strconv.FormatInt(reset, 10)) + + remaining, resetTime, ok := parseRateLimitHeaders(resp) + require.True(t, ok) + assert.Equal(t, 42, remaining) + assert.Equal(t, time.Unix(reset, 0), resetTime) +} + +func TestWaitForContext_RespectsCancellation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + start := time.Now() + waitForContext(ctx, time.Second) + assert.Less(t, time.Since(start), 100*time.Millisecond) +} From 08bb47d6b431e694a61ae1425d343cb9bd5d05f9 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 6 Jun 2026 21:27:34 +0000 Subject: [PATCH 2/2] Wire rate limit transport logger for throttling visibility Log proactive waits and retry backoff so operators can see when the server is slowing down to avoid GitHub API rate limit lockouts. --- internal/ghmcp/server.go | 5 +++-- pkg/github/dependencies.go | 6 ++++-- pkg/http/transport/rate_limit.go | 3 ++- pkg/http/transport/rate_limit_test.go | 2 +- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 22ed7cded3..1c93b049ee 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -63,9 +63,10 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv // Construct REST client rateLimitState := transport.NewRateLimitState() + rateLimitLogger := cfg.Logger.With("component", "rate_limit") restUATransport := &transport.UserAgentTransport{ - Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState), + Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState, rateLimitLogger), Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), } restClient, err := gogithub.NewClient( @@ -82,7 +83,7 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv gqlHTTPClient := &http.Client{ Transport: &transport.BearerAuthTransport{ Transport: &transport.GraphQLFeaturesTransport{ - Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState), + Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState, rateLimitLogger), }, Token: cfg.Token, }, diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index f8449bec77..3a1e9d9830 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -323,10 +323,11 @@ func (d *RequestDeps) GetClient(ctx context.Context) (*gogithub.Client, error) { } // Construct REST client + rateLimitLogger := d.obsv.Logger().With("component", "rate_limit") restClient, err := gogithub.NewClient( gogithub.WithHTTPClient(&http.Client{ Transport: &transport.UserAgentTransport{ - Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token)), + Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token), rateLimitLogger), Agent: fmt.Sprintf("github-mcp-server/%s", d.version), }, }), @@ -352,10 +353,11 @@ func (d *RequestDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error // We use NewEnterpriseClient unconditionally since we already parsed the API host // Wrap transport with GraphQLFeaturesTransport to inject feature flags from context, // matching the transport chain used by the remote server. + rateLimitLogger := d.obsv.Logger().With("component", "rate_limit") gqlHTTPClient := &http.Client{ Transport: &transport.BearerAuthTransport{ Transport: &transport.GraphQLFeaturesTransport{ - Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token)), + Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token), rateLimitLogger), }, Token: token, }, diff --git a/pkg/http/transport/rate_limit.go b/pkg/http/transport/rate_limit.go index 45d5f3d128..a218460e6c 100644 --- a/pkg/http/transport/rate_limit.go +++ b/pkg/http/transport/rate_limit.go @@ -55,7 +55,7 @@ type RateLimitTransport struct { Logger *slog.Logger } -func WrapWithRateLimit(base http.RoundTripper, state *RateLimitState) http.RoundTripper { +func WrapWithRateLimit(base http.RoundTripper, state *RateLimitState, logger *slog.Logger) http.RoundTripper { if state == nil { state = NewRateLimitState() } @@ -66,6 +66,7 @@ func WrapWithRateLimit(base http.RoundTripper, state *RateLimitState) http.Round MinInterval: DefaultMinRequestInterval, MinRemaining: DefaultMinRateLimitRemaining, MaxRetries: DefaultMaxRateLimitRetries, + Logger: logger, } } diff --git a/pkg/http/transport/rate_limit_test.go b/pkg/http/transport/rate_limit_test.go index fd9eeb3018..8503ff6d11 100644 --- a/pkg/http/transport/rate_limit_test.go +++ b/pkg/http/transport/rate_limit_test.go @@ -27,7 +27,7 @@ func TestRateLimitTransport_UpdatesStateFromHeaders(t *testing.T) { t.Cleanup(server.Close) state := NewRateLimitState() - client := &http.Client{Transport: WrapWithRateLimit(server.Client().Transport, state)} + client := &http.Client{Transport: WrapWithRateLimit(server.Client().Transport, state, nil)} resp, err := client.Get(server.URL) require.NoError(t, err)