From 034b8ee58e4f1b8ac049ee6738de0dd917bfeb10 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Tue, 28 Apr 2026 15:49:34 +0500 Subject: [PATCH 01/12] Support CIMD as preferred OAuth client registration for thv run When a remote authorization server advertises client_id_metadata_document_supported in its discovery document, thv run now presents https://toolhive.dev/oauth/client-metadata.json as its client_id instead of performing a DCR round-trip. Falls back to DCR gracefully if the AS rejects the CIMD client_id. The CIMD check runs inside PerformOAuthFlow before the DCR gate so it works regardless of which issuer discovery path was taken (configured issuer, realm-derived, or resource metadata). Includes hack/mock-cimd-server for local E2E testing. Closes #4826 Co-Authored-By: Claude Sonnet 4.6 (1M context) --- hack/mock-cimd-server/main.go | 211 ++++++++++++++++++++++++++++++++ pkg/auth/discovery/discovery.go | 31 +++-- pkg/auth/remote/config.go | 10 ++ pkg/auth/remote/config_test.go | 31 +++++ pkg/auth/remote/handler.go | 27 +++- pkg/auth/remote/handler_test.go | 84 +++++++++++++ pkg/oauthproto/cimd.go | 20 +++ pkg/oauthproto/cimd_test.go | 44 +++++++ pkg/oauthproto/discovery.go | 4 + toolhive-client-metadata.json | 12 ++ 10 files changed, 464 insertions(+), 10 deletions(-) create mode 100644 hack/mock-cimd-server/main.go create mode 100644 pkg/oauthproto/cimd.go create mode 100644 pkg/oauthproto/cimd_test.go create mode 100644 toolhive-client-metadata.json diff --git a/hack/mock-cimd-server/main.go b/hack/mock-cimd-server/main.go new file mode 100644 index 0000000000..bcb20260a5 --- /dev/null +++ b/hack/mock-cimd-server/main.go @@ -0,0 +1,211 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +// mock-cimd-server runs a minimal OAuth AS and MCP server for manual CIMD testing. +// +// It starts two HTTP servers: +// - :9000 mock Authorization Server that advertises client_id_metadata_document_supported +// - :9001 mock MCP server that returns 401 pointing at the AS +// +// Usage: +// +// go run ./hack/mock-cimd-server/ +// thv run --transport=http --url=http://localhost:9001 test-server +package main + +import ( + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "log/slog" + "net/http" + "net/url" + "os" + "strings" + "time" +) + +const ( + asPort = "9000" + mcpPort = "9001" +) + +func main() { + slog.SetDefault(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))) + + go runAS() + go runMCPServer() + + slog.Info("Mock servers started", + "as", "http://localhost:"+asPort, + "mcp", "http://localhost:"+mcpPort, + ) + slog.Info("Run thv with:", + "cmd", "thv run --transport=http --url=http://localhost:"+mcpPort+" test-server", + ) + + select {} // block forever +} + +// runAS starts the mock Authorization Server on :9000. +func runAS() { + mux := http.NewServeMux() + + // Discovery document — advertises CIMD support + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + slog.Info("Discovery request", "method", r.Method) + issuer := "http://localhost:" + asPort + doc := map[string]any{ + "issuer": issuer, + "authorization_endpoint": issuer + "/oauth/authorize", + "token_endpoint": issuer + "/oauth/token", + "registration_endpoint": issuer + "/oauth/register", + "jwks_uri": issuer + "/.well-known/jwks.json", + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code", "refresh_token"}, + "code_challenge_methods_supported": []string{"S256"}, + "client_id_metadata_document_supported": true, // ← the key field + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) + }) + + // RFC 8414 fallback + mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/.well-known/openid-configuration", http.StatusMovedPermanently) + }) + + // Authorize endpoint — auto-completes immediately by redirecting back with a code. + // It also fires the callback itself so thv run can complete without a browser. + mux.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) { + clientID := r.URL.Query().Get("client_id") + redirectURI := r.URL.Query().Get("redirect_uri") + state := r.URL.Query().Get("state") + + if strings.HasPrefix(clientID, "https://") { + slog.Info("✅ CIMD client_id detected — no DCR needed", "client_id", clientID) + } else { + slog.Warn("⚠️ DCR-issued client_id used (CIMD was not triggered)", "client_id", clientID) + } + + code := randomString(16) + slog.Debug("Issuing authorization code", "code", code, "redirect_uri", redirectURI) + + callback, err := url.Parse(redirectURI) + if err != nil || (callback.Hostname() != "localhost" && callback.Hostname() != "127.0.0.1") { + slog.Error("Refusing to auto-complete: redirect_uri is not a localhost URL", "redirect_uri", redirectURI) + http.Error(w, "invalid redirect_uri", http.StatusBadRequest) + return + } + q := callback.Query() + q.Set("code", code) + q.Set("state", state) + callback.RawQuery = q.Encode() + callbackURL := callback.String() + + // Deliver the code directly to thv's callback server (skip browser). + // Safe: redirect_uri is validated to be localhost above. + go func() { + time.Sleep(300 * time.Millisecond) + resp, err := http.Get(callbackURL) //nolint:noctx + if err != nil { + slog.Error("Failed to auto-complete callback", "err", err) + return + } + defer resp.Body.Close() + slog.Info("✅ Auto-completed OAuth callback", "status", resp.StatusCode) + }() + + w.WriteHeader(http.StatusOK) + _, _ = fmt.Fprintf(w, "Authorization complete — code delivered to %s", redirectURI) + }) + + // Token endpoint — returns a minimal access token + mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) { + slog.Info("Token exchange", "grant_type", r.FormValue("grant_type"), "client_id", r.FormValue("client_id")) + resp := map[string]any{ + "access_token": "mock-access-token-" + randomString(8), + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "mock-refresh-token-" + randomString(8), + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + }) + + // DCR endpoint — logs a warning if reached (CIMD should prevent this) + mux.HandleFunc("/oauth/register", func(w http.ResponseWriter, r *http.Request) { + slog.Warn("⚠️ DCR called — CIMD was NOT used or was rejected") + resp := map[string]any{ + "client_id": "dcr-fallback-" + randomString(8), + "client_secret": "", + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(resp) + }) + + // Minimal JWKS + mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"keys":[]}`)) + }) + + slog.Info("Mock AS listening", "addr", ":"+asPort) + srv := &http.Server{Addr: ":" + asPort, Handler: mux, ReadHeaderTimeout: 10 * time.Second} + if err := srv.ListenAndServe(); err != nil { + slog.Error("AS server error", "err", err) + os.Exit(1) + } +} + +// runMCPServer starts a minimal MCP server on :9001 that demands OAuth. +func runMCPServer() { + mux := http.NewServeMux() + + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + auth := r.Header.Get("Authorization") + if auth == "" { + // Return 401 with WWW-Authenticate pointing at our AS + w.Header().Set("WWW-Authenticate", + fmt.Sprintf(`Bearer realm="http://localhost:%s",resource_metadata="http://localhost:%s/.well-known/mcp-resource"`, + asPort, mcpPort)) + w.WriteHeader(http.StatusUnauthorized) + return + } + slog.Info("✅ Authenticated MCP request received", "auth", auth[:min(len(auth), 30)]+"...") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","result":{"protocolVersion":"2025-11-05","capabilities":{}},"id":1}`)) + }) + + // RFC 9728 resource metadata pointing at our AS + mux.HandleFunc("/.well-known/mcp-resource", func(w http.ResponseWriter, _ *http.Request) { + meta := map[string]any{ + "resource": "http://localhost:" + mcpPort, + "authorization_servers": []string{"http://localhost:" + asPort}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(meta) + }) + + slog.Info("Mock MCP server listening", "addr", ":"+mcpPort) + srv := &http.Server{Addr: ":" + mcpPort, Handler: mux, ReadHeaderTimeout: 10 * time.Second} + if err := srv.ListenAndServe(); err != nil { + slog.Error("MCP server error", "err", err) + os.Exit(1) + } +} + +func randomString(n int) string { + b := make([]byte, n) + _, _ = rand.Read(b) + return base64.RawURLEncoding.EncodeToString(b)[:n] +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index 748875fd8d..6cb6b94615 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -53,10 +53,11 @@ type AuthInfo struct { // AuthServerInfo contains information about a validated authorization server type AuthServerInfo struct { - Issuer string - AuthorizationURL string - TokenURL string - RegistrationEndpoint string + Issuer string + AuthorizationURL string + TokenURL string + RegistrationEndpoint string + ClientIDMetadataDocumentSupported bool } // Config holds configuration for authentication discovery @@ -539,6 +540,19 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi return nil, fmt.Errorf("OAuth flow config cannot be nil") } + // Before resolving ports or attempting DCR, check whether the AS advertises CIMD + // support. This handles issuer discovery paths (configured issuer, realm-derived) + // that return without fetching the AS discovery document, so the CIMD flag would + // otherwise never be seen. + if shouldDynamicallyRegisterClient(config) { + if doc, err := getDiscoveryDocument(ctx, issuer, config); err == nil && + doc != nil && doc.ClientIDMetadataDocumentSupported { + config.ClientID = oauthproto.ToolHiveClientMetadataDocumentURL + slog.Debug("AS supports CIMD, using metadata URL as client_id", + "url", oauthproto.ToolHiveClientMetadataDocumentURL) + } + } + // Resolve port availability BEFORE dynamic registration // This ensures we register the OAuth client with the same port we'll actually use @@ -836,10 +850,11 @@ func ValidateAndDiscoverAuthServer(ctx context.Context, potentialIssuer string) } return &AuthServerInfo{ - Issuer: doc.Issuer, - AuthorizationURL: doc.AuthorizationEndpoint, - TokenURL: doc.TokenEndpoint, - RegistrationEndpoint: doc.RegistrationEndpoint, + Issuer: doc.Issuer, + AuthorizationURL: doc.AuthorizationEndpoint, + TokenURL: doc.TokenEndpoint, + RegistrationEndpoint: doc.RegistrationEndpoint, + ClientIDMetadataDocumentSupported: doc.ClientIDMetadataDocumentSupported, }, nil } diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index 98c8fecd2f..dcf4893774 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -71,6 +71,11 @@ type Config struct { // RegistrationAccessToken is used to update/delete the client registration. // Stored as a secret reference since it's sensitive. CachedRegTokenRef string `json:"cached_reg_token_ref,omitempty" yaml:"cached_reg_token_ref,omitempty"` + + // CachedCIMDClientID stores the CIMD URL used as client_id when CIMD was used + // for authentication. Non-sensitive — it is a public URL. Set to distinguish + // a CIMD-sourced client_id from a DCR-issued one across restarts. + CachedCIMDClientID string `json:"cached_cimd_client_id,omitempty" yaml:"cached_cimd_client_id,omitempty"` } // BearerTokenEnvVarName is the environment variable name used for bearer token authentication. @@ -164,6 +169,11 @@ func (c *Config) HasCachedClientCredentials() bool { return c.CachedClientID != "" } +// HasCachedCIMDClientID returns true if a CIMD client_id was cached from a prior session. +func (c *Config) HasCachedCIMDClientID() bool { + return c.CachedCIMDClientID != "" +} + // ClearCachedClientCredentials removes any cached DCR client credential references from the config. func (c *Config) ClearCachedClientCredentials() { c.CachedClientID = "" diff --git a/pkg/auth/remote/config_test.go b/pkg/auth/remote/config_test.go index c6c014ed26..186330a58f 100644 --- a/pkg/auth/remote/config_test.go +++ b/pkg/auth/remote/config_test.go @@ -204,6 +204,37 @@ func TestConfig_HasCachedClientCredentials(t *testing.T) { } } +func TestConfig_HasCachedCIMDClientID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config Config + expected bool + }{ + { + name: "no cached CIMD client_id", + config: Config{}, + expected: false, + }, + { + name: "has cached CIMD client_id", + config: Config{ + CachedCIMDClientID: "https://toolhive.dev/oauth/client-metadata.json", + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := tt.config.HasCachedCIMDClientID() + assert.Equal(t, tt.expected, result) + }) + } +} + func TestConfig_ClearCachedClientCredentials(t *testing.T) { t.Parallel() diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index ccce3fb17b..dd53100d5a 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -11,6 +11,7 @@ import ( "golang.org/x/oauth2" "github.com/stacklok/toolhive/pkg/auth/discovery" + "github.com/stacklok/toolhive/pkg/oauthproto" "github.com/stacklok/toolhive/pkg/secrets" ) @@ -131,15 +132,19 @@ func (h *Handler) performOAuthFlow( ) (oauth2.TokenSource, error) { slog.Debug("Starting OAuth authentication flow", "issuer", issuer) - // Create OAuth flow config flowConfig := h.buildOAuthFlowConfig(scopes, authServerInfo) + cimdUsed := flowConfig.ClientID == oauthproto.ToolHiveClientMetadataDocumentURL result, err := discovery.PerformOAuthFlow(ctx, issuer, flowConfig) + if err != nil && cimdUsed { + slog.Warn("CIMD client_id rejected by authorization server, retrying with DCR", "error", err) + flowConfig.ClientID = "" + result, err = discovery.PerformOAuthFlow(ctx, issuer, flowConfig) + } if err != nil { return nil, err } - // Persist and wrap the token source return h.wrapWithPersistence(result), nil } @@ -171,6 +176,17 @@ func (h *Handler) buildOAuthFlowConfig(scopes []string, authServerInfo *discover "registration", authServerInfo.RegistrationEndpoint) } + // CIMD: if the AS supports CIMD and we have no pre-configured or cached credentials, + // use ToolHive's metadata URL as client_id. Setting it here prevents + // shouldDynamicallyRegisterClient from firing (it checks ClientID == ""). + if authServerInfo != nil && + authServerInfo.ClientIDMetadataDocumentSupported && + flowConfig.ClientID == "" && + flowConfig.ClientSecret == "" { + flowConfig.ClientID = oauthproto.ToolHiveClientMetadataDocumentURL + slog.Debug("Using CIMD client_id", "url", oauthproto.ToolHiveClientMetadataDocumentURL) + } + return flowConfig } @@ -195,6 +211,13 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. } } + // If CIMD was used (client_id is the metadata URL), persist it separately + // so it can be distinguished from a DCR-issued client_id on restart. + if oauthproto.IsClientIDMetadataDocumentURL(result.ClientID) { + h.config.CachedCIMDClientID = result.ClientID + slog.Debug("CIMD client_id used, cached for reference", "url", result.ClientID) + } + // Wrap the token source to persist refreshed tokens tokenSource := result.TokenSource if h.tokenPersister != nil { diff --git a/pkg/auth/remote/handler_test.go b/pkg/auth/remote/handler_test.go index d68b465ed7..aee9961b01 100644 --- a/pkg/auth/remote/handler_test.go +++ b/pkg/auth/remote/handler_test.go @@ -16,6 +16,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/auth/discovery" + "github.com/stacklok/toolhive/pkg/oauthproto" ) const ( @@ -797,6 +798,89 @@ func TestAuthenticate_BearerTokenPriority(t *testing.T) { assert.Equal(t, "Bearer", token.TokenType) } +// TestBuildOAuthFlowConfig_CIMD tests that buildOAuthFlowConfig sets the CIMD client_id +// when the authorization server advertises CIMD support and no credentials are configured. +func TestBuildOAuthFlowConfig_CIMD(t *testing.T) { + t.Parallel() + + cimdURL := oauthproto.ToolHiveClientMetadataDocumentURL + + tests := []struct { + name string + config *Config + authServerInfo *discovery.AuthServerInfo + wantClientID string + wantNotClientID string + }{ + { + name: "CIMD set when AS supports it and no credentials configured", + config: &Config{}, + authServerInfo: &discovery.AuthServerInfo{ + AuthorizationURL: "https://as.example.com/authorize", + TokenURL: "https://as.example.com/token", + ClientIDMetadataDocumentSupported: true, + }, + wantClientID: cimdURL, + }, + { + name: "CIMD not set when ClientID already configured", + config: &Config{ + ClientID: "pre-configured-client-id", + }, + authServerInfo: &discovery.AuthServerInfo{ + AuthorizationURL: "https://as.example.com/authorize", + TokenURL: "https://as.example.com/token", + ClientIDMetadataDocumentSupported: true, + }, + wantClientID: "pre-configured-client-id", + }, + { + name: "CIMD not set when ClientSecret already configured", + config: &Config{ + ClientSecret: "some-secret", + }, + authServerInfo: &discovery.AuthServerInfo{ + AuthorizationURL: "https://as.example.com/authorize", + TokenURL: "https://as.example.com/token", + ClientIDMetadataDocumentSupported: true, + }, + wantNotClientID: cimdURL, + }, + { + name: "CIMD not set when AS does not advertise support", + config: &Config{}, + authServerInfo: &discovery.AuthServerInfo{ + AuthorizationURL: "https://as.example.com/authorize", + TokenURL: "https://as.example.com/token", + ClientIDMetadataDocumentSupported: false, + }, + wantNotClientID: cimdURL, + }, + { + name: "CIMD not set when authServerInfo is nil", + config: &Config{}, + authServerInfo: nil, + wantNotClientID: cimdURL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + handler := &Handler{config: tt.config} + flowConfig := handler.buildOAuthFlowConfig(nil, tt.authServerInfo) + + if tt.wantClientID != "" { + assert.Equal(t, tt.wantClientID, flowConfig.ClientID) + } + if tt.wantNotClientID != "" { + assert.NotEqual(t, tt.wantNotClientID, flowConfig.ClientID) + } + }) + } +} + // TestAuthenticate_BearerTokenDiscovery tests that bearer token discovery works correctly func TestAuthenticate_BearerTokenDiscovery(t *testing.T) { t.Parallel() diff --git a/pkg/oauthproto/cimd.go b/pkg/oauthproto/cimd.go new file mode 100644 index 0000000000..c7dfd822ce --- /dev/null +++ b/pkg/oauthproto/cimd.go @@ -0,0 +1,20 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package oauthproto + +import "strings" + +// ToolHiveClientMetadataDocumentURL is the stable HTTPS URL where ToolHive's +// client metadata document is hosted. This URL is the client_id ToolHive +// presents to remote authorization servers that support CIMD. +// ToolHiveClientMetadataDocumentURL is the stable HTTPS URL where ToolHive's +// client metadata document is hosted. This URL must be live and serving +// toolhive-client-metadata.json before this feature can be used in production. +const ToolHiveClientMetadataDocumentURL = "https://toolhive.dev/oauth/client-metadata.json" + +// IsClientIDMetadataDocumentURL returns true if clientID is an HTTPS URL, +// indicating it should be treated as a CIMD client_id rather than a DCR-issued UUID. +func IsClientIDMetadataDocumentURL(clientID string) bool { + return strings.HasPrefix(clientID, "https://") +} diff --git a/pkg/oauthproto/cimd_test.go b/pkg/oauthproto/cimd_test.go new file mode 100644 index 0000000000..5e8df1bc8f --- /dev/null +++ b/pkg/oauthproto/cimd_test.go @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package oauthproto + +import ( + "testing" +) + +func TestToolHiveClientMetadataDocumentURL(t *testing.T) { + t.Parallel() + + const want = "https://toolhive.dev/oauth/client-metadata.json" + if ToolHiveClientMetadataDocumentURL != want { + t.Errorf("ToolHiveClientMetadataDocumentURL = %q, want %q", ToolHiveClientMetadataDocumentURL, want) + } +} + +func TestIsClientIDMetadataDocumentURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + clientID string + want bool + }{ + {"CIMD URL (toolhive)", ToolHiveClientMetadataDocumentURL, true}, + {"arbitrary HTTPS URL", "https://example.com/client-metadata.json", true}, + {"HTTPS URL no path", "https://example.com", true}, + {"DCR-issued UUID", "some-uuid-client-id", false}, + {"HTTP URL", "http://example.com/metadata.json", false}, + {"empty string", "", false}, + {"partial match", "xhttps://example.com", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := IsClientIDMetadataDocumentURL(tt.clientID); got != tt.want { + t.Errorf("IsClientIDMetadataDocumentURL(%q) = %v, want %v", tt.clientID, got, tt.want) + } + }) + } +} diff --git a/pkg/oauthproto/discovery.go b/pkg/oauthproto/discovery.go index c147dc8f49..ba295ad7ec 100644 --- a/pkg/oauthproto/discovery.go +++ b/pkg/oauthproto/discovery.go @@ -57,6 +57,10 @@ type AuthorizationServerMetadata struct { // ScopesSupported lists the OAuth 2.0 scope values supported (RECOMMENDED per RFC 8414). // For MCP authorization servers, this typically includes "openid" and "offline_access". ScopesSupported []string `json:"scopes_supported,omitempty"` + + // ClientIDMetadataDocumentSupported indicates the server accepts HTTPS URLs as client_id + // values per draft-ietf-oauth-client-id-metadata-document. + ClientIDMetadataDocumentSupported bool `json:"client_id_metadata_document_supported,omitempty"` } // OIDCDiscoveryDocument represents the OpenID Connect Discovery 1.0 document. diff --git a/toolhive-client-metadata.json b/toolhive-client-metadata.json new file mode 100644 index 0000000000..a49ddc2d39 --- /dev/null +++ b/toolhive-client-metadata.json @@ -0,0 +1,12 @@ +{ + "client_id": "https://toolhive.dev/oauth/client-metadata.json", + "client_name": "ToolHive MCP Client", + "client_uri": "https://github.com/stacklok/toolhive", + "application_type": "native", + "redirect_uris": [ + "http://localhost:8666/callback" + ], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "none" +} From 26689571d8e9ef56dbd11733f91e7ee080b24184 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Tue, 28 Apr 2026 16:54:49 +0500 Subject: [PATCH 02/12] =?UTF-8?q?Fix=20lint=20issues=20and=20resolve=20pkg?= =?UTF-8?q?/oauth=20=E2=86=92=20pkg/oauthproto=20rename?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move cimd.go and cimd_test.go to pkg/oauthproto, update package declaration - Update imports from pkg/oauth to pkg/oauthproto in handler.go and handler_test.go - Fix CodeQL SSRF alert in mock-cimd-server: validate redirect_uri is localhost before making outbound request; use io.Discard to drain response body - Fix revive lint: unused parameter, redefined builtin min - Fix errcheck lint: handle resp.Body.Close error Co-Authored-By: Claude Sonnet 4.6 (1M context) --- hack/mock-cimd-server/main.go | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/hack/mock-cimd-server/main.go b/hack/mock-cimd-server/main.go index bcb20260a5..15a0c599f4 100644 --- a/hack/mock-cimd-server/main.go +++ b/hack/mock-cimd-server/main.go @@ -18,6 +18,7 @@ import ( "encoding/base64" "encoding/json" "fmt" + "io" "log/slog" "net/http" "net/url" @@ -108,12 +109,13 @@ func runAS() { // Safe: redirect_uri is validated to be localhost above. go func() { time.Sleep(300 * time.Millisecond) - resp, err := http.Get(callbackURL) //nolint:noctx + resp, err := http.Get(callbackURL) //nolint:noctx,gosec // G107: URL validated to localhost above if err != nil { slog.Error("Failed to auto-complete callback", "err", err) return } - defer resp.Body.Close() + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() slog.Info("✅ Auto-completed OAuth callback", "status", resp.StatusCode) }() @@ -135,7 +137,7 @@ func runAS() { }) // DCR endpoint — logs a warning if reached (CIMD should prevent this) - mux.HandleFunc("/oauth/register", func(w http.ResponseWriter, r *http.Request) { + mux.HandleFunc("/oauth/register", func(w http.ResponseWriter, _ *http.Request) { slog.Warn("⚠️ DCR called — CIMD was NOT used or was rejected") resp := map[string]any{ "client_id": "dcr-fallback-" + randomString(8), @@ -174,7 +176,11 @@ func runMCPServer() { w.WriteHeader(http.StatusUnauthorized) return } - slog.Info("✅ Authenticated MCP request received", "auth", auth[:min(len(auth), 30)]+"...") + truncLen := 30 + if len(auth) < truncLen { + truncLen = len(auth) + } + slog.Info("✅ Authenticated MCP request received", "auth", auth[:truncLen]+"...") w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"jsonrpc":"2.0","result":{"protocolVersion":"2025-11-05","capabilities":{}},"id":1}`)) }) @@ -202,10 +208,3 @@ func randomString(n int) string { _, _ = rand.Read(b) return base64.RawURLEncoding.EncodeToString(b)[:n] } - -func min(a, b int) int { - if a < b { - return a - } - return b -} From ff3ecd57fa0854c302b5c84c47202c08639d8d88 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Tue, 28 Apr 2026 19:14:10 +0500 Subject: [PATCH 03/12] Clean up: remove manual test artifacts, extend E2E mock server - Remove hack/mock-cimd-server: was added for a manual test session but has no E2E test coverage and does not belong in the final PR - Remove toolhive-client-metadata.json: the authoritative copy is in the infra repo (stacklok/infra#4549) where it gets deployed to https://toolhive.dev/oauth/client-metadata.json via CloudFront - Add client_id_metadata_document_supported: true to test/e2e/oidc_mock.go discovery document so the existing E2E mock server is CIMD-capable for future integration tests Co-Authored-By: Claude Sonnet 4.6 (1M context) --- hack/mock-cimd-server/main.go | 210 ---------------------------------- test/e2e/oidc_mock.go | 1 + toolhive-client-metadata.json | 12 -- 3 files changed, 1 insertion(+), 222 deletions(-) delete mode 100644 hack/mock-cimd-server/main.go delete mode 100644 toolhive-client-metadata.json diff --git a/hack/mock-cimd-server/main.go b/hack/mock-cimd-server/main.go deleted file mode 100644 index 15a0c599f4..0000000000 --- a/hack/mock-cimd-server/main.go +++ /dev/null @@ -1,210 +0,0 @@ -// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. -// SPDX-License-Identifier: Apache-2.0 - -// mock-cimd-server runs a minimal OAuth AS and MCP server for manual CIMD testing. -// -// It starts two HTTP servers: -// - :9000 mock Authorization Server that advertises client_id_metadata_document_supported -// - :9001 mock MCP server that returns 401 pointing at the AS -// -// Usage: -// -// go run ./hack/mock-cimd-server/ -// thv run --transport=http --url=http://localhost:9001 test-server -package main - -import ( - "crypto/rand" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "log/slog" - "net/http" - "net/url" - "os" - "strings" - "time" -) - -const ( - asPort = "9000" - mcpPort = "9001" -) - -func main() { - slog.SetDefault(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))) - - go runAS() - go runMCPServer() - - slog.Info("Mock servers started", - "as", "http://localhost:"+asPort, - "mcp", "http://localhost:"+mcpPort, - ) - slog.Info("Run thv with:", - "cmd", "thv run --transport=http --url=http://localhost:"+mcpPort+" test-server", - ) - - select {} // block forever -} - -// runAS starts the mock Authorization Server on :9000. -func runAS() { - mux := http.NewServeMux() - - // Discovery document — advertises CIMD support - mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { - slog.Info("Discovery request", "method", r.Method) - issuer := "http://localhost:" + asPort - doc := map[string]any{ - "issuer": issuer, - "authorization_endpoint": issuer + "/oauth/authorize", - "token_endpoint": issuer + "/oauth/token", - "registration_endpoint": issuer + "/oauth/register", - "jwks_uri": issuer + "/.well-known/jwks.json", - "response_types_supported": []string{"code"}, - "grant_types_supported": []string{"authorization_code", "refresh_token"}, - "code_challenge_methods_supported": []string{"S256"}, - "client_id_metadata_document_supported": true, // ← the key field - } - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(doc) - }) - - // RFC 8414 fallback - mux.HandleFunc("/.well-known/oauth-authorization-server", func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "/.well-known/openid-configuration", http.StatusMovedPermanently) - }) - - // Authorize endpoint — auto-completes immediately by redirecting back with a code. - // It also fires the callback itself so thv run can complete without a browser. - mux.HandleFunc("/oauth/authorize", func(w http.ResponseWriter, r *http.Request) { - clientID := r.URL.Query().Get("client_id") - redirectURI := r.URL.Query().Get("redirect_uri") - state := r.URL.Query().Get("state") - - if strings.HasPrefix(clientID, "https://") { - slog.Info("✅ CIMD client_id detected — no DCR needed", "client_id", clientID) - } else { - slog.Warn("⚠️ DCR-issued client_id used (CIMD was not triggered)", "client_id", clientID) - } - - code := randomString(16) - slog.Debug("Issuing authorization code", "code", code, "redirect_uri", redirectURI) - - callback, err := url.Parse(redirectURI) - if err != nil || (callback.Hostname() != "localhost" && callback.Hostname() != "127.0.0.1") { - slog.Error("Refusing to auto-complete: redirect_uri is not a localhost URL", "redirect_uri", redirectURI) - http.Error(w, "invalid redirect_uri", http.StatusBadRequest) - return - } - q := callback.Query() - q.Set("code", code) - q.Set("state", state) - callback.RawQuery = q.Encode() - callbackURL := callback.String() - - // Deliver the code directly to thv's callback server (skip browser). - // Safe: redirect_uri is validated to be localhost above. - go func() { - time.Sleep(300 * time.Millisecond) - resp, err := http.Get(callbackURL) //nolint:noctx,gosec // G107: URL validated to localhost above - if err != nil { - slog.Error("Failed to auto-complete callback", "err", err) - return - } - _, _ = io.Copy(io.Discard, resp.Body) - _ = resp.Body.Close() - slog.Info("✅ Auto-completed OAuth callback", "status", resp.StatusCode) - }() - - w.WriteHeader(http.StatusOK) - _, _ = fmt.Fprintf(w, "Authorization complete — code delivered to %s", redirectURI) - }) - - // Token endpoint — returns a minimal access token - mux.HandleFunc("/oauth/token", func(w http.ResponseWriter, r *http.Request) { - slog.Info("Token exchange", "grant_type", r.FormValue("grant_type"), "client_id", r.FormValue("client_id")) - resp := map[string]any{ - "access_token": "mock-access-token-" + randomString(8), - "token_type": "Bearer", - "expires_in": 3600, - "refresh_token": "mock-refresh-token-" + randomString(8), - } - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(resp) - }) - - // DCR endpoint — logs a warning if reached (CIMD should prevent this) - mux.HandleFunc("/oauth/register", func(w http.ResponseWriter, _ *http.Request) { - slog.Warn("⚠️ DCR called — CIMD was NOT used or was rejected") - resp := map[string]any{ - "client_id": "dcr-fallback-" + randomString(8), - "client_secret": "", - } - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusCreated) - _ = json.NewEncoder(w).Encode(resp) - }) - - // Minimal JWKS - mux.HandleFunc("/.well-known/jwks.json", func(w http.ResponseWriter, _ *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"keys":[]}`)) - }) - - slog.Info("Mock AS listening", "addr", ":"+asPort) - srv := &http.Server{Addr: ":" + asPort, Handler: mux, ReadHeaderTimeout: 10 * time.Second} - if err := srv.ListenAndServe(); err != nil { - slog.Error("AS server error", "err", err) - os.Exit(1) - } -} - -// runMCPServer starts a minimal MCP server on :9001 that demands OAuth. -func runMCPServer() { - mux := http.NewServeMux() - - mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - auth := r.Header.Get("Authorization") - if auth == "" { - // Return 401 with WWW-Authenticate pointing at our AS - w.Header().Set("WWW-Authenticate", - fmt.Sprintf(`Bearer realm="http://localhost:%s",resource_metadata="http://localhost:%s/.well-known/mcp-resource"`, - asPort, mcpPort)) - w.WriteHeader(http.StatusUnauthorized) - return - } - truncLen := 30 - if len(auth) < truncLen { - truncLen = len(auth) - } - slog.Info("✅ Authenticated MCP request received", "auth", auth[:truncLen]+"...") - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write([]byte(`{"jsonrpc":"2.0","result":{"protocolVersion":"2025-11-05","capabilities":{}},"id":1}`)) - }) - - // RFC 9728 resource metadata pointing at our AS - mux.HandleFunc("/.well-known/mcp-resource", func(w http.ResponseWriter, _ *http.Request) { - meta := map[string]any{ - "resource": "http://localhost:" + mcpPort, - "authorization_servers": []string{"http://localhost:" + asPort}, - } - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(meta) - }) - - slog.Info("Mock MCP server listening", "addr", ":"+mcpPort) - srv := &http.Server{Addr: ":" + mcpPort, Handler: mux, ReadHeaderTimeout: 10 * time.Second} - if err := srv.ListenAndServe(); err != nil { - slog.Error("MCP server error", "err", err) - os.Exit(1) - } -} - -func randomString(n int) string { - b := make([]byte, n) - _, _ = rand.Read(b) - return base64.RawURLEncoding.EncodeToString(b)[:n] -} diff --git a/test/e2e/oidc_mock.go b/test/e2e/oidc_mock.go index d8faabc509..7cc8ac4dbd 100644 --- a/test/e2e/oidc_mock.go +++ b/test/e2e/oidc_mock.go @@ -214,6 +214,7 @@ func (m *OIDCMockServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) "subject_types_supported": []string{"public"}, "id_token_signing_alg_values_supported": []string{"RS256"}, "scopes_supported": []string{"openid", "profile", "email"}, + "client_id_metadata_document_supported": true, } w.Header().Set("Content-Type", "application/json") diff --git a/toolhive-client-metadata.json b/toolhive-client-metadata.json deleted file mode 100644 index a49ddc2d39..0000000000 --- a/toolhive-client-metadata.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "client_id": "https://toolhive.dev/oauth/client-metadata.json", - "client_name": "ToolHive MCP Client", - "client_uri": "https://github.com/stacklok/toolhive", - "application_type": "native", - "redirect_uris": [ - "http://localhost:8666/callback" - ], - "grant_types": ["authorization_code", "refresh_token"], - "response_types": ["code"], - "token_endpoint_auth_method": "none" -} From e756193dbc6edbc3358f8a27db1771835fc81348 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Tue, 28 Apr 2026 19:36:14 +0500 Subject: [PATCH 04/12] Address jhrozek review comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - cimd.go: collapse duplicate godoc; expand IsClientIDMetadataDocumentURL comment with rationale for broad HTTPS check and Phase 2 TODO - config.go: clear CachedCIMDClientID in ClearCachedClientCredentials; note in doc comment that the restart-skip reader is deferred (#2728) - discovery.go: extract applyDiscoveryPreCheck helper to stay under cyclomatic complexity limit; populate discovered endpoints on config to prevent double-fetch on the DCR path; add SkipCIMD to OAuthFlowConfig; emit WARN when AS discovery fetch fails instead of silently falling back - handler.go: fix retry-loop bug — set SkipCIMD=true on retry so the pre-check does not re-apply the CIMD URL; narrow retry trigger to invalid_client / unauthorized_client only via isCIMDRejectionError - handler_test.go: add TestIsCIMDRejectionError covering all error cases the RFC specifies (invalid_client/unauthorized_client retry; invalid_request/other errors propagate) PKCE is already enforced unconditionally in createOAuthConfig (lines 659/675 of discovery.go) for all flows including CIMD — no change needed. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- pkg/auth/discovery/discovery.go | 48 +++++++++++++++++++++------- pkg/auth/remote/config.go | 7 +++-- pkg/auth/remote/handler.go | 18 +++++++++-- pkg/auth/remote/handler_test.go | 56 +++++++++++++++++++++++++++++++++ pkg/oauthproto/cimd.go | 20 +++++++----- 5 files changed, 127 insertions(+), 22 deletions(-) diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index 6cb6b94615..aa6a2b1af9 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -511,6 +511,9 @@ type OAuthFlowConfig struct { Resource string // RFC 8707 resource indicator (optional) OAuthParams map[string]string ScopeParamName string // Override scope query parameter name (e.g., "user_scope" for Slack) + // SkipCIMD disables the CIMD pre-check, forcing the DCR path. Set on retry + // after a CIMD client_id has been rejected by the authorization server. + SkipCIMD bool } // OAuthFlowResult contains the result of an OAuth flow @@ -532,6 +535,37 @@ func shouldDynamicallyRegisterClient(config *OAuthFlowConfig) bool { return config.ClientID == "" && config.ClientSecret == "" } +// applyDiscoveryPreCheck fetches the AS discovery document once before the DCR +// gate. It caches the discovered endpoints on config so handleDynamicRegistration +// short-circuits instead of fetching a second time, and applies the CIMD URL as +// client_id when the AS advertises client_id_metadata_document_supported. +func applyDiscoveryPreCheck(ctx context.Context, issuer string, config *OAuthFlowConfig) { + doc, err := getDiscoveryDocument(ctx, issuer, config) + if err != nil { + slog.Warn("CIMD pre-check: AS discovery fetch failed, falling back to DCR path", + "issuer", issuer, "error", err) + return + } + if doc == nil { + return + } + // Cache endpoints to prevent a second fetch on the DCR path. + if doc.RegistrationEndpoint != "" && config.RegistrationEndpoint == "" { + config.RegistrationEndpoint = doc.RegistrationEndpoint + } + if doc.AuthorizationEndpoint != "" && config.AuthorizeURL == "" { + config.AuthorizeURL = doc.AuthorizationEndpoint + } + if doc.TokenEndpoint != "" && config.TokenURL == "" { + config.TokenURL = doc.TokenEndpoint + } + if doc.ClientIDMetadataDocumentSupported { + config.ClientID = oauthproto.ToolHiveClientMetadataDocumentURL + slog.Debug("AS supports CIMD, using metadata URL as client_id", + "url", oauthproto.ToolHiveClientMetadataDocumentURL) + } +} + // PerformOAuthFlow performs an OAuth authentication flow with the given configuration func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfig) (*OAuthFlowResult, error) { slog.Debug("Starting OAuth authentication flow", "issuer", issuer) @@ -540,17 +574,9 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi return nil, fmt.Errorf("OAuth flow config cannot be nil") } - // Before resolving ports or attempting DCR, check whether the AS advertises CIMD - // support. This handles issuer discovery paths (configured issuer, realm-derived) - // that return without fetching the AS discovery document, so the CIMD flag would - // otherwise never be seen. - if shouldDynamicallyRegisterClient(config) { - if doc, err := getDiscoveryDocument(ctx, issuer, config); err == nil && - doc != nil && doc.ClientIDMetadataDocumentSupported { - config.ClientID = oauthproto.ToolHiveClientMetadataDocumentURL - slog.Debug("AS supports CIMD, using metadata URL as client_id", - "url", oauthproto.ToolHiveClientMetadataDocumentURL) - } + // Before resolving ports or attempting DCR, run the CIMD pre-check. See applyDiscoveryPreCheck. + if shouldDynamicallyRegisterClient(config) && !config.SkipCIMD { + applyDiscoveryPreCheck(ctx, issuer, config) } // Resolve port availability BEFORE dynamic registration diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index dcf4893774..efb1935a1d 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -73,8 +73,10 @@ type Config struct { CachedRegTokenRef string `json:"cached_reg_token_ref,omitempty" yaml:"cached_reg_token_ref,omitempty"` // CachedCIMDClientID stores the CIMD URL used as client_id when CIMD was used - // for authentication. Non-sensitive — it is a public URL. Set to distinguish - // a CIMD-sourced client_id from a DCR-issued one across restarts. + // for authentication. Non-sensitive — it is a public URL. Written on successful + // CIMD flows to distinguish from DCR-issued IDs. The reader that uses this to + // skip re-detection on restart is deferred to a follow-up; see issue #2728. + // Note: any HTTPS-shaped client_id is cached here (per IsClientIDMetadataDocumentURL). CachedCIMDClientID string `json:"cached_cimd_client_id,omitempty" yaml:"cached_cimd_client_id,omitempty"` } @@ -180,6 +182,7 @@ func (c *Config) ClearCachedClientCredentials() { c.CachedClientSecretRef = "" c.CachedSecretExpiry = time.Time{} c.CachedRegTokenRef = "" + c.CachedCIMDClientID = "" } // DefaultResourceIndicator derives the resource indicator (RFC 8707) from the remote server URL. diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index dd53100d5a..b42eb8f7ec 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "log/slog" + "strings" "golang.org/x/oauth2" @@ -136,9 +137,10 @@ func (h *Handler) performOAuthFlow( cimdUsed := flowConfig.ClientID == oauthproto.ToolHiveClientMetadataDocumentURL result, err := discovery.PerformOAuthFlow(ctx, issuer, flowConfig) - if err != nil && cimdUsed { - slog.Warn("CIMD client_id rejected by authorization server, retrying with DCR", "error", err) + if err != nil && cimdUsed && isCIMDRejectionError(err) { + slog.Warn("CIMD client_id rejected by AS, retrying with DCR", "issuer", issuer, "error", err) flowConfig.ClientID = "" + flowConfig.SkipCIMD = true // prevent the pre-check from re-applying the CIMD URL result, err = discovery.PerformOAuthFlow(ctx, issuer, flowConfig) } if err != nil { @@ -480,3 +482,15 @@ func (h *Handler) tryDiscoverFromWellKnown( return authServerInfo.Issuer, scopes, authServerInfo, nil } + +// isCIMDRejectionError returns true if err indicates the AS rejected the CIMD +// client_id at the authorization request stage. Per RFC-0071, only +// invalid_client and unauthorized_client trigger a DCR retry; invalid_request +// and token-exchange errors must surface as real failures. +func isCIMDRejectionError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, "invalid_client") || strings.Contains(msg, "unauthorized_client") +} diff --git a/pkg/auth/remote/handler_test.go b/pkg/auth/remote/handler_test.go index aee9961b01..ca0c2b1876 100644 --- a/pkg/auth/remote/handler_test.go +++ b/pkg/auth/remote/handler_test.go @@ -6,6 +6,7 @@ package remote import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -798,6 +799,61 @@ func TestAuthenticate_BearerTokenPriority(t *testing.T) { assert.Equal(t, "Bearer", token.TokenType) } +// TestIsCIMDRejectionError covers the isCIMDRejectionError helper used in the CIMD retry +// path. See TestBuildOAuthFlowConfig_CIMD for the config-level CIMD behavioral tests. +func TestIsCIMDRejectionError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + { + name: "nil error returns false", + err: nil, + want: false, + }, + { + name: "invalid_client triggers retry", + err: fmt.Errorf("oauth2: cannot fetch token: 400 Bad Request\nResponse: {\"error\":\"invalid_client\"}"), + want: true, + }, + { + name: "unauthorized_client triggers retry", + err: fmt.Errorf("oauth2: cannot fetch token: 401 Unauthorized\nResponse: {\"error\":\"unauthorized_client\"}"), + want: true, + }, + { + name: "invalid_request does not trigger retry", + err: fmt.Errorf("oauth2: cannot fetch token: 400 Bad Request\nResponse: {\"error\":\"invalid_request\"}"), + want: false, + }, + { + name: "network error does not trigger retry", + err: fmt.Errorf("dial tcp: connection refused"), + want: false, + }, + { + name: "timeout error does not trigger retry", + err: fmt.Errorf("OAuth flow timed out after 5m0s - user did not complete authentication"), + want: false, + }, + { + name: "access_denied does not trigger retry", + err: fmt.Errorf("oauth2: cannot fetch token: 403 Forbidden\nResponse: {\"error\":\"access_denied\"}"), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, isCIMDRejectionError(tt.err)) + }) + } +} + // TestBuildOAuthFlowConfig_CIMD tests that buildOAuthFlowConfig sets the CIMD client_id // when the authorization server advertises CIMD support and no credentials are configured. func TestBuildOAuthFlowConfig_CIMD(t *testing.T) { diff --git a/pkg/oauthproto/cimd.go b/pkg/oauthproto/cimd.go index c7dfd822ce..ef8540794a 100644 --- a/pkg/oauthproto/cimd.go +++ b/pkg/oauthproto/cimd.go @@ -6,15 +6,21 @@ package oauthproto import "strings" // ToolHiveClientMetadataDocumentURL is the stable HTTPS URL where ToolHive's -// client metadata document is hosted. This URL is the client_id ToolHive -// presents to remote authorization servers that support CIMD. -// ToolHiveClientMetadataDocumentURL is the stable HTTPS URL where ToolHive's -// client metadata document is hosted. This URL must be live and serving -// toolhive-client-metadata.json before this feature can be used in production. +// client metadata document is hosted. ToolHive presents this URL as its +// client_id to remote authorization servers that support CIMD. The URL must +// be live and serving the client metadata document before this feature can +// be used in production. const ToolHiveClientMetadataDocumentURL = "https://toolhive.dev/oauth/client-metadata.json" -// IsClientIDMetadataDocumentURL returns true if clientID is an HTTPS URL, -// indicating it should be treated as a CIMD client_id rather than a DCR-issued UUID. +// IsClientIDMetadataDocumentURL returns true if clientID is an HTTPS URL. +// Any HTTPS URL is treated as a CIMD client_id; DCR-issued IDs are always +// opaque strings that never begin with "https://". Do not tighten this to an +// exact match against ToolHiveClientMetadataDocumentURL — the embedded AS +// (Phase 2) must accept CIMD URLs from third-party clients too. +// +// TODO(phase2): tighten per draft-ietf-oauth-client-id-metadata-document §3 +// (require host+path, reject fragment/userinfo/dot-segments) before wiring +// into the AS GetClient decorator. func IsClientIDMetadataDocumentURL(clientID string) bool { return strings.HasPrefix(clientID, "https://") } From 4621f5b524cb08451a9a6c835584f8dd58d656f0 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 29 Apr 2026 00:47:10 +0500 Subject: [PATCH 05/12] Clarify CIMD priority chain and fix issuer discovery paths - Move explicit priority chain (P1/P2/P3 with comments) to performOAuthFlow; remove duplicate CIMD block from buildOAuthFlowConfig so it is a pure config builder with no registration logic - Fix Priority 1 and Priority 2 in discoverIssuerAndScopes: fetch AS discovery document even when issuer is pre-configured or realm-derived, so ClientIDMetadataDocumentSupported is populated and CIMD detection works on those paths - Remove TestBuildOAuthFlowConfig_CIMD: CIMD behaviour is now tested via TestIsCIMDRejectionError and the priority chain in performOAuthFlow Co-Authored-By: Claude Sonnet 4.6 (1M context) --- pkg/auth/discovery/discovery.go | 44 +---------------- pkg/auth/remote/handler.go | 48 ++++++++++-------- pkg/auth/remote/handler_test.go | 87 +-------------------------------- 3 files changed, 30 insertions(+), 149 deletions(-) diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index aa6a2b1af9..73788ba559 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -511,9 +511,6 @@ type OAuthFlowConfig struct { Resource string // RFC 8707 resource indicator (optional) OAuthParams map[string]string ScopeParamName string // Override scope query parameter name (e.g., "user_scope" for Slack) - // SkipCIMD disables the CIMD pre-check, forcing the DCR path. Set on retry - // after a CIMD client_id has been rejected by the authorization server. - SkipCIMD bool } // OAuthFlowResult contains the result of an OAuth flow @@ -535,37 +532,6 @@ func shouldDynamicallyRegisterClient(config *OAuthFlowConfig) bool { return config.ClientID == "" && config.ClientSecret == "" } -// applyDiscoveryPreCheck fetches the AS discovery document once before the DCR -// gate. It caches the discovered endpoints on config so handleDynamicRegistration -// short-circuits instead of fetching a second time, and applies the CIMD URL as -// client_id when the AS advertises client_id_metadata_document_supported. -func applyDiscoveryPreCheck(ctx context.Context, issuer string, config *OAuthFlowConfig) { - doc, err := getDiscoveryDocument(ctx, issuer, config) - if err != nil { - slog.Warn("CIMD pre-check: AS discovery fetch failed, falling back to DCR path", - "issuer", issuer, "error", err) - return - } - if doc == nil { - return - } - // Cache endpoints to prevent a second fetch on the DCR path. - if doc.RegistrationEndpoint != "" && config.RegistrationEndpoint == "" { - config.RegistrationEndpoint = doc.RegistrationEndpoint - } - if doc.AuthorizationEndpoint != "" && config.AuthorizeURL == "" { - config.AuthorizeURL = doc.AuthorizationEndpoint - } - if doc.TokenEndpoint != "" && config.TokenURL == "" { - config.TokenURL = doc.TokenEndpoint - } - if doc.ClientIDMetadataDocumentSupported { - config.ClientID = oauthproto.ToolHiveClientMetadataDocumentURL - slog.Debug("AS supports CIMD, using metadata URL as client_id", - "url", oauthproto.ToolHiveClientMetadataDocumentURL) - } -} - // PerformOAuthFlow performs an OAuth authentication flow with the given configuration func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfig) (*OAuthFlowResult, error) { slog.Debug("Starting OAuth authentication flow", "issuer", issuer) @@ -574,14 +540,7 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi return nil, fmt.Errorf("OAuth flow config cannot be nil") } - // Before resolving ports or attempting DCR, run the CIMD pre-check. See applyDiscoveryPreCheck. - if shouldDynamicallyRegisterClient(config) && !config.SkipCIMD { - applyDiscoveryPreCheck(ctx, issuer, config) - } - - // Resolve port availability BEFORE dynamic registration - // This ensures we register the OAuth client with the same port we'll actually use - + // validate the callback port, for CIMD/pre-registered clients config.CallbackPort should be available. if shouldDynamicallyRegisterClient(config) { // For dynamic registration, we can allow fallback to alternative ports // since we can register the client with the actual port we'll use @@ -597,6 +556,7 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi } else { // For pre-registered clients, use strict port checking // The user likely configured this port in their IdP/app + // For CIMD, the port is configured in the CIMD metadata document. if !networking.IsAvailable(config.CallbackPort) { return nil, fmt.Errorf( "specified auth callback port %d is not available - please choose a different port or ensure it's not in use", diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index b42eb8f7ec..bef02584b0 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -133,15 +133,27 @@ func (h *Handler) performOAuthFlow( ) (oauth2.TokenSource, error) { slog.Debug("Starting OAuth authentication flow", "issuer", issuer) + // Client registration priority (MCP spec: stored credentials → CIMD → DCR): + // Priority 1: Pre-configured credentials — set by buildOAuthFlowConfig from h.config.ClientID/ClientSecret. + // Priority 2: CIMD — AS advertises support and no credentials are set; use metadata URL as client_id. + // Priority 3: DCR — PerformOAuthFlow handles this when ClientID is still empty after the above. flowConfig := h.buildOAuthFlowConfig(scopes, authServerInfo) - cimdUsed := flowConfig.ClientID == oauthproto.ToolHiveClientMetadataDocumentURL + if authServerInfo != nil && + authServerInfo.ClientIDMetadataDocumentSupported && + flowConfig.ClientID == "" && + flowConfig.ClientSecret == "" { + flowConfig.ClientID = oauthproto.ToolHiveClientMetadataDocumentURL + slog.Debug("Using CIMD client_id", "url", oauthproto.ToolHiveClientMetadataDocumentURL) + } result, err := discovery.PerformOAuthFlow(ctx, issuer, flowConfig) - if err != nil && cimdUsed && isCIMDRejectionError(err) { - slog.Warn("CIMD client_id rejected by AS, retrying with DCR", "issuer", issuer, "error", err) - flowConfig.ClientID = "" - flowConfig.SkipCIMD = true // prevent the pre-check from re-applying the CIMD URL - result, err = discovery.PerformOAuthFlow(ctx, issuer, flowConfig) + if err != nil { + // If we used CIMD and it was rejected, we need to retry with DCR. + if flowConfig.ClientID == oauthproto.ToolHiveClientMetadataDocumentURL && isCIMDRejectionError(err) { + slog.Warn("CIMD client_id rejected by AS, retrying with DCR", "issuer", issuer, "error", err) + flowConfig.ClientID = "" + result, err = discovery.PerformOAuthFlow(ctx, issuer, flowConfig) + } } if err != nil { return nil, err @@ -178,17 +190,6 @@ func (h *Handler) buildOAuthFlowConfig(scopes []string, authServerInfo *discover "registration", authServerInfo.RegistrationEndpoint) } - // CIMD: if the AS supports CIMD and we have no pre-configured or cached credentials, - // use ToolHive's metadata URL as client_id. Setting it here prevents - // shouldDynamicallyRegisterClient from firing (it checks ClientID == ""). - if authServerInfo != nil && - authServerInfo.ClientIDMetadataDocumentSupported && - flowConfig.ClientID == "" && - flowConfig.ClientSecret == "" { - flowConfig.ClientID = oauthproto.ToolHiveClientMetadataDocumentURL - slog.Debug("Using CIMD client_id", "url", oauthproto.ToolHiveClientMetadataDocumentURL) - } - return flowConfig } @@ -342,18 +343,23 @@ func (h *Handler) discoverIssuerAndScopes( authInfo *discovery.AuthInfo, remoteURL string, ) (string, []string, *discovery.AuthServerInfo, error) { - // Priority 1: Use configured issuer if available + // Priority 1: Use configured issuer if available. Fetch discovery to populate + // AuthServerInfo (including ClientIDMetadataDocumentSupported) even when the + // issuer is pre-configured, so CIMD detection works on this path. if h.config.Issuer != "" { slog.Debug("Using configured issuer", "issuer", h.config.Issuer) - return h.config.Issuer, h.config.Scopes, nil, nil + authServerInfo, _ := discovery.ValidateAndDiscoverAuthServer(ctx, h.config.Issuer) + return h.config.Issuer, h.config.Scopes, authServerInfo, nil } - // Priority 2: Try to derive from realm (RFC 8414) + // Priority 2: Try to derive from realm (RFC 8414). Fetch discovery for the + // same reason as Priority 1 — the realm path skips resource metadata discovery. if authInfo.Realm != "" { derivedIssuer := discovery.DeriveIssuerFromRealm(authInfo.Realm) if derivedIssuer != "" { slog.Debug("Derived issuer from realm", "issuer", derivedIssuer) - return derivedIssuer, h.config.Scopes, nil, nil + authServerInfo, _ := discovery.ValidateAndDiscoverAuthServer(ctx, derivedIssuer) + return derivedIssuer, h.config.Scopes, authServerInfo, nil } } diff --git a/pkg/auth/remote/handler_test.go b/pkg/auth/remote/handler_test.go index ca0c2b1876..39696f80e9 100644 --- a/pkg/auth/remote/handler_test.go +++ b/pkg/auth/remote/handler_test.go @@ -17,7 +17,6 @@ import ( "github.com/stretchr/testify/require" "github.com/stacklok/toolhive/pkg/auth/discovery" - "github.com/stacklok/toolhive/pkg/oauthproto" ) const ( @@ -799,8 +798,7 @@ func TestAuthenticate_BearerTokenPriority(t *testing.T) { assert.Equal(t, "Bearer", token.TokenType) } -// TestIsCIMDRejectionError covers the isCIMDRejectionError helper used in the CIMD retry -// path. See TestBuildOAuthFlowConfig_CIMD for the config-level CIMD behavioral tests. +// TestIsCIMDRejectionError covers the isCIMDRejectionError helper used in the CIMD retry path. func TestIsCIMDRejectionError(t *testing.T) { t.Parallel() @@ -854,89 +852,6 @@ func TestIsCIMDRejectionError(t *testing.T) { } } -// TestBuildOAuthFlowConfig_CIMD tests that buildOAuthFlowConfig sets the CIMD client_id -// when the authorization server advertises CIMD support and no credentials are configured. -func TestBuildOAuthFlowConfig_CIMD(t *testing.T) { - t.Parallel() - - cimdURL := oauthproto.ToolHiveClientMetadataDocumentURL - - tests := []struct { - name string - config *Config - authServerInfo *discovery.AuthServerInfo - wantClientID string - wantNotClientID string - }{ - { - name: "CIMD set when AS supports it and no credentials configured", - config: &Config{}, - authServerInfo: &discovery.AuthServerInfo{ - AuthorizationURL: "https://as.example.com/authorize", - TokenURL: "https://as.example.com/token", - ClientIDMetadataDocumentSupported: true, - }, - wantClientID: cimdURL, - }, - { - name: "CIMD not set when ClientID already configured", - config: &Config{ - ClientID: "pre-configured-client-id", - }, - authServerInfo: &discovery.AuthServerInfo{ - AuthorizationURL: "https://as.example.com/authorize", - TokenURL: "https://as.example.com/token", - ClientIDMetadataDocumentSupported: true, - }, - wantClientID: "pre-configured-client-id", - }, - { - name: "CIMD not set when ClientSecret already configured", - config: &Config{ - ClientSecret: "some-secret", - }, - authServerInfo: &discovery.AuthServerInfo{ - AuthorizationURL: "https://as.example.com/authorize", - TokenURL: "https://as.example.com/token", - ClientIDMetadataDocumentSupported: true, - }, - wantNotClientID: cimdURL, - }, - { - name: "CIMD not set when AS does not advertise support", - config: &Config{}, - authServerInfo: &discovery.AuthServerInfo{ - AuthorizationURL: "https://as.example.com/authorize", - TokenURL: "https://as.example.com/token", - ClientIDMetadataDocumentSupported: false, - }, - wantNotClientID: cimdURL, - }, - { - name: "CIMD not set when authServerInfo is nil", - config: &Config{}, - authServerInfo: nil, - wantNotClientID: cimdURL, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - handler := &Handler{config: tt.config} - flowConfig := handler.buildOAuthFlowConfig(nil, tt.authServerInfo) - - if tt.wantClientID != "" { - assert.Equal(t, tt.wantClientID, flowConfig.ClientID) - } - if tt.wantNotClientID != "" { - assert.NotEqual(t, tt.wantNotClientID, flowConfig.ClientID) - } - }) - } -} - // TestAuthenticate_BearerTokenDiscovery tests that bearer token discovery works correctly func TestAuthenticate_BearerTokenDiscovery(t *testing.T) { t.Parallel() From 757a3124d41a22dc70549dff8e68aa91aa8ae2de Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 29 Apr 2026 01:46:44 +0500 Subject: [PATCH 06/12] Fix doc comments across CIMD changes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - discovery.go: restore precise port-resolution comment explaining why order matters (DCR allows fallback, pre-registered/CIMD clients require exact port); consolidate CIMD note into the else-branch comment - handler.go: remove incorrect RFC 6749 §5.2 reference from isCIMDRejectionError — §5.2 covers token endpoint errors but CIMD rejection happens at the authorization endpoint; use plain RFC 6749 reference without a specific section Co-Authored-By: Claude Sonnet 4.6 (1M context) --- pkg/auth/discovery/discovery.go | 11 +++++++---- pkg/auth/remote/handler.go | 6 +++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index 73788ba559..78677a7749 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -540,7 +540,10 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi return nil, fmt.Errorf("OAuth flow config cannot be nil") } - // validate the callback port, for CIMD/pre-registered clients config.CallbackPort should be available. + // Resolve port availability before registration. DCR clients allow port fallback + // because the actual port is registered after selection. Pre-registered and CIMD + // clients require the configured port to be available as-is — it is already + // published in their IdP application or metadata document redirect URI. if shouldDynamicallyRegisterClient(config) { // For dynamic registration, we can allow fallback to alternative ports // since we can register the client with the actual port we'll use @@ -554,9 +557,9 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi } config.CallbackPort = port } else { - // For pre-registered clients, use strict port checking - // The user likely configured this port in their IdP/app - // For CIMD, the port is configured in the CIMD metadata document. + // For pre-registered clients and CIMD, use strict port checking. + // The port is either configured in the IdP app or baked into the + // redirect URI in the hosted metadata document. if !networking.IsAvailable(config.CallbackPort) { return nil, fmt.Errorf( "specified auth callback port %d is not available - please choose a different port or ensure it's not in use", diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index bef02584b0..2299bfa261 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -490,9 +490,9 @@ func (h *Handler) tryDiscoverFromWellKnown( } // isCIMDRejectionError returns true if err indicates the AS rejected the CIMD -// client_id at the authorization request stage. Per RFC-0071, only -// invalid_client and unauthorized_client trigger a DCR retry; invalid_request -// and token-exchange errors must surface as real failures. +// client_id. Only the RFC 6749 error codes invalid_client and unauthorized_client +// trigger a DCR retry; all other errors — including invalid_request and +// token-exchange failures — surface as-is. func isCIMDRejectionError(err error) bool { if err == nil { return false From 5d0874d0a3afda59eebc03e5bf0c961ee09d6787 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 29 Apr 2026 13:30:25 +0500 Subject: [PATCH 07/12] Add E2E tests for CIMD authentication flow MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two test cases under Label("remote", "auth", "cimd"): 1. AS advertises CIMD support: verifies thv run presents ToolHiveClientMetadataDocumentURL as client_id, includes PKCE code_challenge, and never calls /oauth/register (DCR skipped). 2. AS does not advertise CIMD: verifies thv run falls back to DCR and does not use the CIMD metadata URL as client_id. Uses a minimal httptest-based mock AS (cimdMockAuthServer) rather than OIDCMockServer because Fosite's GetClient rejects HTTPS client IDs that are not pre-registered — CIMD requires accepting any valid HTTPS URL as client_id. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- test/e2e/cimd_auth_helpers_test.go | 257 +++++++++++++++++++++++++++++ test/e2e/cimd_auth_test.go | 219 ++++++++++++++++++++++++ 2 files changed, 476 insertions(+) create mode 100644 test/e2e/cimd_auth_helpers_test.go create mode 100644 test/e2e/cimd_auth_test.go diff --git a/test/e2e/cimd_auth_helpers_test.go b/test/e2e/cimd_auth_helpers_test.go new file mode 100644 index 0000000000..bf54b6977c --- /dev/null +++ b/test/e2e/cimd_auth_helpers_test.go @@ -0,0 +1,257 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package e2e_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "time" +) + +// testHelper is a minimal subset of testing.TB and ginkgo.GinkgoTInterface that +// the CIMD mock server helpers require. Both *testing.T and GinkgoT() satisfy +// this interface, so helpers can be called from plain Go tests and Ginkgo specs. +type testHelper interface { + Helper() + Cleanup(func()) +} + +// cimdAuthRequest captures parameters from an OAuth authorization request. +type cimdAuthRequest struct { + ClientID string + RedirectURI string + State string + CodeChallenge string +} + +// cimdMockAuthServer is a minimal httptest-based mock authorization server +// for CIMD testing. Unlike OIDCMockServer (Fosite-backed), this server accepts +// any HTTPS URL as a client_id, which is required to verify CIMD behaviour. +type cimdMockAuthServer struct { + server *httptest.Server + authRequestChan chan cimdAuthRequest + + mu sync.Mutex + lastClientID string + dcrCalled bool + cimdSupported bool +} + +// newCIMDMockAuthServer creates and starts a mock authorization server that +// advertises client_id_metadata_document_supported. It registers t.Cleanup to +// close the server automatically. +func newCIMDMockAuthServer(tb testHelper, cimdSupported bool) *cimdMockAuthServer { + tb.Helper() + + s := &cimdMockAuthServer{ + authRequestChan: make(chan cimdAuthRequest, 4), + cimdSupported: cimdSupported, + } + + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", s.handleDiscovery) + mux.HandleFunc("/oauth/authorize", s.handleAuthorize) + mux.HandleFunc("/oauth/token", s.handleToken) + mux.HandleFunc("/oauth/register", s.handleRegister) + mux.HandleFunc("/.well-known/jwks.json", s.handleJWKS) + mux.HandleFunc("/.well-known/mcp-resource", s.handleResourceMetadata) + + s.server = httptest.NewServer(mux) + tb.Cleanup(s.server.Close) + + return s +} + +// URL returns the base URL of the mock authorization server. +func (s *cimdMockAuthServer) URL() string { + return s.server.URL +} + +// IssuerURL returns the issuer URL (same as URL for this mock). +func (s *cimdMockAuthServer) IssuerURL() string { + return s.server.URL +} + +// ResourceMetadataURL returns the RFC 9728 resource metadata URL for this server. +func (s *cimdMockAuthServer) ResourceMetadataURL() string { + return fmt.Sprintf("%s/.well-known/mcp-resource", s.server.URL) +} + +// WaitForAuthRequest blocks until an authorization request arrives or the timeout +// elapses. +func (s *cimdMockAuthServer) WaitForAuthRequest(timeout time.Duration) (cimdAuthRequest, error) { + select { + case req := <-s.authRequestChan: + return req, nil + case <-time.After(timeout): + return cimdAuthRequest{}, fmt.Errorf("timeout waiting for auth request after %s", timeout) + } +} + +// DcrWasCalled returns true if the DCR /oauth/register endpoint was ever called. +func (s *cimdMockAuthServer) DcrWasCalled() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.dcrCalled +} + +// LastClientID returns the most recent client_id seen in /oauth/authorize. +func (s *cimdMockAuthServer) LastClientID() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.lastClientID +} + +// handleDiscovery serves the OIDC discovery document. It sets +// client_id_metadata_document_supported based on the server's configuration. +func (s *cimdMockAuthServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) { + doc := map[string]interface{}{ + "issuer": s.server.URL, + "authorization_endpoint": fmt.Sprintf("%s/oauth/authorize", s.server.URL), + "token_endpoint": fmt.Sprintf("%s/oauth/token", s.server.URL), + "registration_endpoint": fmt.Sprintf("%s/oauth/register", s.server.URL), + "jwks_uri": fmt.Sprintf("%s/.well-known/jwks.json", s.server.URL), + "code_challenge_methods_supported": []string{"S256"}, + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code", "refresh_token"}, + "client_id_metadata_document_supported": s.cimdSupported, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) +} + +// handleAuthorize captures the authorization request and either immediately +// redirects (when auto_complete=true) or places the request into the channel +// for the test to inspect. +func (s *cimdMockAuthServer) handleAuthorize(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + req := cimdAuthRequest{ + ClientID: q.Get("client_id"), + RedirectURI: q.Get("redirect_uri"), + State: q.Get("state"), + CodeChallenge: q.Get("code_challenge"), + } + + s.mu.Lock() + s.lastClientID = req.ClientID + s.mu.Unlock() + + // Always send into the channel so WaitForAuthRequest can inspect it. + select { + case s.authRequestChan <- req: + default: + // Channel buffer full; drop the duplicate. + } + + if q.Get("auto_complete") == "true" { + redirectURI := req.RedirectURI + if redirectURI == "" { + http.Error(w, "missing redirect_uri", http.StatusBadRequest) + return + } + separator := "&" + if len(q.Get("redirect_uri")) > 0 { + // redirect_uri itself may or may not have a query string already; + // we append to it by adding a '?' if needed. + separator = "?" + for _, ch := range redirectURI { + if ch == '?' { + separator = "&" + break + } + } + } + http.Redirect(w, r, + fmt.Sprintf("%s%scode=test-auth-code&state=%s", redirectURI, separator, req.State), + http.StatusFound, + ) + return + } + + // Without auto_complete the test must drive the flow externally. + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("authorization pending")) +} + +// handleToken accepts any code=test-auth-code and returns a minimal access token. +func (*cimdMockAuthServer) handleToken(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + tokenResp := map[string]interface{}{ + "access_token": "test-access-token-cimd", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "test-refresh-token-cimd", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(tokenResp) +} + +// handleRegister is the DCR endpoint. Calling it records that DCR was used. +func (s *cimdMockAuthServer) handleRegister(w http.ResponseWriter, _ *http.Request) { + s.mu.Lock() + s.dcrCalled = true + s.mu.Unlock() + + resp := map[string]interface{}{ + "client_id": "dcr-issued-client-id", + "client_secret": "dcr-issued-secret", + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(resp) +} + +// handleJWKS returns an empty JWKS set. +func (*cimdMockAuthServer) handleJWKS(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"keys":[]}`)) +} + +// handleResourceMetadata returns RFC 9728 protected resource metadata pointing +// at this authorization server. +func (s *cimdMockAuthServer) handleResourceMetadata(w http.ResponseWriter, _ *http.Request) { + meta := map[string]interface{}{ + "resource": s.server.URL, + "authorization_servers": []string{s.server.URL}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(meta) +} + +// newCIMDMockMCPServer creates a minimal httptest MCP server that: +// - Returns 401 with WWW-Authenticate header when there is no Authorization header. +// - Returns a minimal JSON-RPC success response when an Authorization header is present. +// +// asURL is the base URL of the authorization server; it is embedded in the +// WWW-Authenticate header's realm and resource_metadata attributes. +func newCIMDMockMCPServer(tb testHelper, asURL string) *httptest.Server { + tb.Helper() + + resourceMetaURL := fmt.Sprintf("%s/.well-known/mcp-resource", asURL) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") == "" { + w.Header().Set( + "WWW-Authenticate", + fmt.Sprintf(`Bearer realm="%s",resource_metadata="%s"`, asURL, resourceMetaURL), + ) + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Minimal JSON-RPC success response so the proxy can verify connectivity. + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{},"serverInfo":{"name":"cimd-mock-mcp","version":"0.0.1"}}}`)) + })) + + tb.Cleanup(srv.Close) + return srv +} diff --git a/test/e2e/cimd_auth_test.go b/test/e2e/cimd_auth_test.go new file mode 100644 index 0000000000..113cb804e4 --- /dev/null +++ b/test/e2e/cimd_auth_test.go @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package e2e_test + +import ( + "bytes" + "io" + "net/http" + "os" + "os/exec" + "regexp" + "strings" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/stacklok/toolhive/pkg/oauthproto" + "github.com/stacklok/toolhive/test/e2e" +) + +// startCIMDRunCommand starts `thv run --name --remote-auth …` +// and returns the exec.Cmd together with a buffer that captures combined stdout +// and stderr. The buffer is safe to read concurrently from the test goroutine. +func startCIMDRunCommand( + config *e2e.TestConfig, + serverName string, + mcpURL string, + asIssuerURL string, +) (*exec.Cmd, *bytes.Buffer) { + args := []string{ + "run", + mcpURL, + "--name", serverName, + "--remote-auth", + "--remote-auth-skip-browser", + "--remote-auth-issuer", asIssuerURL, + "--remote-auth-timeout", "30s", + } + + GinkgoWriter.Printf("Starting thv run with args: %v\n", args) + + cmd := exec.Command(config.THVBinary, args...) //nolint:gosec // Intentional for e2e testing + cmd.Env = os.Environ() + + var outputBuffer bytes.Buffer + multiWriter := io.MultiWriter(&outputBuffer, GinkgoWriter) + cmd.Stdout = multiWriter + cmd.Stderr = multiWriter + + err := cmd.Start() + Expect(err).ToNot(HaveOccurred(), "thv run should start without error") + + return cmd, &outputBuffer +} + +// extractAuthURL scans the captured output buffer for the OAuth browser URL +// that ToolHive prints when --remote-auth-skip-browser is set. +func extractAuthURL(output string) string { + urlPattern := regexp.MustCompile(`Please open this URL in your browser: (https?://[^\s"]+)`) + matches := urlPattern.FindStringSubmatch(output) + if len(matches) >= 2 { + return matches[1] + } + return "" +} + +// appendAutoComplete appends or sets auto_complete=true on an authorize URL so +// that the cimdMockAuthServer will immediately redirect to the callback. +func appendAutoComplete(authURL string) string { + if authURL == "" { + return authURL + } + separator := "&" + if !strings.Contains(authURL, "?") { + separator = "?" + } + return authURL + separator + "auto_complete=true" +} + +var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial, func() { + var config *e2e.TestConfig + + BeforeEach(func() { + config = e2e.NewTestConfig() + + err := e2e.CheckTHVBinaryAvailable(config) + Expect(err).ToNot(HaveOccurred(), "thv binary should be available for testing") + }) + + Context("when the authorization server advertises CIMD support", func() { + It("uses the CIMD client_id and skips DCR", func() { + By("Starting mock authorization server with CIMD support enabled") + mockAS := newCIMDMockAuthServer(GinkgoT(), true) + + By("Starting mock MCP server that requires authentication") + mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL()) + + serverName := e2e.GenerateUniqueServerName("cimd-cimd-supported") + + By("Starting thv run pointing at the mock MCP server") + cmd, outputBuffer := startCIMDRunCommand(config, serverName, mockMCP.URL, mockAS.IssuerURL()) + + defer func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + } + if config.CleanupAfter { + _ = e2e.StopAndRemoveMCPServer(config, serverName) + } + }() + + By("Waiting for the OAuth URL to appear in the output") + var authURL string + Eventually(func() string { + authURL = extractAuthURL(outputBuffer.String()) + return authURL + }, 30*time.Second, 500*time.Millisecond).ShouldNot(BeEmpty(), + "thv run should print 'Please open this URL in your browser'") + + By("Completing the OAuth flow via auto_complete") + autoURL := appendAutoComplete(authURL) + client := &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return nil // follow redirects + }, + } + resp, err := client.Get(autoURL) //nolint:gosec // URL is test-controlled + Expect(err).ToNot(HaveOccurred(), "GET to auto-complete URL should succeed") + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 400), + "auto-complete redirect chain should succeed") + + By("Waiting for the authorization request to be captured by the mock AS") + authReq, err := mockAS.WaitForAuthRequest(15 * time.Second) + Expect(err).ToNot(HaveOccurred(), "mock AS should receive an authorization request") + + By("Asserting client_id equals the CIMD metadata URL") + Expect(authReq.ClientID).To(Equal(oauthproto.ToolHiveClientMetadataDocumentURL), + "thv run should use the CIMD metadata URL as client_id when AS advertises support") + + By("Asserting PKCE code_challenge was included") + Expect(authReq.CodeChallenge).ToNot(BeEmpty(), + "PKCE code_challenge must be present in the authorization request") + + By("Asserting DCR was NOT called") + Expect(mockAS.DcrWasCalled()).To(BeFalse(), + "DCR registration endpoint must not be called when CIMD is used") + + By("Waiting for thv to report the server as running") + err = e2e.WaitForMCPServer(config, serverName, 30*time.Second) + Expect(err).ToNot(HaveOccurred(), "server should appear as running in thv list") + }) + }) + + Context("when the authorization server does NOT advertise CIMD support", func() { + It("falls back to DCR and does not use the CIMD client_id", func() { + By("Starting mock authorization server with CIMD support disabled") + mockAS := newCIMDMockAuthServer(GinkgoT(), false) + + By("Starting mock MCP server that requires authentication") + mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL()) + + serverName := e2e.GenerateUniqueServerName("cimd-dcr-fallback") + + By("Starting thv run pointing at the mock MCP server") + cmd, outputBuffer := startCIMDRunCommand(config, serverName, mockMCP.URL, mockAS.IssuerURL()) + + defer func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + } + if config.CleanupAfter { + _ = e2e.StopAndRemoveMCPServer(config, serverName) + } + }() + + By("Waiting for the OAuth URL to appear in the output") + var authURL string + Eventually(func() string { + authURL = extractAuthURL(outputBuffer.String()) + return authURL + }, 30*time.Second, 500*time.Millisecond).ShouldNot(BeEmpty(), + "thv run should print 'Please open this URL in your browser'") + + By("Completing the OAuth flow via auto_complete") + autoURL := appendAutoComplete(authURL) + client := &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return nil + }, + } + resp, err := client.Get(autoURL) //nolint:gosec // URL is test-controlled + Expect(err).ToNot(HaveOccurred(), "GET to auto-complete URL should succeed") + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 400)) + + By("Waiting for the authorization request to be captured by the mock AS") + authReq, err := mockAS.WaitForAuthRequest(15 * time.Second) + Expect(err).ToNot(HaveOccurred(), "mock AS should receive an authorization request") + + By("Asserting client_id is NOT the CIMD metadata URL") + Expect(authReq.ClientID).ToNot(Equal(oauthproto.ToolHiveClientMetadataDocumentURL), + "thv run must not use the CIMD metadata URL when the AS does not advertise support") + + By("Asserting DCR WAS called") + // Give thv a moment to hit the DCR endpoint before asserting. + Eventually(mockAS.DcrWasCalled, 10*time.Second, 500*time.Millisecond).Should(BeTrue(), + "DCR registration endpoint must be called when CIMD is not advertised") + }) + }) +}) From 5513e963f98e2666567409161e3ec639209ca418 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 29 Apr 2026 17:57:55 +0500 Subject: [PATCH 08/12] Restore CachedCIMDClientID with reader, extract shouldUseCIMD, fix error check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit CachedCIMDClientID: - Re-introduce the field with a clear doc comment explaining why it is kept separate from CachedClientID (independent lifecycle — DCR credential rotation must not clear the stable CIMD URL) - Guard clientCredentialsPersister so CIMD URLs are NOT stored in CachedClientID; they go to CachedCIMDClientID instead - Implement the reader in resolveClientCredentials: CIMD URL is returned as client_id for token refresh on restart, checked before DCR credentials; this was the deferred reader Jakub flagged - ClearCachedClientCredentials does not clear CachedCIMDClientID (documented in its comment) shouldUseCIMD: - Extract the four-condition CIMD predicate into a named helper, mirroring shouldDynamicallyRegisterClient (jhrozek non-blocking) isCIMDRejectionError: - Replace strings.Contains with errors.As(*oauth2.RetrieveError) to avoid substring collisions and format coupling (jhrozek non-blocking) - Update TestIsCIMDRejectionError to use real *oauth2.RetrieveError values matching what golang.org/x/oauth2 actually returns Co-Authored-By: Claude Sonnet 4.6 (1M context) --- pkg/auth/remote/config.go | 13 ++++----- pkg/auth/remote/handler.go | 47 +++++++++++++++++++++++++-------- pkg/auth/remote/handler_test.go | 23 ++++++++++------ 3 files changed, 58 insertions(+), 25 deletions(-) diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index efb1935a1d..7b8f678369 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -72,11 +72,11 @@ type Config struct { // Stored as a secret reference since it's sensitive. CachedRegTokenRef string `json:"cached_reg_token_ref,omitempty" yaml:"cached_reg_token_ref,omitempty"` - // CachedCIMDClientID stores the CIMD URL used as client_id when CIMD was used - // for authentication. Non-sensitive — it is a public URL. Written on successful - // CIMD flows to distinguish from DCR-issued IDs. The reader that uses this to - // skip re-detection on restart is deferred to a follow-up; see issue #2728. - // Note: any HTTPS-shaped client_id is cached here (per IsClientIDMetadataDocumentURL). + // CachedCIMDClientID stores the CIMD metadata URL used as client_id when CIMD + // authentication was used. Kept separate from CachedClientID (which holds + // DCR-issued IDs) so the two can have independent lifecycles — DCR credential + // rotation clears CachedClientID without touching the stable CIMD URL. + // Read by resolveClientCredentials to send the correct client_id on token refresh. CachedCIMDClientID string `json:"cached_cimd_client_id,omitempty" yaml:"cached_cimd_client_id,omitempty"` } @@ -177,12 +177,13 @@ func (c *Config) HasCachedCIMDClientID() bool { } // ClearCachedClientCredentials removes any cached DCR client credential references from the config. +// It does not clear CachedCIMDClientID — the CIMD URL is a stable constant that does not +// need to be rotated alongside DCR secrets. func (c *Config) ClearCachedClientCredentials() { c.CachedClientID = "" c.CachedClientSecretRef = "" c.CachedSecretExpiry = time.Time{} c.CachedRegTokenRef = "" - c.CachedCIMDClientID = "" } // DefaultResourceIndicator derives the resource indicator (RFC 8707) from the remote server URL. diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index 2299bfa261..a94c4d5e44 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -5,9 +5,9 @@ package remote import ( "context" + "errors" "fmt" "log/slog" - "strings" "golang.org/x/oauth2" @@ -138,10 +138,7 @@ func (h *Handler) performOAuthFlow( // Priority 2: CIMD — AS advertises support and no credentials are set; use metadata URL as client_id. // Priority 3: DCR — PerformOAuthFlow handles this when ClientID is still empty after the above. flowConfig := h.buildOAuthFlowConfig(scopes, authServerInfo) - if authServerInfo != nil && - authServerInfo.ClientIDMetadataDocumentSupported && - flowConfig.ClientID == "" && - flowConfig.ClientSecret == "" { + if shouldUseCIMD(authServerInfo, flowConfig) { flowConfig.ClientID = oauthproto.ToolHiveClientMetadataDocumentURL slog.Debug("Using CIMD client_id", "url", oauthproto.ToolHiveClientMetadataDocumentURL) } @@ -206,7 +203,9 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. // Persist DCR client credentials if available (for servers that use Dynamic Client Registration) // Only persist if client_id exists - client_secret may be empty for PKCE flows - if h.clientCredentialsPersister != nil && result.ClientID != "" { + // CIMD client IDs (HTTPS URLs) are stable constants and are stored separately below. + if h.clientCredentialsPersister != nil && result.ClientID != "" && + !oauthproto.IsClientIDMetadataDocumentURL(result.ClientID) { if err := h.clientCredentialsPersister(result.ClientID, result.ClientSecret); err != nil { slog.Warn("Failed to persist DCR client credentials", "error", err) } else { @@ -214,11 +213,11 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. } } - // If CIMD was used (client_id is the metadata URL), persist it separately - // so it can be distinguished from a DCR-issued client_id on restart. + // Persist the CIMD metadata URL separately so it can be used as client_id + // on token refresh without conflating it with DCR-issued credentials. if oauthproto.IsClientIDMetadataDocumentURL(result.ClientID) { h.config.CachedCIMDClientID = result.ClientID - slog.Debug("CIMD client_id used, cached for reference", "url", result.ClientID) + slog.Debug("Persisted CIMD client_id for future restarts", "url", result.ClientID) } // Wrap the token source to persist refreshed tokens @@ -237,6 +236,15 @@ func (h *Handler) resolveClientCredentials(ctx context.Context) (clientID, clien clientID = h.config.ClientID clientSecret = h.config.ClientSecret + // If CIMD was used in a prior session, use the cached metadata URL as client_id. + // CIMD clients have no secret (token_endpoint_auth_method=none). + // Checked before DCR so that DCR credential rotation does not change which + // client_id is sent on token refresh. + if h.config.HasCachedCIMDClientID() { + slog.Debug("Using cached CIMD client_id", "url", h.config.CachedCIMDClientID) + return h.config.CachedCIMDClientID, "" + } + // If we have cached DCR client credentials, use those instead if h.config.HasCachedClientCredentials() { // ClientID is stored as plain text (it's public information) @@ -489,6 +497,16 @@ func (h *Handler) tryDiscoverFromWellKnown( return authServerInfo.Issuer, scopes, authServerInfo, nil } +// shouldUseCIMD reports whether the CIMD client_id should be presented to the AS. +// The AS must advertise CIMD support and no pre-configured credentials may be set. +// Mirrors shouldDynamicallyRegisterClient in pkg/auth/discovery for consistency. +func shouldUseCIMD(authServerInfo *discovery.AuthServerInfo, flowConfig *discovery.OAuthFlowConfig) bool { + if authServerInfo == nil || !authServerInfo.ClientIDMetadataDocumentSupported { + return false + } + return flowConfig.ClientID == "" && flowConfig.ClientSecret == "" +} + // isCIMDRejectionError returns true if err indicates the AS rejected the CIMD // client_id. Only the RFC 6749 error codes invalid_client and unauthorized_client // trigger a DCR retry; all other errors — including invalid_request and @@ -497,6 +515,13 @@ func isCIMDRejectionError(err error) bool { if err == nil { return false } - msg := err.Error() - return strings.Contains(msg, "invalid_client") || strings.Contains(msg, "unauthorized_client") + var rerr *oauth2.RetrieveError + if !errors.As(err, &rerr) { + return false + } + switch rerr.ErrorCode { + case "invalid_client", "unauthorized_client": + return true + } + return false } diff --git a/pkg/auth/remote/handler_test.go b/pkg/auth/remote/handler_test.go index 39696f80e9..2b24476146 100644 --- a/pkg/auth/remote/handler_test.go +++ b/pkg/auth/remote/handler_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" "github.com/stacklok/toolhive/pkg/auth/discovery" ) @@ -798,6 +799,12 @@ func TestAuthenticate_BearerTokenPriority(t *testing.T) { assert.Equal(t, "Bearer", token.TokenType) } +// retrieveErr constructs an *oauth2.RetrieveError with the given error code, +// matching what golang.org/x/oauth2 returns for token endpoint errors. +func retrieveErr(code string) *oauth2.RetrieveError { + return &oauth2.RetrieveError{ErrorCode: code} +} + // TestIsCIMDRejectionError covers the isCIMDRejectionError helper used in the CIMD retry path. func TestIsCIMDRejectionError(t *testing.T) { t.Parallel() @@ -814,17 +821,22 @@ func TestIsCIMDRejectionError(t *testing.T) { }, { name: "invalid_client triggers retry", - err: fmt.Errorf("oauth2: cannot fetch token: 400 Bad Request\nResponse: {\"error\":\"invalid_client\"}"), + err: retrieveErr("invalid_client"), want: true, }, { name: "unauthorized_client triggers retry", - err: fmt.Errorf("oauth2: cannot fetch token: 401 Unauthorized\nResponse: {\"error\":\"unauthorized_client\"}"), + err: retrieveErr("unauthorized_client"), want: true, }, { name: "invalid_request does not trigger retry", - err: fmt.Errorf("oauth2: cannot fetch token: 400 Bad Request\nResponse: {\"error\":\"invalid_request\"}"), + err: retrieveErr("invalid_request"), + want: false, + }, + { + name: "access_denied does not trigger retry", + err: retrieveErr("access_denied"), want: false, }, { @@ -837,11 +849,6 @@ func TestIsCIMDRejectionError(t *testing.T) { err: fmt.Errorf("OAuth flow timed out after 5m0s - user did not complete authentication"), want: false, }, - { - name: "access_denied does not trigger retry", - err: fmt.Errorf("oauth2: cannot fetch token: 403 Forbidden\nResponse: {\"error\":\"access_denied\"}"), - want: false, - }, } for _, tt := range tests { From e8f6c22822e19496886d7752321714e227ba6c7f Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 29 Apr 2026 18:34:14 +0500 Subject: [PATCH 09/12] Fix isCIMDRejectionError to handle auth-endpoint rejections CIMD rejection surfaces from two distinct stages: - Authorization endpoint: AS redirects callback with error=invalid_client; flow.go formats this as "OAuth error: - " (plain error, not *oauth2.RetrieveError). - Token endpoint: *oauth2.RetrieveError with ErrorCode set. The previous errors.As-only check missed auth-endpoint rejections entirely, meaning the DCR fallback never triggered when the AS rejected the CIMD URL at the authorize step. Fix by checking errors.As first (precise, for token endpoint), then strings.HasPrefix for the auth-endpoint format. Update TestIsCIMDRejectionError to cover both error sources. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- pkg/auth/remote/handler.go | 23 +++++++++++++++++------ pkg/auth/remote/handler_test.go | 22 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index a94c4d5e44..ce7d271288 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -8,6 +8,7 @@ import ( "errors" "fmt" "log/slog" + "strings" "golang.org/x/oauth2" @@ -511,17 +512,27 @@ func shouldUseCIMD(authServerInfo *discovery.AuthServerInfo, flowConfig *discove // client_id. Only the RFC 6749 error codes invalid_client and unauthorized_client // trigger a DCR retry; all other errors — including invalid_request and // token-exchange failures — surface as-is. +// +// CIMD rejection can surface from two stages: +// - Authorization endpoint: AS redirects to callback with error=invalid_client; +// flow.go formats this as "OAuth error: - " (a plain error). +// - Token endpoint: oauth2.RetrieveError with ErrorCode set. func isCIMDRejectionError(err error) bool { if err == nil { return false } + // Token endpoint rejection — structured error from golang.org/x/oauth2. var rerr *oauth2.RetrieveError - if !errors.As(err, &rerr) { + if errors.As(err, &rerr) { + switch rerr.ErrorCode { + case "invalid_client", "unauthorized_client": + return true + } return false } - switch rerr.ErrorCode { - case "invalid_client", "unauthorized_client": - return true - } - return false + // Authorization endpoint rejection — flow.go formats callback errors as + // "OAuth error: - ". Check for the code after the prefix. + msg := err.Error() + return strings.HasPrefix(msg, "OAuth error: invalid_client") || + strings.HasPrefix(msg, "OAuth error: unauthorized_client") } diff --git a/pkg/auth/remote/handler_test.go b/pkg/auth/remote/handler_test.go index 2b24476146..a57c7e874e 100644 --- a/pkg/auth/remote/handler_test.go +++ b/pkg/auth/remote/handler_test.go @@ -839,6 +839,28 @@ func TestIsCIMDRejectionError(t *testing.T) { err: retrieveErr("access_denied"), want: false, }, + // Authorization-endpoint rejections — flow.go format: "OAuth error: - " + { + name: "auth callback invalid_client triggers retry", + err: fmt.Errorf("OAuth error: invalid_client - client not recognised"), + want: true, + }, + { + name: "auth callback unauthorized_client triggers retry", + err: fmt.Errorf("OAuth error: unauthorized_client - not allowed"), + want: true, + }, + { + name: "auth callback invalid_request does not trigger retry", + err: fmt.Errorf("OAuth error: invalid_request - missing param"), + want: false, + }, + { + name: "auth callback access_denied does not trigger retry", + err: fmt.Errorf("OAuth error: access_denied - user denied"), + want: false, + }, + // Non-OAuth errors must not trigger retry. { name: "network error does not trigger retry", err: fmt.Errorf("dial tcp: connection refused"), From 129b8316499a7beaa529317845ba76dd38cc9b93 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 29 Apr 2026 19:34:08 +0500 Subject: [PATCH 10/12] Add follow-up CIMD E2E and unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit E2E (test/e2e/): - Extend cimdMockAuthServer with rejectCIMD option: first authorize request with a CIMD client_id returns error=invalid_client; subsequent requests with a DCR-issued client_id succeed normally - New test: "falls back to DCR when AS rejects the CIMD client_id" — verifies the full retry path end-to-end: CIMD attempted, rejected, DCR fallback fires, server becomes reachable Unit (pkg/auth/remote/handler_test.go): - TestResolveClientCredentials: 5 table-driven cases covering CachedCIMDClientID precedence over DCR and static credentials, empty-secret guarantee for CIMD, and DCR/static fallback when CachedCIMDClientID is empty Co-Authored-By: Claude Sonnet 4.6 (1M context) --- pkg/auth/remote/handler_test.go | 74 +++++++++++++++++++++++++ test/e2e/cimd_auth_helpers_test.go | 66 ++++++++++++++++++++-- test/e2e/cimd_auth_test.go | 88 +++++++++++++++++++++++++++++- 3 files changed, 220 insertions(+), 8 deletions(-) diff --git a/pkg/auth/remote/handler_test.go b/pkg/auth/remote/handler_test.go index a57c7e874e..b7761964d9 100644 --- a/pkg/auth/remote/handler_test.go +++ b/pkg/auth/remote/handler_test.go @@ -937,3 +937,77 @@ func TestAuthenticate_BearerTokenDiscovery(t *testing.T) { assert.Equal(t, "Bearer", token.TokenType) }) } + +// TestResolveClientCredentials verifies the credential selection priority in +// resolveClientCredentials: CachedCIMDClientID > CachedClientID (DCR) > +// statically-configured ClientID. +func TestResolveClientCredentials(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *Config + wantClientID string + wantClientSecret string + }{ + { + name: "CachedCIMDClientID takes precedence over DCR and static credentials", + config: &Config{ + ClientID: "static-client-id", + ClientSecret: "static-secret", + CachedClientID: "dcr-client-id", + CachedCIMDClientID: "https://toolhive.dev/oauth/client-metadata.json", + }, + wantClientID: "https://toolhive.dev/oauth/client-metadata.json", + wantClientSecret: "", + }, + { + name: "CachedCIMDClientID returns empty secret (token_endpoint_auth_method=none)", + config: &Config{ + CachedCIMDClientID: "https://toolhive.dev/oauth/client-metadata.json", + }, + wantClientID: "https://toolhive.dev/oauth/client-metadata.json", + wantClientSecret: "", + }, + { + // When CachedClientID is set the DCR client_id is used, but because + // CachedClientSecretRef is empty (no secret reference stored) the + // function falls through to the statically-configured ClientSecret. + name: "CachedClientID used when CachedCIMDClientID is empty", + config: &Config{ + ClientID: "static-client-id", + ClientSecret: "static-secret", + CachedClientID: "dcr-client-id", + }, + wantClientID: "dcr-client-id", + wantClientSecret: "static-secret", + }, + { + name: "static credentials used when no cached credentials exist", + config: &Config{ + ClientID: "static-client-id", + ClientSecret: "static-secret", + }, + wantClientID: "static-client-id", + wantClientSecret: "static-secret", + }, + { + name: "all empty returns empty strings", + config: &Config{}, + wantClientID: "", + wantClientSecret: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h := &Handler{config: tt.config} + gotClientID, gotClientSecret := h.resolveClientCredentials(context.Background()) + + assert.Equal(t, tt.wantClientID, gotClientID, "clientID mismatch") + assert.Equal(t, tt.wantClientSecret, gotClientSecret, "clientSecret mismatch") + }) + } +} diff --git a/test/e2e/cimd_auth_helpers_test.go b/test/e2e/cimd_auth_helpers_test.go index bf54b6977c..6af96eda2d 100644 --- a/test/e2e/cimd_auth_helpers_test.go +++ b/test/e2e/cimd_auth_helpers_test.go @@ -35,21 +35,27 @@ type cimdMockAuthServer struct { server *httptest.Server authRequestChan chan cimdAuthRequest - mu sync.Mutex - lastClientID string - dcrCalled bool - cimdSupported bool + mu sync.Mutex + lastClientID string + dcrCalled bool + cimdSupported bool + rejectCIMD bool + cimdRejectedOnce bool } // newCIMDMockAuthServer creates and starts a mock authorization server that // advertises client_id_metadata_document_supported. It registers t.Cleanup to -// close the server automatically. -func newCIMDMockAuthServer(tb testHelper, cimdSupported bool) *cimdMockAuthServer { +// close the server automatically. Pass rejectCIMD=true to make the server +// reject the first authorization request that uses a CIMD client_id (an HTTPS +// URL), simulating an AS that advertises CIMD support but rejects it at +// runtime, triggering the DCR fallback path in ToolHive. +func newCIMDMockAuthServer(tb testHelper, cimdSupported bool, rejectCIMD bool) *cimdMockAuthServer { tb.Helper() s := &cimdMockAuthServer{ authRequestChan: make(chan cimdAuthRequest, 4), cimdSupported: cimdSupported, + rejectCIMD: rejectCIMD, } mux := http.NewServeMux() @@ -124,9 +130,24 @@ func (s *cimdMockAuthServer) handleDiscovery(w http.ResponseWriter, _ *http.Requ _ = json.NewEncoder(w).Encode(doc) } +// RejectCIMDWasCalled returns true if the server rejected a CIMD client_id at +// least once. Callers use this to assert that the CIMD path was attempted +// before the DCR fallback fired. +func (s *cimdMockAuthServer) RejectCIMDWasCalled() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.cimdRejectedOnce +} + // handleAuthorize captures the authorization request and either immediately // redirects (when auto_complete=true) or places the request into the channel // for the test to inspect. +// +// When rejectCIMD is true, the first request whose client_id is an HTTPS URL +// (i.e. a CIMD metadata document URL) is rejected by redirecting to the +// callback with error=invalid_client. This simulates an AS that advertises +// CIMD support but rejects it at the authorization endpoint, triggering the +// DCR fallback path in ToolHive. func (s *cimdMockAuthServer) handleAuthorize(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() req := cimdAuthRequest{ @@ -138,6 +159,32 @@ func (s *cimdMockAuthServer) handleAuthorize(w http.ResponseWriter, r *http.Requ s.mu.Lock() s.lastClientID = req.ClientID + + // If rejectCIMD is armed and this is the first CIMD request, reject it. + // A CIMD client_id is any HTTPS URL (see oauthproto.IsClientIDMetadataDocumentURL). + if s.rejectCIMD && !s.cimdRejectedOnce && isCIMDClientID(req.ClientID) { + s.cimdRejectedOnce = true + s.mu.Unlock() + + redirectURI := req.RedirectURI + if redirectURI == "" { + http.Error(w, "missing redirect_uri", http.StatusBadRequest) + return + } + separator := "?" + for _, ch := range redirectURI { + if ch == '?' { + separator = "&" + break + } + } + http.Redirect(w, r, + fmt.Sprintf("%s%serror=invalid_client&state=%s&error_description=cimd+not+supported", + redirectURI, separator, req.State), + http.StatusFound, + ) + return + } s.mu.Unlock() // Always send into the channel so WaitForAuthRequest can inspect it. @@ -226,6 +273,13 @@ func (s *cimdMockAuthServer) handleResourceMetadata(w http.ResponseWriter, _ *ht _ = json.NewEncoder(w).Encode(meta) } +// isCIMDClientID returns true if clientID looks like a CIMD metadata document +// URL (i.e. any HTTPS URL). This mirrors oauthproto.IsClientIDMetadataDocumentURL +// without importing the production package from a test helper. +func isCIMDClientID(clientID string) bool { + return len(clientID) >= 8 && clientID[:8] == "https://" +} + // newCIMDMockMCPServer creates a minimal httptest MCP server that: // - Returns 401 with WWW-Authenticate header when there is no Authorization header. // - Returns a minimal JSON-RPC success response when an Authorization header is present. diff --git a/test/e2e/cimd_auth_test.go b/test/e2e/cimd_auth_test.go index 113cb804e4..b8d775eec3 100644 --- a/test/e2e/cimd_auth_test.go +++ b/test/e2e/cimd_auth_test.go @@ -92,7 +92,7 @@ var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial, Context("when the authorization server advertises CIMD support", func() { It("uses the CIMD client_id and skips DCR", func() { By("Starting mock authorization server with CIMD support enabled") - mockAS := newCIMDMockAuthServer(GinkgoT(), true) + mockAS := newCIMDMockAuthServer(GinkgoT(), true, false) By("Starting mock MCP server that requires authentication") mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL()) @@ -160,7 +160,7 @@ var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial, Context("when the authorization server does NOT advertise CIMD support", func() { It("falls back to DCR and does not use the CIMD client_id", func() { By("Starting mock authorization server with CIMD support disabled") - mockAS := newCIMDMockAuthServer(GinkgoT(), false) + mockAS := newCIMDMockAuthServer(GinkgoT(), false, false) By("Starting mock MCP server that requires authentication") mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL()) @@ -216,4 +216,88 @@ var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial, "DCR registration endpoint must be called when CIMD is not advertised") }) }) + + Context("CIMD fallback and warm-start behaviour", func() { + It("falls back to DCR when AS rejects the CIMD client_id", func() { + By("Starting mock authorization server: CIMD advertised but first CIMD request rejected") + mockAS := newCIMDMockAuthServer(GinkgoT(), true, true) + + By("Starting mock MCP server that requires authentication") + mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL()) + + serverName := e2e.GenerateUniqueServerName("cimd-reject-fallback") + + By("Starting thv run pointing at the mock MCP server") + cmd, outputBuffer := startCIMDRunCommand(config, serverName, mockMCP.URL, mockAS.IssuerURL()) + + defer func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + } + if config.CleanupAfter { + _ = e2e.StopAndRemoveMCPServer(config, serverName) + } + }() + + By("Waiting for the first OAuth URL (CIMD attempt) to appear in the output") + var firstAuthURL string + Eventually(func() string { + firstAuthURL = extractAuthURL(outputBuffer.String()) + return firstAuthURL + }, 30*time.Second, 500*time.Millisecond).ShouldNot(BeEmpty(), + "thv run should print 'Please open this URL in your browser' for the CIMD attempt") + + By("Visiting the first URL — the AS will redirect back with error=invalid_client") + client := &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return nil // follow redirects + }, + } + autoFirstURL := appendAutoComplete(firstAuthURL) + resp, err := client.Get(autoFirstURL) //nolint:gosec // URL is test-controlled + Expect(err).ToNot(HaveOccurred(), "GET to first auto-complete URL should not error") + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + // The redirect chain ends at the ToolHive callback; any 2xx/3xx is fine. + Expect(resp.StatusCode).To(BeNumerically("<", 500), + "redirect chain for CIMD rejection should not produce a server error") + + By("Asserting the mock AS registered the CIMD rejection") + Eventually(mockAS.RejectCIMDWasCalled, 10*time.Second, 500*time.Millisecond).Should(BeTrue(), + "mock AS must have rejected the CIMD client_id before DCR retry") + + By("Waiting for the second OAuth URL (DCR retry) to appear in the output") + var secondAuthURL string + Eventually(func() string { + out := outputBuffer.String() + // The second URL appears after the first; find the last occurrence. + allURLs := regexp.MustCompile(`Please open this URL in your browser: (https?://[^\s"]+)`). + FindAllStringSubmatch(out, -1) + if len(allURLs) >= 2 { + secondAuthURL = allURLs[len(allURLs)-1][1] + } + return secondAuthURL + }, 45*time.Second, 500*time.Millisecond).ShouldNot(BeEmpty(), + "thv run should print a second OAuth URL after the CIMD rejection triggers a DCR retry") + + By("Completing the DCR OAuth flow via auto_complete") + autoSecondURL := appendAutoComplete(secondAuthURL) + resp2, err := client.Get(autoSecondURL) //nolint:gosec // URL is test-controlled + Expect(err).ToNot(HaveOccurred(), "GET to second auto-complete URL should succeed") + _, _ = io.Copy(io.Discard, resp2.Body) + _ = resp2.Body.Close() + Expect(resp2.StatusCode).To(BeNumerically("<", 400), + "DCR auto-complete redirect chain should succeed") + + By("Asserting DCR was called during the retry") + Eventually(mockAS.DcrWasCalled, 10*time.Second, 500*time.Millisecond).Should(BeTrue(), + "DCR registration endpoint must be called after CIMD rejection") + + By("Waiting for thv to report the server as running") + err = e2e.WaitForMCPServer(config, serverName, 30*time.Second) + Expect(err).ToNot(HaveOccurred(), "server should appear as running in thv list after CIMD→DCR fallback") + }) + }) }) From 279bbec0b85a282ff414bb6fdca102640d4a5754 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Wed, 29 Apr 2026 20:38:45 +0500 Subject: [PATCH 11/12] Merge main and regenerate swagger docs Brings in changes from main (v0.25.0 release cycle) and regenerates the swagger documentation to match, fixing the Docs CI check. --- docs/server/docs.go | 4 ++++ docs/server/swagger.json | 4 ++++ docs/server/swagger.yaml | 8 ++++++++ 3 files changed, 16 insertions(+) diff --git a/docs/server/docs.go b/docs/server/docs.go index 761041e32e..29dca143e9 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -275,6 +275,10 @@ const docTemplate = `{ "bearer_token_file": { "type": "string" }, + "cached_cimd_client_id": { + "description": "CachedCIMDClientID stores the CIMD metadata URL used as client_id when CIMD\nauthentication was used. Kept separate from CachedClientID (which holds\nDCR-issued IDs) so the two can have independent lifecycles — DCR credential\nrotation clears CachedClientID without touching the stable CIMD URL.\nRead by resolveClientCredentials to send the correct client_id on token refresh.", + "type": "string" + }, "cached_client_id": { "description": "Cached DCR client credentials for persistence across restarts.\nThese are obtained during Dynamic Client Registration and needed to refresh tokens.\nClientID is stored as plain text since it's public information.", "type": "string" diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 91b6e3b3aa..4b90283025 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -268,6 +268,10 @@ "bearer_token_file": { "type": "string" }, + "cached_cimd_client_id": { + "description": "CachedCIMDClientID stores the CIMD metadata URL used as client_id when CIMD\nauthentication was used. Kept separate from CachedClientID (which holds\nDCR-issued IDs) so the two can have independent lifecycles — DCR credential\nrotation clears CachedClientID without touching the stable CIMD URL.\nRead by resolveClientCredentials to send the correct client_id on token refresh.", + "type": "string" + }, "cached_client_id": { "description": "Cached DCR client credentials for persistence across restarts.\nThese are obtained during Dynamic Client Registration and needed to refresh tokens.\nClientID is stored as plain text since it's public information.", "type": "string" diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index f1118124d0..9ec55b9211 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -280,6 +280,14 @@ components: type: string bearer_token_file: type: string + cached_cimd_client_id: + description: |- + CachedCIMDClientID stores the CIMD metadata URL used as client_id when CIMD + authentication was used. Kept separate from CachedClientID (which holds + DCR-issued IDs) so the two can have independent lifecycles — DCR credential + rotation clears CachedClientID without touching the stable CIMD URL. + Read by resolveClientCredentials to send the correct client_id on token refresh. + type: string cached_client_id: description: |- Cached DCR client credentials for persistence across restarts. From cfa88509c9618588d139f26a7bca9b1a7f4440d7 Mon Sep 17 00:00:00 2001 From: amirejaz Date: Thu, 30 Apr 2026 02:41:20 +0500 Subject: [PATCH 12/12] Merge CIMD flag from OAuth AS metadata when OIDC doc lacks it Some authorization servers (e.g. Granola) advertise client_id_metadata_document_supported only in their RFC 8414 /.well-known/oauth-authorization-server document, not in the OIDC /.well-known/openid-configuration document. When discoverOIDCEndpointsWithClientAndValidation falls back to the OAuth AS metadata to fetch a missing registration_endpoint, it now also merges ClientIDMetadataDocumentSupported so CIMD detection works correctly for servers that split their metadata across the two well-known endpoints. Tested against Granola (https://mcp-auth.granola.ai): CIMD URL now appears as client_id in the authorize request instead of a DCR-issued opaque identifier. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- pkg/auth/oauth/oidc.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pkg/auth/oauth/oidc.go b/pkg/auth/oauth/oidc.go index fe037352d8..8dbfd3f21e 100644 --- a/pkg/auth/oauth/oidc.go +++ b/pkg/auth/oauth/oidc.go @@ -124,6 +124,12 @@ func discoverOIDCEndpointsWithClientAndValidation( if oauthDoc.Issuer == doc.Issuer { doc.RegistrationEndpoint = oauthDoc.RegistrationEndpoint slog.Debug("Found registration_endpoint in OAuth authorization server metadata", "endpoint", doc.RegistrationEndpoint) + // Merge CIMD support flag — some servers (e.g. Granola) only advertise + // client_id_metadata_document_supported in the OAuth AS metadata, not + // in the OIDC discovery document. + if oauthDoc.ClientIDMetadataDocumentSupported { + doc.ClientIDMetadataDocumentSupported = true + } } else { slog.Warn("Issuer mismatch between OIDC and OAuth discovery documents, not merging registration_endpoint", "oidc_issuer", doc.Issuer, "oauth_issuer", oauthDoc.Issuer)