Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions pkg/ratelimit/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ type rateLimitMiddleware struct {
client redis.UniversalClient
}

// ToolNameResolver resolves the rate-limit tool name from a parsed MCP request.
type ToolNameResolver func(*mcp.ParsedMCPRequest) string

// DefaultToolNameResolver uses the parsed MCP resource ID as the rate-limit tool name.
func DefaultToolNameResolver(parsed *mcp.ParsedMCPRequest) string {
if parsed == nil {
return ""
}
return parsed.ResourceID
}

// Handler returns the middleware function used by the proxy.
func (m *rateLimitMiddleware) Handler() types.MiddlewareFunction {
return m.handler
Expand Down Expand Up @@ -99,16 +110,19 @@ func CreateMiddleware(config *types.MiddlewareConfig, runner types.MiddlewareRun
}

mw := &rateLimitMiddleware{
handler: rateLimitHandler(limiter),
handler: NewMiddleware(limiter, nil),
client: client,
}
runner.AddMiddleware(MiddlewareType, mw)
return nil
}

// rateLimitHandler returns a middleware function that enforces rate limits
// NewMiddleware returns a middleware function that enforces rate limits
// on tools/call requests.
func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
func NewMiddleware(limiter Limiter, resolveToolName ToolNameResolver) types.MiddlewareFunction {
if resolveToolName == nil {
resolveToolName = DefaultToolNameResolver
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Rate limits only apply to parsed tools/call requests.
Expand All @@ -127,7 +141,7 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
if identity, ok := auth.IdentityFromContext(r.Context()); ok {
userID = identity.Subject
}
decision, err := limiter.Allow(r.Context(), parsed.ResourceID, userID)
decision, err := limiter.Allow(r.Context(), resolveToolName(parsed), userID)
if err != nil {
slog.Warn("rate limit check failed, allowing request", "error", err)
next.ServeHTTP(w, r)
Expand All @@ -142,6 +156,11 @@ func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
}
}

// rateLimitHandler returns the default rate-limit middleware used by tests and legacy callers.
func rateLimitHandler(limiter Limiter) types.MiddlewareFunction {
return NewMiddleware(limiter, nil)
}

// writeRateLimited writes an HTTP 429 response with a JSON-RPC error body.
func writeRateLimited(w http.ResponseWriter, requestID any, retryAfter time.Duration) {
retrySeconds := int(math.Ceil(retryAfter.Seconds()))
Expand Down
73 changes: 73 additions & 0 deletions pkg/ratelimit/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@ import (
"testing"
"time"

"github.com/alicebob/miniredis/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"

v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1"
"github.com/stacklok/toolhive/pkg/auth"
"github.com/stacklok/toolhive/pkg/mcp"
transporttypes "github.com/stacklok/toolhive/pkg/transport/types"
transportmocks "github.com/stacklok/toolhive/pkg/transport/types/mocks"
)

// dummyLimiter is a test double for the Limiter interface.
Expand Down Expand Up @@ -208,3 +214,70 @@ func TestRateLimitHandler_NoIdentityPassesEmptyUserID(t *testing.T) {
assert.Equal(t, "echo", recorder.toolName)
assert.Empty(t, recorder.userID, "unauthenticated requests should pass empty userID")
}

func TestDefaultToolNameResolverNilParsedRequest(t *testing.T) {
t.Parallel()

assert.Empty(t, DefaultToolNameResolver(nil))
}

func TestNewMiddlewareUsesCustomToolNameResolver(t *testing.T) {
t.Parallel()

recorder := &recordingLimiter{}
handler := NewMiddleware(recorder, func(*mcp.ParsedMCPRequest) string {
return "resolved-tool"
})(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
}))

req := httptest.NewRequest(http.MethodPost, "/mcp", nil)
req = withParsedMCPRequest(req, "tools/call", "raw-tool", 1)
w := httptest.NewRecorder()

handler.ServeHTTP(w, req)

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "resolved-tool", recorder.toolName)
}

func TestRateLimitMiddlewareHandlerReturnsConfiguredHandler(t *testing.T) {
t.Parallel()

expected := rateLimitHandler(&dummyLimiter{decision: &Decision{Allowed: true}})
mw := &rateLimitMiddleware{handler: expected}

assert.NotNil(t, mw.Handler())
}

