diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index a37c4d940d..1c93b049ee 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -62,8 +62,11 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv } // Construct REST client + rateLimitState := transport.NewRateLimitState() + rateLimitLogger := cfg.Logger.With("component", "rate_limit") + restUATransport := &transport.UserAgentTransport{ - Transport: http.DefaultTransport, + Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState, rateLimitLogger), Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), } restClient, err := gogithub.NewClient( @@ -80,7 +83,7 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv gqlHTTPClient := &http.Client{ Transport: &transport.BearerAuthTransport{ Transport: &transport.GraphQLFeaturesTransport{ - Transport: http.DefaultTransport, + Transport: transport.WrapWithRateLimit(http.DefaultTransport, rateLimitState, rateLimitLogger), }, Token: cfg.Token, }, diff --git a/pkg/github/dependencies.go b/pkg/github/dependencies.go index 1141fbce89..3a1e9d9830 100644 --- a/pkg/github/dependencies.go +++ b/pkg/github/dependencies.go @@ -276,6 +276,8 @@ type RequestDeps struct { // Observability exporters (includes logger) obsv observability.Exporters + + rateLimits *transport.RateLimitRegistry } // NewRequestDeps creates a RequestDeps with the provided clients and configuration. @@ -298,6 +300,7 @@ func NewRequestDeps( ContentWindowSize: contentWindowSize, featureChecker: featureChecker, obsv: obsv, + rateLimits: transport.NewRateLimitRegistry(), } } @@ -320,9 +323,15 @@ func (d *RequestDeps) GetClient(ctx context.Context) (*gogithub.Client, error) { } // Construct REST client + rateLimitLogger := d.obsv.Logger().With("component", "rate_limit") restClient, err := gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{ + Transport: &transport.UserAgentTransport{ + Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token), rateLimitLogger), + Agent: fmt.Sprintf("github-mcp-server/%s", d.version), + }, + }), gogithub.WithAuthToken(token), - gogithub.WithUserAgent(fmt.Sprintf("github-mcp-server/%s", d.version)), gogithub.WithEnterpriseURLs(baseRestURL.String(), uploadURL.String()), ) if err != nil { @@ -344,10 +353,11 @@ func (d *RequestDeps) GetGQLClient(ctx context.Context) (*githubv4.Client, error // We use NewEnterpriseClient unconditionally since we already parsed the API host // Wrap transport with GraphQLFeaturesTransport to inject feature flags from context, // matching the transport chain used by the remote server. + rateLimitLogger := d.obsv.Logger().With("component", "rate_limit") gqlHTTPClient := &http.Client{ Transport: &transport.BearerAuthTransport{ Transport: &transport.GraphQLFeaturesTransport{ - Transport: http.DefaultTransport, + Transport: transport.WrapWithRateLimit(http.DefaultTransport, d.rateLimits.Get(token), rateLimitLogger), }, Token: token, }, diff --git a/pkg/github/dependencies_ratelimit_test.go b/pkg/github/dependencies_ratelimit_test.go new file mode 100644 index 0000000000..f68e07f94a --- /dev/null +++ b/pkg/github/dependencies_ratelimit_test.go @@ -0,0 +1,39 @@ +package github + +import ( + "log/slog" + "testing" + + "github.com/github/github-mcp-server/pkg/observability" + "github.com/github/github-mcp-server/pkg/observability/metrics" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewRequestDeps_InitializesRateLimitRegistry(t *testing.T) { + t.Parallel() + + obs, err := observability.NewExporters(slog.New(slog.DiscardHandler), metrics.NewNoopMetrics()) + require.NoError(t, err) + + deps := NewRequestDeps( + nil, + "test", + false, + nil, + translations.NullTranslationHelper, + 0, + nil, + obs, + ) + + require.NotNil(t, deps.rateLimits) + + stateA1 := deps.rateLimits.Get("token-a") + stateA2 := deps.rateLimits.Get("token-a") + stateB := deps.rateLimits.Get("token-b") + + assert.Same(t, stateA1, stateA2) + assert.NotSame(t, stateA1, stateB) +} diff --git a/pkg/http/transport/rate_limit.go b/pkg/http/transport/rate_limit.go new file mode 100644 index 0000000000..a218460e6c --- /dev/null +++ b/pkg/http/transport/rate_limit.go @@ -0,0 +1,230 @@ +package transport + +import ( + "context" + "log/slog" + "net/http" + "strconv" + "sync" + "time" +) + +const ( + DefaultMinRateLimitRemaining = 50 + DefaultMinRequestInterval = 50 * time.Millisecond + DefaultMaxRateLimitRetries = 3 +) + +type RateLimitState struct { + mu sync.Mutex + + remaining int // -1 means unknown + reset time.Time + lastReq time.Time +} + +func NewRateLimitState() *RateLimitState { + return &RateLimitState{remaining: -1} +} + +type RateLimitRegistry struct { + states sync.Map +} + +func NewRateLimitRegistry() *RateLimitRegistry { + return &RateLimitRegistry{} +} + +func (r *RateLimitRegistry) Get(token string) *RateLimitState { + if state, ok := r.states.Load(token); ok { + return state.(*RateLimitState) + } + + state := NewRateLimitState() + actual, _ := r.states.LoadOrStore(token, state) + return actual.(*RateLimitState) +} + +type RateLimitTransport struct { + Transport http.RoundTripper + State *RateLimitState + + MinInterval time.Duration + MinRemaining int + MaxRetries int + Logger *slog.Logger +} + +func WrapWithRateLimit(base http.RoundTripper, state *RateLimitState, logger *slog.Logger) http.RoundTripper { + if state == nil { + state = NewRateLimitState() + } + + return &RateLimitTransport{ + Transport: base, + State: state, + MinInterval: DefaultMinRequestInterval, + MinRemaining: DefaultMinRateLimitRemaining, + MaxRetries: DefaultMaxRateLimitRetries, + Logger: logger, + } +} + +func (t *RateLimitTransport) RoundTrip(req *http.Request) (*http.Response, error) { + transport := t.Transport + if transport == nil { + transport = http.DefaultTransport + } + + maxRetries := t.MaxRetries + if maxRetries < 0 { + maxRetries = DefaultMaxRateLimitRetries + } + + for attempt := 0; attempt <= maxRetries; attempt++ { + t.waitBeforeRequest(req.Context()) + + resp, err := transport.RoundTrip(req) + if err != nil { + return resp, err + } + + t.updateFromResponse(resp) + + if !isRateLimitedResponse(resp) || attempt == maxRetries { + return resp, nil + } + + wait := retryAfterDuration(resp) + if t.Logger != nil { + t.Logger.Warn( + "GitHub API rate limit hit, waiting before retry", + "attempt", attempt+1, + "max_retries", maxRetries, + "wait", wait.Round(time.Second), + "status", resp.StatusCode, + ) + } + + resp.Body.Close() + waitForContext(req.Context(), wait) + } + + return nil, nil +} + +func (t *RateLimitTransport) waitBeforeRequest(ctx context.Context) { + minInterval := t.MinInterval + if minInterval <= 0 { + minInterval = DefaultMinRequestInterval + } + + minRemaining := t.MinRemaining + if minRemaining <= 0 { + minRemaining = DefaultMinRateLimitRemaining + } + + t.State.mu.Lock() + defer t.State.mu.Unlock() + + if wait := time.Until(t.State.lastReq.Add(minInterval)); wait > 0 { + waitForContext(ctx, wait) + } + + if t.State.remaining >= 0 && t.State.remaining < minRemaining && !t.State.reset.IsZero() { + if wait := time.Until(t.State.reset) + time.Second; wait > 0 { + if t.Logger != nil { + t.Logger.Warn( + "GitHub API rate limit nearly exhausted, waiting for reset", + "remaining", t.State.remaining, + "wait", wait.Round(time.Second), + ) + } + waitForContext(ctx, wait) + t.State.remaining = -1 + } + } + + t.State.lastReq = time.Now() +} + +func (t *RateLimitTransport) updateFromResponse(resp *http.Response) { + remaining, reset, ok := parseRateLimitHeaders(resp) + if !ok { + return + } + + t.State.mu.Lock() + defer t.State.mu.Unlock() + t.State.remaining = remaining + t.State.reset = reset +} + +func parseRateLimitHeaders(resp *http.Response) (remaining int, reset time.Time, ok bool) { + remainingStr := resp.Header.Get("X-RateLimit-Remaining") + resetStr := resp.Header.Get("X-RateLimit-Reset") + if remainingStr == "" || resetStr == "" { + return 0, time.Time{}, false + } + + remainingVal, err := strconv.Atoi(remainingStr) + if err != nil { + return 0, time.Time{}, false + } + + resetUnix, err := strconv.ParseInt(resetStr, 10, 64) + if err != nil { + return 0, time.Time{}, false + } + + return remainingVal, time.Unix(resetUnix, 0), true +} + +func isRateLimitedResponse(resp *http.Response) bool { + if resp == nil { + return false + } + + switch resp.StatusCode { + case http.StatusTooManyRequests: + return true + case http.StatusForbidden: + return resp.Header.Get("Retry-After") != "" + default: + return false + } +} + +func retryAfterDuration(resp *http.Response) time.Duration { + if resp == nil { + return time.Second + } + + if retryAfter := resp.Header.Get("Retry-After"); retryAfter != "" { + if seconds, err := strconv.Atoi(retryAfter); err == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + + if _, reset, ok := parseRateLimitHeaders(resp); ok && !reset.IsZero() { + if wait := time.Until(reset) + time.Second; wait > 0 { + return wait + } + } + + return time.Second +} + +func waitForContext(ctx context.Context, d time.Duration) { + if d <= 0 { + return + } + + timer := time.NewTimer(d) + defer timer.Stop() + + select { + case <-ctx.Done(): + case <-timer.C: + } +} diff --git a/pkg/http/transport/rate_limit_test.go b/pkg/http/transport/rate_limit_test.go new file mode 100644 index 0000000000..8503ff6d11 --- /dev/null +++ b/pkg/http/transport/rate_limit_test.go @@ -0,0 +1,190 @@ +package transport + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strconv" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRateLimitTransport_UpdatesStateFromHeaders(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.Header().Set("X-RateLimit-Remaining", "100") + w.Header().Set("X-RateLimit-Reset", strconv.FormatInt(time.Now().Add(time.Minute).Unix(), 10)) + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + state := NewRateLimitState() + client := &http.Client{Transport: WrapWithRateLimit(server.Client().Transport, state, nil)} + + resp, err := client.Get(server.URL) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + state.mu.Lock() + defer state.mu.Unlock() + assert.Equal(t, 100, state.remaining) + assert.False(t, state.reset.IsZero()) + assert.Equal(t, int32(1), calls.Load()) +} + +func TestRateLimitTransport_WaitsWhenRemainingLow(t *testing.T) { + t.Parallel() + + state := NewRateLimitState() + state.mu.Lock() + state.remaining = 10 + state.reset = time.Now().Add(200 * time.Millisecond) + state.mu.Unlock() + + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + transport := &RateLimitTransport{ + Transport: server.Client().Transport, + State: state, + MinInterval: 0, + MinRemaining: 50, + MaxRetries: 0, + } + client := &http.Client{Transport: transport} + + start := time.Now() + resp, err := client.Get(server.URL) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + assert.GreaterOrEqual(t, time.Since(start), 150*time.Millisecond) +} + +func TestRateLimitTransport_EnforcesMinInterval(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + state := NewRateLimitState() + transport := &RateLimitTransport{ + Transport: server.Client().Transport, + State: state, + MinInterval: 100 * time.Millisecond, + MinRemaining: 0, + MaxRetries: 0, + } + client := &http.Client{Transport: transport} + + start := time.Now() + resp1, err := client.Get(server.URL) + require.NoError(t, err) + require.NoError(t, resp1.Body.Close()) + resp2, err := client.Get(server.URL) + require.NoError(t, err) + require.NoError(t, resp2.Body.Close()) + assert.GreaterOrEqual(t, time.Since(start), 100*time.Millisecond) +} + +func TestRateLimitTransport_RetriesOn429(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + if calls.Add(1) == 1 { + w.Header().Set("Retry-After", "1") + w.WriteHeader(http.StatusTooManyRequests) + return + } + w.WriteHeader(http.StatusOK) + })) + t.Cleanup(server.Close) + + transport := &RateLimitTransport{ + Transport: server.Client().Transport, + State: NewRateLimitState(), + MinInterval: 0, + MinRemaining: 0, + MaxRetries: 1, + } + resp, err := (&http.Client{Transport: transport}).Get(server.URL) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, int32(2), calls.Load()) +} + +func TestRateLimitTransport_DoesNotRetryOther403(t *testing.T) { + t.Parallel() + + var calls atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusForbidden) + })) + t.Cleanup(server.Close) + + transport := &RateLimitTransport{ + Transport: server.Client().Transport, + State: NewRateLimitState(), + MinInterval: 0, + MinRemaining: 0, + MaxRetries: 2, + } + resp, err := (&http.Client{Transport: transport}).Get(server.URL) + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusForbidden, resp.StatusCode) + assert.Equal(t, int32(1), calls.Load()) +} + +func TestRateLimitRegistry_SharesStatePerToken(t *testing.T) { + t.Parallel() + + registry := NewRateLimitRegistry() + stateA1 := registry.Get("token-a") + stateA2 := registry.Get("token-a") + stateB := registry.Get("token-b") + assert.Same(t, stateA1, stateA2) + assert.NotSame(t, stateA1, stateB) +} + +func TestParseRateLimitHeaders(t *testing.T) { + t.Parallel() + + reset := time.Now().Add(time.Minute).Unix() + resp := &http.Response{Header: make(http.Header), Body: io.NopCloser(http.NoBody)} + resp.Header.Set("X-RateLimit-Remaining", "42") + resp.Header.Set("X-RateLimit-Reset", strconv.FormatInt(reset, 10)) + + remaining, resetTime, ok := parseRateLimitHeaders(resp) + require.True(t, ok) + assert.Equal(t, 42, remaining) + assert.Equal(t, time.Unix(reset, 0), resetTime) +} + +func TestWaitForContext_RespectsCancellation(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + start := time.Now() + waitForContext(ctx, time.Second) + assert.Less(t, time.Since(start), 100*time.Millisecond) +}