Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions cmd/thv/app/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/stacklok/toolhive/pkg/oauthproto/tokenexchange"
"github.com/stacklok/toolhive/pkg/transport"
"github.com/stacklok/toolhive/pkg/transport/middleware"
"github.com/stacklok/toolhive/pkg/transport/middleware/origin"
"github.com/stacklok/toolhive/pkg/transport/proxy/transparent"
"github.com/stacklok/toolhive/pkg/transport/types"
)
Expand Down Expand Up @@ -110,9 +111,10 @@ Dynamic client registration (automatic OAuth client setup):
}

var (
proxyHost string
proxyPort int
proxyTargetURI string
proxyHost string
proxyPort int
proxyTargetURI string
proxyAllowedOrigins []string

resourceURL string // Explicit resource URL for OAuth discovery endpoint (RFC 9728)

Expand All @@ -133,6 +135,10 @@ const (
func init() {
proxyCmd.Flags().StringVar(&proxyHost, "host", transport.LocalhostIPv4, "Host for the HTTP proxy to listen on (IP or hostname)")
proxyCmd.Flags().IntVar(&proxyPort, "port", 0, "Port for the HTTP proxy to listen on (host port)")
proxyCmd.Flags().StringArrayVar(&proxyAllowedOrigins, "allowed-origins", nil,
"Exact-match allowlist for the HTTP Origin header (repeatable). Recommended when binding publicly; "+
"loopback binds derive a default allowlist automatically, non-loopback binds log a warning when "+
"no value is supplied. Example: https://my-mcp.example.com")
proxyCmd.Flags().StringVar(
&proxyTargetURI,
"target-uri",
Expand Down Expand Up @@ -226,6 +232,22 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
// Create middlewares slice for incoming request authentication
var middlewares []types.NamedMiddleware

// Origin-header validation (DNS-rebinding protection per MCP 2025-11-25
// §"Security Warning"). Added first so disallowed Origins are rejected
// before authentication or any outbound token acquisition runs.
if allowed := origin.ResolveAllowedOrigins(proxyHost, port, proxyAllowedOrigins); len(allowed) > 0 {
middlewares = append(middlewares, types.NamedMiddleware{
Name: origin.MiddlewareType,
Function: origin.NewHandler(allowed),
})
} else {
slog.Warn("Origin validation disabled — no allowlist configured for non-loopback bind",
"host", proxyHost,
"port", port,
"hint", "pass --allowed-origins=https://your-client.example to enable DNS-rebind protection",
)
}

// Get OIDC configuration if enabled (for protecting the proxy endpoint)
oidcConfig := getProxyOIDCConfig(cmd)

Expand Down
11 changes: 11 additions & 0 deletions cmd/thv/app/run_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,12 @@ type RunFlags struct {
RemoteForwardHeaders []string
RemoteForwardHeadersSecret []string

// AllowedOrigins is the HTTP Origin-header allowlist for DNS-rebinding protection
// (MCP 2025-11-25 §"Security Warning"). Empty with a loopback host auto-derives
// loopback-only defaults; empty with a non-loopback host disables the check
// (operator must supply explicit origins for public bind).
AllowedOrigins []string

// Runtime configuration
RuntimeImage string
RuntimeAddPackages []string
Expand All @@ -160,6 +166,10 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
cmd.Flags().StringVar(&config.Name, "name", "", "Name of the MCP server (default to auto-generated from image)")
cmd.Flags().StringVar(&config.Group, "group", "default", "Name of the group this workload should belong to")
cmd.Flags().StringVar(&config.Host, "host", transport.LocalhostIPv4, "Host for the HTTP proxy to listen on (IP or hostname)")
cmd.Flags().StringArrayVar(&config.AllowedOrigins, "allowed-origins", nil,
"Exact-match allowlist for the HTTP Origin header (repeatable). Recommended when binding publicly; "+
"loopback binds derive a default allowlist automatically, non-loopback binds log a warning when "+
"no value is supplied. Example: https://my-mcp.example.com")
cmd.Flags().IntVar(&config.ProxyPort, "proxy-port", 0, "Port for the HTTP proxy to listen on (host port)")
cmd.Flags().IntVar(&config.TargetPort, "target-port", 0,
"Port for the container to expose (only applicable to SSE or Streamable HTTP transport)")
Expand Down Expand Up @@ -685,6 +695,7 @@ func buildRunnerConfig(
PrintOverlays: runFlags.PrintOverlays,
}),
runner.WithPublish(runFlags.Publish),
runner.WithAllowedOrigins(runFlags.AllowedOrigins),
}
opts = append(opts, extraOpts...)

Expand Down
1 change: 1 addition & 0 deletions docs/cli/thv_proxy.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions docs/cli/thv_run.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions docs/server/docs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions docs/server/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions docs/server/swagger.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions pkg/runner/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ type RunConfig struct {
// TargetHost is the host to forward traffic to (only applicable to SSE transport)
TargetHost string `json:"target_host,omitempty" yaml:"target_host,omitempty"`

// AllowedOrigins is the allowlist of values accepted on the HTTP Origin header,
// used for DNS-rebinding protection per MCP 2025-11-25 §"Security Warning".
// When empty and Host is loopback (127.0.0.1 / localhost / [::1]), a default
// loopback-only allowlist is derived at middleware-wiring time.
// When empty and Host is non-loopback, the middleware is disabled — operators
// exposing the proxy publicly must configure an explicit allowlist.
AllowedOrigins []string `json:"allowed_origins,omitempty" yaml:"allowed_origins,omitempty"`
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The operator path uses PopulateMiddlewareConfigs so the factory side is wired, but MCPServerSpec / MCPRemoteProxySpec / VirtualMCPServerSpec have no AllowedOrigins field. Combined with operator-deployed pods binding to non-loopback addresses, ResolveAllowedOrigins returns nil and addOriginMiddleware skips registration with a WARN — so K8s deployments ship with Origin validation disabled.

Is this expected, planned for a follow-up PR, or considered out of scope for the CRDs? If a follow-up, would it be worth a // TODO or a note in the PR description so it's tracked?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deferring CRD wiring to a follow-up to keep this PR focused on CLI/proxyrunner. Documented explicitly in the commit message and in the prependOriginMiddleware doc comment: operator non-loopback pods log the WARN rather than enforcing until the CRD field lands. Tracked in a follow-up issue: "Add allowedOrigins to MCPServer/MCPRemoteProxy/VirtualMCPServer CRDs + operator wiring."


// Publish lists ports to publish to the host in format "hostPort:containerPort"
Publish []string `json:"publish,omitempty" yaml:"publish,omitempty"`

Expand Down
12 changes: 12 additions & 0 deletions pkg/runner/config_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,18 @@ func WithAllowDockerGateway(allow bool) RunConfigBuilderOption {
}
}

// WithAllowedOrigins sets the HTTP Origin-header allowlist used for
// DNS-rebinding protection (MCP 2025-11-25 §"Security Warning").
// An empty slice defers the choice to middleware wiring, which derives a
// loopback-only default when the bind host is loopback and otherwise leaves
// the middleware disabled.
func WithAllowedOrigins(origins []string) RunConfigBuilderOption {
return func(b *runConfigBuilder) error {
b.config.AllowedOrigins = origins
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new addOriginMiddleware is added to PopulateMiddlewareConfigs, but WithMiddlewareFromFlags / addCoreMiddlewares (the path used by thv run) doesn't include it. Since runner.Run skips PopulateMiddlewareConfigs when MiddlewareConfigs is pre-populated (runner.go:232), thv run --allowed-origins=... plumbs the flag into RunConfig.AllowedOrigins — which is exactly what this builder option does — but the middleware never registers at runtime.

Is this intentional, or an omission? If intentional, what's the rationale for excluding thv run from the protection?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — and the fix needed to go deeper than moving the call into WithMiddlewareFromFlags. The builder resolves the effective port in validateConfig (via WithPorts), which runs after all options, so b.config.Port is still 0 during WithMiddlewareFromFlagsResolveAllowedOrigins would have returned nil and silently disabled loopback defaults for thv run. Fixed by wiring Origin validation centrally in runner.Run (new prependOriginMiddleware), after both population paths, where Host/Port/AllowedOrigins are fully resolved. Removed it from PopulateMiddlewareConfigs so there's one wiring site, and prepended it so it still runs first (rejects before auth). Added TestPrependOriginMiddleware covering the thv run path.

return nil
}
}

// WithTrustProxyHeaders sets whether to trust X-Forwarded-* headers from reverse proxies
func WithTrustProxyHeaders(trust bool) RunConfigBuilderOption {
return func(b *runConfigBuilder) error {
Expand Down
56 changes: 52 additions & 4 deletions pkg/runner/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package runner

import (
"fmt"
"log/slog"

"github.com/stacklok/toolhive/pkg/audit"
"github.com/stacklok/toolhive/pkg/auth"
Expand All @@ -21,6 +22,7 @@ import (
"github.com/stacklok/toolhive/pkg/recovery"
"github.com/stacklok/toolhive/pkg/telemetry"
headerfwd "github.com/stacklok/toolhive/pkg/transport/middleware"
"github.com/stacklok/toolhive/pkg/transport/middleware/origin"
"github.com/stacklok/toolhive/pkg/transport/types"
"github.com/stacklok/toolhive/pkg/usagemetrics"
"github.com/stacklok/toolhive/pkg/webhook/mutating"
Expand All @@ -45,6 +47,7 @@ func GetSupportedMiddlewareFactories() map[string]types.MiddlewareFactory {
audit.MiddlewareType: audit.CreateMiddleware,
recovery.MiddlewareType: recovery.CreateMiddleware,
headerfwd.HeaderForwardMiddlewareName: headerfwd.CreateMiddleware,
origin.MiddlewareType: origin.CreateMiddleware,
validating.MiddlewareType: validating.CreateMiddleware,
mutating.MiddlewareType: mutating.CreateMiddleware,
}
Expand All @@ -57,22 +60,28 @@ func GetSupportedMiddlewareFactories() map[string]types.MiddlewareFactory {
func PopulateMiddlewareConfigs(config *RunConfig) error {
var middlewareConfigs []types.MiddlewareConfig
// TODO: Consider extracting other middleware setup into helper functions like addUsageMetricsMiddleware
//
// NOTE: Origin-validation middleware is intentionally NOT added here. It is
// wired centrally in runner.Run (via prependOriginMiddleware) for both the
// operator/proxyrunner path (this function) and the CLI path
// (WithMiddlewareFromFlags), because that is the only place where the
// effective Host/Port/AllowedOrigins are fully resolved.

// Authentication middleware (always present)
authParams := auth.MiddlewareParams{
OIDCConfig: config.OIDCConfig,
}
authConfig, err := types.NewMiddlewareConfig(auth.MiddlewareType, authParams)
if err != nil {
return fmt.Errorf("failed to create auth middleware config: %w", err)
authConfig, authErr := types.NewMiddlewareConfig(auth.MiddlewareType, authParams)
if authErr != nil {
return fmt.Errorf("failed to create auth middleware config: %w", authErr)
}
middlewareConfigs = append(middlewareConfigs, *authConfig)

// Upstream swap middleware (if embedded auth server is configured)
// This exchanges ToolHive JWTs for upstream IdP tokens when embedded auth server is used.
// IMPORTANT: Must run BEFORE token exchange middleware so it can read the `tsid` claim
// from the original ToolHive JWT before any token modification occurs.
middlewareConfigs, err = addUpstreamSwapMiddleware(middlewareConfigs, config)
middlewareConfigs, err := addUpstreamSwapMiddleware(middlewareConfigs, config)
if err != nil {
return err
}
Expand Down Expand Up @@ -421,6 +430,45 @@ func addAWSStsMiddleware(middlewares []types.MiddlewareConfig, config *RunConfig
return append(middlewares, *awsStsMwConfig), nil
}

// prependOriginMiddleware prepends Origin-header validation middleware for
// DNS-rebind protection per MCP 2025-11-25 §"Security Warning". It is placed at
// the front of the chain so disallowed Origin values are rejected before
// authentication or any business logic runs. Default-derivation logic lives in
// origin.ResolveAllowedOrigins so the standalone `thv proxy` command and the
// runner path agree on behavior.
//
// This is called from runner.Run after both middleware-population paths
// (PopulateMiddlewareConfigs and WithMiddlewareFromFlags) have run, because
// that is the only point where the effective Host/Port/AllowedOrigins are
// fully resolved — the CLI builder defers port resolution to validateConfig.
//
// When the effective allowlist is empty — which happens when the operator
// binds to a non-loopback host without supplying --allowed-origins — the
// middleware is skipped entirely and a WARN is logged so the security-disabled
// state is visible in operator logs. A follow-up PR hardens the non-loopback
// path by requiring an explicit opt-in flag (see audit row 22).
func prependOriginMiddleware(middlewares []types.MiddlewareConfig, config *RunConfig) ([]types.MiddlewareConfig, error) {
allowed := origin.ResolveAllowedOrigins(config.Host, config.Port, config.AllowedOrigins)
if len(allowed) == 0 {
slog.Warn("Origin validation disabled — no allowlist configured for non-loopback bind",
"host", config.Host,
"port", config.Port,
"hint", "pass --allowed-origins=https://your-client.example to enable DNS-rebind protection",
)
return middlewares, nil
}

params := origin.MiddlewareParams{AllowedOrigins: allowed}
mwCfg, err := types.NewMiddlewareConfig(origin.MiddlewareType, params)
if err != nil {
return nil, fmt.Errorf("failed to create origin middleware config: %w", err)
}
// Prepend so Origin validation is the outermost wrapper (runs first at
// request time). Build a new slice to avoid mutating the caller's backing
// array.
return append([]types.MiddlewareConfig{*mwCfg}, middlewares...), nil
}

// addRateLimitMiddleware adds rate limit middleware if configured.
func addRateLimitMiddleware(middlewares []types.MiddlewareConfig, config *RunConfig) ([]types.MiddlewareConfig, error) {
if config.RateLimitConfig == nil {
Expand Down
60 changes: 60 additions & 0 deletions pkg/runner/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/stacklok/toolhive/pkg/recovery"
"github.com/stacklok/toolhive/pkg/telemetry"
headerfwd "github.com/stacklok/toolhive/pkg/transport/middleware"
"github.com/stacklok/toolhive/pkg/transport/middleware/origin"
"github.com/stacklok/toolhive/pkg/transport/types"
"github.com/stacklok/toolhive/pkg/webhook"
"github.com/stacklok/toolhive/pkg/webhook/mutating"
Expand Down Expand Up @@ -123,6 +124,65 @@ func TestAddHeaderForwardMiddleware(t *testing.T) {
}
}

func TestPrependOriginMiddleware(t *testing.T) {
t.Parallel()

tests := []struct {
name string
config *RunConfig
wantPrepended bool
wantAllowedCount int
}{
{
name: "non-loopback bind without explicit allowlist skips middleware",
config: &RunConfig{Host: "0.0.0.0", Port: 8080},
wantPrepended: false,
},
{
name: "zero port skips middleware",
config: &RunConfig{Host: "127.0.0.1", Port: 0},
wantPrepended: false,
},
{
name: "loopback bind derives default allowlist and prepends",
config: &RunConfig{Host: "127.0.0.1", Port: 8080},
wantPrepended: true,
wantAllowedCount: 3, // localhost + 127.0.0.1 + [::1]
},
{
name: "explicit allowlist on non-loopback bind prepends",
config: &RunConfig{Host: "0.0.0.0", Port: 8080, AllowedOrigins: []string{"https://app.example.com"}},
wantPrepended: true,
wantAllowedCount: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

// Seed with an existing entry so we can prove origin is prepended,
// not appended — the security intent requires it to run first.
initial := []types.MiddlewareConfig{{Type: auth.MiddlewareType}}
got, err := prependOriginMiddleware(initial, tt.config)
require.NoError(t, err)

if !tt.wantPrepended {
assert.Equal(t, initial, got, "middleware slice should be unchanged")
return
}

require.Len(t, got, len(initial)+1)
assert.Equal(t, origin.MiddlewareType, got[0].Type, "origin middleware must be first in the chain")
assert.Equal(t, auth.MiddlewareType, got[1].Type, "pre-existing middleware must follow origin")

var params origin.MiddlewareParams
require.NoError(t, json.Unmarshal(got[0].Parameters, &params))
assert.Len(t, params.AllowedOrigins, tt.wantAllowedCount)
})
}
}

func TestPopulateMiddlewareConfigs_HeaderForward(t *testing.T) {
t.Parallel()

Expand Down
14 changes: 14 additions & 0 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,20 @@ func (r *Runner) Run(ctx context.Context) error {
}
}

// Origin-header validation (DNS-rebinding protection per MCP 2025-11-25
// §"Security Warning") is wired here, after both middleware-population
// paths, because it is the single place where Host/Port/AllowedOrigins are
// fully resolved: the CLI builder (WithMiddlewareFromFlags) defers port
// resolution to validateConfig, so the effective port is not known at
// builder time. Prepending keeps Origin validation at the front of the
// chain so disallowed Origins are rejected before authentication or any
// business logic runs.
var err error
r.Config.MiddlewareConfigs, err = prependOriginMiddleware(r.Config.MiddlewareConfigs, r.Config)
if err != nil {
return fmt.Errorf("failed to add origin middleware: %w", err)
}

// Initialize embedded auth server if configured.
// This must happen before middleware creation so that the upstream token
// service is available to middleware factories (e.g., upstreamswap).
Expand Down
Loading
Loading