func TestCreateMiddlewareRegistersUsableMiddleware(t *testing.T) {
t.Parallel()

mr := miniredis.RunT(t)
cfg, err := transporttypes.NewMiddlewareConfig(MiddlewareType, MiddlewareParams{
Namespace: "default",
ServerName: "server",
RedisAddr: mr.Addr(),
Config: &v1beta1.RateLimitConfig{
Shared: &v1beta1.RateLimitBucket{
MaxTokens: 1,
RefillPeriod: metav1.Duration{Duration: time.Minute},
},
},
})
require.NoError(t, err)

ctrl := gomock.NewController(t)
runner := transportmocks.NewMockMiddlewareRunner(ctrl)
var registered transporttypes.Middleware
runner.EXPECT().
AddMiddleware(MiddlewareType, gomock.AssignableToTypeOf(&rateLimitMiddleware{})).
Do(func(_ string, middleware transporttypes.Middleware) {
registered = middleware
})

require.NoError(t, CreateMiddleware(cfg, runner))
require.NotNil(t, registered)
require.NotNil(t, registered.Handler())
require.NoError(t, registered.Close())
}
10 changes: 10 additions & 0 deletions pkg/vmcp/cli/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error {

serverCfg := &vmcpserver.Config{
Name: vmcpCfg.Name,
Namespace: vmcpNamespace(),
Version: versions.Version,
GroupRef: vmcpCfg.Group,
Host: cfg.Host,
Expand All @@ -386,6 +387,7 @@ func Serve(ctx context.Context, cfg ServeConfig) error {
OptimizerConfig: optCfg,
SessionFactory: sessionFactory,
SessionStorage: vmcpCfg.SessionStorage,
RateLimiting: vmcpCfg.RateLimiting,
}

// Assign Watcher only when backendWatcher is non-nil. A typed nil
Expand Down Expand Up @@ -521,6 +523,14 @@ func generateQuickModeConfig(groupRef string) (*config.Config, error) {
return cfg, nil
}

func vmcpNamespace() string {
namespace := os.Getenv("VMCP_NAMESPACE")
if namespace == "" {
return "local"
}
return namespace
}

// loadAuthServerConfig loads the auth server RunConfig from a sibling file
// alongside the main config. The operator serializes authserver.RunConfig as a
// separate ConfigMap key (authserver-config.yaml).
Expand Down
14 changes: 14 additions & 0 deletions pkg/vmcp/cli/serve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,20 @@ func TestValidateQuickModeHost(t *testing.T) {
}
}

func TestVMCPNamespace(t *testing.T) {
t.Run("defaults to local", func(t *testing.T) {
t.Setenv("VMCP_NAMESPACE", "")

assert.Equal(t, "local", vmcpNamespace())
})

t.Run("uses environment value", func(t *testing.T) {
t.Setenv("VMCP_NAMESPACE", "toolhive-system")

assert.Equal(t, "toolhive-system", vmcpNamespace())
})
}

// TestRunDiscovery_ZeroBackends exercises the branch in runDiscovery where the
// discoverer succeeds but returns no backends. The function must return a
// non-error, an empty (non-nil) backend slice, and pass through the client and
Expand Down
81 changes: 81 additions & 0 deletions pkg/vmcp/server/ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc.
// SPDX-License-Identifier: Apache-2.0

package server

import (
"context"
"fmt"
"net/http"
"os"
"time"

"github.com/redis/go-redis/v9"

mcpparser "github.com/stacklok/toolhive/pkg/mcp"
"github.com/stacklok/toolhive/pkg/ratelimit"
vmcpconfig "github.com/stacklok/toolhive/pkg/vmcp/config"
)

const rateLimitRedisPingTimeout = 5 * time.Second

func (s *Server) buildRateLimitMiddleware(
ctx context.Context,
) (func(http.Handler) http.Handler, func(context.Context) error, error) {
if s.config.RateLimiting == nil {
return nil, nil, nil
}
if s.config.SessionStorage == nil || s.config.SessionStorage.Provider != "redis" {
return nil, nil, fmt.Errorf("rate limiting requires Redis session storage")
}
if s.config.SessionStorage.Address == "" {
return nil, nil, fmt.Errorf("rate limiting requires Redis session storage address")
}

client := redis.NewClient(&redis.Options{
Addr: s.config.SessionStorage.Address,
DB: int(s.config.SessionStorage.DB),
Password: os.Getenv(vmcpconfig.RedisPasswordEnvVar),
})

pingCtx, cancel := context.WithTimeout(ctx, rateLimitRedisPingTimeout)
defer cancel()
if err := client.Ping(pingCtx).Err(); err != nil {
_ = client.Close()
return nil, nil, fmt.Errorf("rate limit middleware: failed to connect to Redis at %s: %w",
s.config.SessionStorage.Address, err)
}

limiter, err := ratelimit.NewLimiter(client, s.config.Namespace, s.config.Name, s.config.RateLimiting)
if err != nil {
_ = client.Close()
return nil, nil, fmt.Errorf("failed to create rate limiter: %w", err)
}

cleanup := func(context.Context) error {
return client.Close()
}
return ratelimit.NewMiddleware(limiter, s.rateLimitToolName), cleanup, nil
}

func (s *Server) rateLimitToolName(parsed *mcpparser.ParsedMCPRequest) string {
if parsed == nil {
return ""
}
toolName := parsed.ResourceID
if !s.optimizerEnabled() || toolName != "call_tool" {
return toolName
}
if parsed.Arguments == nil {
return toolName
}
innerToolName, ok := parsed.Arguments["tool_name"].(string)
if !ok || innerToolName == "" {
return toolName
}
return innerToolName
}

func (s *Server) optimizerEnabled() bool {
return s.config.OptimizerConfig != nil || s.config.OptimizerFactory != nil
}
Loading
Loading