diff --git a/docs/server/docs.go b/docs/server/docs.go index aef6944614..6438a99073 100644 --- a/docs/server/docs.go +++ b/docs/server/docs.go @@ -275,6 +275,10 @@ const docTemplate = `{ "bearer_token_file": { "type": "string" }, + "cached_cimd_client_id": { + "description": "CachedCIMDClientID stores the CIMD metadata URL used as client_id when CIMD\nauthentication was used. Kept separate from CachedClientID (which holds\nDCR-issued IDs) so the two can have independent lifecycles — DCR credential\nrotation clears CachedClientID without touching the stable CIMD URL.\nRead by resolveClientCredentials to send the correct client_id on token refresh.", + "type": "string" + }, "cached_client_id": { "description": "Cached DCR client credentials for persistence across restarts.\nThese are obtained during Dynamic Client Registration and needed to refresh tokens.\nClientID is stored as plain text since it's public information.", "type": "string" diff --git a/docs/server/swagger.json b/docs/server/swagger.json index 082a1611d4..1c3fe7bf2c 100644 --- a/docs/server/swagger.json +++ b/docs/server/swagger.json @@ -268,6 +268,10 @@ "bearer_token_file": { "type": "string" }, + "cached_cimd_client_id": { + "description": "CachedCIMDClientID stores the CIMD metadata URL used as client_id when CIMD\nauthentication was used. Kept separate from CachedClientID (which holds\nDCR-issued IDs) so the two can have independent lifecycles — DCR credential\nrotation clears CachedClientID without touching the stable CIMD URL.\nRead by resolveClientCredentials to send the correct client_id on token refresh.", + "type": "string" + }, "cached_client_id": { "description": "Cached DCR client credentials for persistence across restarts.\nThese are obtained during Dynamic Client Registration and needed to refresh tokens.\nClientID is stored as plain text since it's public information.", "type": "string" diff --git a/docs/server/swagger.yaml b/docs/server/swagger.yaml index 6d973f5cea..6ebca4fea1 100644 --- a/docs/server/swagger.yaml +++ b/docs/server/swagger.yaml @@ -280,6 +280,14 @@ components: type: string bearer_token_file: type: string + cached_cimd_client_id: + description: |- + CachedCIMDClientID stores the CIMD metadata URL used as client_id when CIMD + authentication was used. Kept separate from CachedClientID (which holds + DCR-issued IDs) so the two can have independent lifecycles — DCR credential + rotation clears CachedClientID without touching the stable CIMD URL. + Read by resolveClientCredentials to send the correct client_id on token refresh. + type: string cached_client_id: description: |- Cached DCR client credentials for persistence across restarts. diff --git a/pkg/auth/discovery/discovery.go b/pkg/auth/discovery/discovery.go index 748875fd8d..78677a7749 100644 --- a/pkg/auth/discovery/discovery.go +++ b/pkg/auth/discovery/discovery.go @@ -53,10 +53,11 @@ type AuthInfo struct { // AuthServerInfo contains information about a validated authorization server type AuthServerInfo struct { - Issuer string - AuthorizationURL string - TokenURL string - RegistrationEndpoint string + Issuer string + AuthorizationURL string + TokenURL string + RegistrationEndpoint string + ClientIDMetadataDocumentSupported bool } // Config holds configuration for authentication discovery @@ -539,9 +540,10 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi return nil, fmt.Errorf("OAuth flow config cannot be nil") } - // Resolve port availability BEFORE dynamic registration - // This ensures we register the OAuth client with the same port we'll actually use - + // Resolve port availability before registration. DCR clients allow port fallback + // because the actual port is registered after selection. Pre-registered and CIMD + // clients require the configured port to be available as-is — it is already + // published in their IdP application or metadata document redirect URI. if shouldDynamicallyRegisterClient(config) { // For dynamic registration, we can allow fallback to alternative ports // since we can register the client with the actual port we'll use @@ -555,8 +557,9 @@ func PerformOAuthFlow(ctx context.Context, issuer string, config *OAuthFlowConfi } config.CallbackPort = port } else { - // For pre-registered clients, use strict port checking - // The user likely configured this port in their IdP/app + // For pre-registered clients and CIMD, use strict port checking. + // The port is either configured in the IdP app or baked into the + // redirect URI in the hosted metadata document. if !networking.IsAvailable(config.CallbackPort) { return nil, fmt.Errorf( "specified auth callback port %d is not available - please choose a different port or ensure it's not in use", @@ -836,10 +839,11 @@ func ValidateAndDiscoverAuthServer(ctx context.Context, potentialIssuer string) } return &AuthServerInfo{ - Issuer: doc.Issuer, - AuthorizationURL: doc.AuthorizationEndpoint, - TokenURL: doc.TokenEndpoint, - RegistrationEndpoint: doc.RegistrationEndpoint, + Issuer: doc.Issuer, + AuthorizationURL: doc.AuthorizationEndpoint, + TokenURL: doc.TokenEndpoint, + RegistrationEndpoint: doc.RegistrationEndpoint, + ClientIDMetadataDocumentSupported: doc.ClientIDMetadataDocumentSupported, }, nil } diff --git a/pkg/auth/oauth/oidc.go b/pkg/auth/oauth/oidc.go index fe037352d8..8dbfd3f21e 100644 --- a/pkg/auth/oauth/oidc.go +++ b/pkg/auth/oauth/oidc.go @@ -124,6 +124,12 @@ func discoverOIDCEndpointsWithClientAndValidation( if oauthDoc.Issuer == doc.Issuer { doc.RegistrationEndpoint = oauthDoc.RegistrationEndpoint slog.Debug("Found registration_endpoint in OAuth authorization server metadata", "endpoint", doc.RegistrationEndpoint) + // Merge CIMD support flag — some servers (e.g. Granola) only advertise + // client_id_metadata_document_supported in the OAuth AS metadata, not + // in the OIDC discovery document. + if oauthDoc.ClientIDMetadataDocumentSupported { + doc.ClientIDMetadataDocumentSupported = true + } } else { slog.Warn("Issuer mismatch between OIDC and OAuth discovery documents, not merging registration_endpoint", "oidc_issuer", doc.Issuer, "oauth_issuer", oauthDoc.Issuer) diff --git a/pkg/auth/remote/config.go b/pkg/auth/remote/config.go index 98c8fecd2f..7b8f678369 100644 --- a/pkg/auth/remote/config.go +++ b/pkg/auth/remote/config.go @@ -71,6 +71,13 @@ type Config struct { // RegistrationAccessToken is used to update/delete the client registration. // Stored as a secret reference since it's sensitive. CachedRegTokenRef string `json:"cached_reg_token_ref,omitempty" yaml:"cached_reg_token_ref,omitempty"` + + // CachedCIMDClientID stores the CIMD metadata URL used as client_id when CIMD + // authentication was used. Kept separate from CachedClientID (which holds + // DCR-issued IDs) so the two can have independent lifecycles — DCR credential + // rotation clears CachedClientID without touching the stable CIMD URL. + // Read by resolveClientCredentials to send the correct client_id on token refresh. + CachedCIMDClientID string `json:"cached_cimd_client_id,omitempty" yaml:"cached_cimd_client_id,omitempty"` } // BearerTokenEnvVarName is the environment variable name used for bearer token authentication. @@ -164,7 +171,14 @@ func (c *Config) HasCachedClientCredentials() bool { return c.CachedClientID != "" } +// HasCachedCIMDClientID returns true if a CIMD client_id was cached from a prior session. +func (c *Config) HasCachedCIMDClientID() bool { + return c.CachedCIMDClientID != "" +} + // ClearCachedClientCredentials removes any cached DCR client credential references from the config. +// It does not clear CachedCIMDClientID — the CIMD URL is a stable constant that does not +// need to be rotated alongside DCR secrets. func (c *Config) ClearCachedClientCredentials() { c.CachedClientID = "" c.CachedClientSecretRef = "" diff --git a/pkg/auth/remote/config_test.go b/pkg/auth/remote/config_test.go index c6c014ed26..186330a58f 100644 --- a/pkg/auth/remote/config_test.go +++ b/pkg/auth/remote/config_test.go @@ -204,6 +204,37 @@ func TestConfig_HasCachedClientCredentials(t *testing.T) { } } +func TestConfig_HasCachedCIMDClientID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config Config + expected bool + }{ + { + name: "no cached CIMD client_id", + config: Config{}, + expected: false, + }, + { + name: "has cached CIMD client_id", + config: Config{ + CachedCIMDClientID: "https://toolhive.dev/oauth/client-metadata.json", + }, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := tt.config.HasCachedCIMDClientID() + assert.Equal(t, tt.expected, result) + }) + } +} + func TestConfig_ClearCachedClientCredentials(t *testing.T) { t.Parallel() diff --git a/pkg/auth/remote/handler.go b/pkg/auth/remote/handler.go index ccce3fb17b..ce7d271288 100644 --- a/pkg/auth/remote/handler.go +++ b/pkg/auth/remote/handler.go @@ -5,12 +5,15 @@ package remote import ( "context" + "errors" "fmt" "log/slog" + "strings" "golang.org/x/oauth2" "github.com/stacklok/toolhive/pkg/auth/discovery" + "github.com/stacklok/toolhive/pkg/oauthproto" "github.com/stacklok/toolhive/pkg/secrets" ) @@ -131,15 +134,29 @@ func (h *Handler) performOAuthFlow( ) (oauth2.TokenSource, error) { slog.Debug("Starting OAuth authentication flow", "issuer", issuer) - // Create OAuth flow config + // Client registration priority (MCP spec: stored credentials → CIMD → DCR): + // Priority 1: Pre-configured credentials — set by buildOAuthFlowConfig from h.config.ClientID/ClientSecret. + // Priority 2: CIMD — AS advertises support and no credentials are set; use metadata URL as client_id. + // Priority 3: DCR — PerformOAuthFlow handles this when ClientID is still empty after the above. flowConfig := h.buildOAuthFlowConfig(scopes, authServerInfo) + if shouldUseCIMD(authServerInfo, flowConfig) { + flowConfig.ClientID = oauthproto.ToolHiveClientMetadataDocumentURL + slog.Debug("Using CIMD client_id", "url", oauthproto.ToolHiveClientMetadataDocumentURL) + } result, err := discovery.PerformOAuthFlow(ctx, issuer, flowConfig) + if err != nil { + // If we used CIMD and it was rejected, we need to retry with DCR. + if flowConfig.ClientID == oauthproto.ToolHiveClientMetadataDocumentURL && isCIMDRejectionError(err) { + slog.Warn("CIMD client_id rejected by AS, retrying with DCR", "issuer", issuer, "error", err) + flowConfig.ClientID = "" + result, err = discovery.PerformOAuthFlow(ctx, issuer, flowConfig) + } + } if err != nil { return nil, err } - // Persist and wrap the token source return h.wrapWithPersistence(result), nil } @@ -187,7 +204,9 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. // Persist DCR client credentials if available (for servers that use Dynamic Client Registration) // Only persist if client_id exists - client_secret may be empty for PKCE flows - if h.clientCredentialsPersister != nil && result.ClientID != "" { + // CIMD client IDs (HTTPS URLs) are stable constants and are stored separately below. + if h.clientCredentialsPersister != nil && result.ClientID != "" && + !oauthproto.IsClientIDMetadataDocumentURL(result.ClientID) { if err := h.clientCredentialsPersister(result.ClientID, result.ClientSecret); err != nil { slog.Warn("Failed to persist DCR client credentials", "error", err) } else { @@ -195,6 +214,13 @@ func (h *Handler) wrapWithPersistence(result *discovery.OAuthFlowResult) oauth2. } } + // Persist the CIMD metadata URL separately so it can be used as client_id + // on token refresh without conflating it with DCR-issued credentials. + if oauthproto.IsClientIDMetadataDocumentURL(result.ClientID) { + h.config.CachedCIMDClientID = result.ClientID + slog.Debug("Persisted CIMD client_id for future restarts", "url", result.ClientID) + } + // Wrap the token source to persist refreshed tokens tokenSource := result.TokenSource if h.tokenPersister != nil { @@ -211,6 +237,15 @@ func (h *Handler) resolveClientCredentials(ctx context.Context) (clientID, clien clientID = h.config.ClientID clientSecret = h.config.ClientSecret + // If CIMD was used in a prior session, use the cached metadata URL as client_id. + // CIMD clients have no secret (token_endpoint_auth_method=none). + // Checked before DCR so that DCR credential rotation does not change which + // client_id is sent on token refresh. + if h.config.HasCachedCIMDClientID() { + slog.Debug("Using cached CIMD client_id", "url", h.config.CachedCIMDClientID) + return h.config.CachedCIMDClientID, "" + } + // If we have cached DCR client credentials, use those instead if h.config.HasCachedClientCredentials() { // ClientID is stored as plain text (it's public information) @@ -317,18 +352,23 @@ func (h *Handler) discoverIssuerAndScopes( authInfo *discovery.AuthInfo, remoteURL string, ) (string, []string, *discovery.AuthServerInfo, error) { - // Priority 1: Use configured issuer if available + // Priority 1: Use configured issuer if available. Fetch discovery to populate + // AuthServerInfo (including ClientIDMetadataDocumentSupported) even when the + // issuer is pre-configured, so CIMD detection works on this path. if h.config.Issuer != "" { slog.Debug("Using configured issuer", "issuer", h.config.Issuer) - return h.config.Issuer, h.config.Scopes, nil, nil + authServerInfo, _ := discovery.ValidateAndDiscoverAuthServer(ctx, h.config.Issuer) + return h.config.Issuer, h.config.Scopes, authServerInfo, nil } - // Priority 2: Try to derive from realm (RFC 8414) + // Priority 2: Try to derive from realm (RFC 8414). Fetch discovery for the + // same reason as Priority 1 — the realm path skips resource metadata discovery. if authInfo.Realm != "" { derivedIssuer := discovery.DeriveIssuerFromRealm(authInfo.Realm) if derivedIssuer != "" { slog.Debug("Derived issuer from realm", "issuer", derivedIssuer) - return derivedIssuer, h.config.Scopes, nil, nil + authServerInfo, _ := discovery.ValidateAndDiscoverAuthServer(ctx, derivedIssuer) + return derivedIssuer, h.config.Scopes, authServerInfo, nil } } @@ -457,3 +497,42 @@ func (h *Handler) tryDiscoverFromWellKnown( return authServerInfo.Issuer, scopes, authServerInfo, nil } + +// shouldUseCIMD reports whether the CIMD client_id should be presented to the AS. +// The AS must advertise CIMD support and no pre-configured credentials may be set. +// Mirrors shouldDynamicallyRegisterClient in pkg/auth/discovery for consistency. +func shouldUseCIMD(authServerInfo *discovery.AuthServerInfo, flowConfig *discovery.OAuthFlowConfig) bool { + if authServerInfo == nil || !authServerInfo.ClientIDMetadataDocumentSupported { + return false + } + return flowConfig.ClientID == "" && flowConfig.ClientSecret == "" +} + +// isCIMDRejectionError returns true if err indicates the AS rejected the CIMD +// client_id. Only the RFC 6749 error codes invalid_client and unauthorized_client +// trigger a DCR retry; all other errors — including invalid_request and +// token-exchange failures — surface as-is. +// +// CIMD rejection can surface from two stages: +// - Authorization endpoint: AS redirects to callback with error=invalid_client; +// flow.go formats this as "OAuth error: - " (a plain error). +// - Token endpoint: oauth2.RetrieveError with ErrorCode set. +func isCIMDRejectionError(err error) bool { + if err == nil { + return false + } + // Token endpoint rejection — structured error from golang.org/x/oauth2. + var rerr *oauth2.RetrieveError + if errors.As(err, &rerr) { + switch rerr.ErrorCode { + case "invalid_client", "unauthorized_client": + return true + } + return false + } + // Authorization endpoint rejection — flow.go formats callback errors as + // "OAuth error: - ". Check for the code after the prefix. + msg := err.Error() + return strings.HasPrefix(msg, "OAuth error: invalid_client") || + strings.HasPrefix(msg, "OAuth error: unauthorized_client") +} diff --git a/pkg/auth/remote/handler_test.go b/pkg/auth/remote/handler_test.go index d68b465ed7..a57c7e874e 100644 --- a/pkg/auth/remote/handler_test.go +++ b/pkg/auth/remote/handler_test.go @@ -6,6 +6,7 @@ package remote import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" "strings" @@ -14,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/oauth2" "github.com/stacklok/toolhive/pkg/auth/discovery" ) @@ -797,6 +799,88 @@ func TestAuthenticate_BearerTokenPriority(t *testing.T) { assert.Equal(t, "Bearer", token.TokenType) } +// retrieveErr constructs an *oauth2.RetrieveError with the given error code, +// matching what golang.org/x/oauth2 returns for token endpoint errors. +func retrieveErr(code string) *oauth2.RetrieveError { + return &oauth2.RetrieveError{ErrorCode: code} +} + +// TestIsCIMDRejectionError covers the isCIMDRejectionError helper used in the CIMD retry path. +func TestIsCIMDRejectionError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + { + name: "nil error returns false", + err: nil, + want: false, + }, + { + name: "invalid_client triggers retry", + err: retrieveErr("invalid_client"), + want: true, + }, + { + name: "unauthorized_client triggers retry", + err: retrieveErr("unauthorized_client"), + want: true, + }, + { + name: "invalid_request does not trigger retry", + err: retrieveErr("invalid_request"), + want: false, + }, + { + name: "access_denied does not trigger retry", + err: retrieveErr("access_denied"), + want: false, + }, + // Authorization-endpoint rejections — flow.go format: "OAuth error: - " + { + name: "auth callback invalid_client triggers retry", + err: fmt.Errorf("OAuth error: invalid_client - client not recognised"), + want: true, + }, + { + name: "auth callback unauthorized_client triggers retry", + err: fmt.Errorf("OAuth error: unauthorized_client - not allowed"), + want: true, + }, + { + name: "auth callback invalid_request does not trigger retry", + err: fmt.Errorf("OAuth error: invalid_request - missing param"), + want: false, + }, + { + name: "auth callback access_denied does not trigger retry", + err: fmt.Errorf("OAuth error: access_denied - user denied"), + want: false, + }, + // Non-OAuth errors must not trigger retry. + { + name: "network error does not trigger retry", + err: fmt.Errorf("dial tcp: connection refused"), + want: false, + }, + { + name: "timeout error does not trigger retry", + err: fmt.Errorf("OAuth flow timed out after 5m0s - user did not complete authentication"), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tt.want, isCIMDRejectionError(tt.err)) + }) + } +} + // TestAuthenticate_BearerTokenDiscovery tests that bearer token discovery works correctly func TestAuthenticate_BearerTokenDiscovery(t *testing.T) { t.Parallel() diff --git a/pkg/oauthproto/cimd.go b/pkg/oauthproto/cimd.go new file mode 100644 index 0000000000..ef8540794a --- /dev/null +++ b/pkg/oauthproto/cimd.go @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package oauthproto + +import "strings" + +// ToolHiveClientMetadataDocumentURL is the stable HTTPS URL where ToolHive's +// client metadata document is hosted. ToolHive presents this URL as its +// client_id to remote authorization servers that support CIMD. The URL must +// be live and serving the client metadata document before this feature can +// be used in production. +const ToolHiveClientMetadataDocumentURL = "https://toolhive.dev/oauth/client-metadata.json" + +// IsClientIDMetadataDocumentURL returns true if clientID is an HTTPS URL. +// Any HTTPS URL is treated as a CIMD client_id; DCR-issued IDs are always +// opaque strings that never begin with "https://". Do not tighten this to an +// exact match against ToolHiveClientMetadataDocumentURL — the embedded AS +// (Phase 2) must accept CIMD URLs from third-party clients too. +// +// TODO(phase2): tighten per draft-ietf-oauth-client-id-metadata-document §3 +// (require host+path, reject fragment/userinfo/dot-segments) before wiring +// into the AS GetClient decorator. +func IsClientIDMetadataDocumentURL(clientID string) bool { + return strings.HasPrefix(clientID, "https://") +} diff --git a/pkg/oauthproto/cimd_test.go b/pkg/oauthproto/cimd_test.go new file mode 100644 index 0000000000..5e8df1bc8f --- /dev/null +++ b/pkg/oauthproto/cimd_test.go @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package oauthproto + +import ( + "testing" +) + +func TestToolHiveClientMetadataDocumentURL(t *testing.T) { + t.Parallel() + + const want = "https://toolhive.dev/oauth/client-metadata.json" + if ToolHiveClientMetadataDocumentURL != want { + t.Errorf("ToolHiveClientMetadataDocumentURL = %q, want %q", ToolHiveClientMetadataDocumentURL, want) + } +} + +func TestIsClientIDMetadataDocumentURL(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + clientID string + want bool + }{ + {"CIMD URL (toolhive)", ToolHiveClientMetadataDocumentURL, true}, + {"arbitrary HTTPS URL", "https://example.com/client-metadata.json", true}, + {"HTTPS URL no path", "https://example.com", true}, + {"DCR-issued UUID", "some-uuid-client-id", false}, + {"HTTP URL", "http://example.com/metadata.json", false}, + {"empty string", "", false}, + {"partial match", "xhttps://example.com", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if got := IsClientIDMetadataDocumentURL(tt.clientID); got != tt.want { + t.Errorf("IsClientIDMetadataDocumentURL(%q) = %v, want %v", tt.clientID, got, tt.want) + } + }) + } +} diff --git a/pkg/oauthproto/discovery.go b/pkg/oauthproto/discovery.go index 042f7291f7..4aef8f6e9a 100644 --- a/pkg/oauthproto/discovery.go +++ b/pkg/oauthproto/discovery.go @@ -298,6 +298,10 @@ type AuthorizationServerMetadata struct { // ScopesSupported lists the OAuth 2.0 scope values supported (RECOMMENDED per RFC 8414). // For MCP authorization servers, this typically includes "openid" and "offline_access". ScopesSupported []string `json:"scopes_supported,omitempty"` + + // ClientIDMetadataDocumentSupported indicates the server accepts HTTPS URLs as client_id + // values per draft-ietf-oauth-client-id-metadata-document. + ClientIDMetadataDocumentSupported bool `json:"client_id_metadata_document_supported,omitempty"` } // OIDCDiscoveryDocument represents the OpenID Connect Discovery 1.0 document. diff --git a/test/e2e/cimd_auth_helpers_test.go b/test/e2e/cimd_auth_helpers_test.go new file mode 100644 index 0000000000..bf54b6977c --- /dev/null +++ b/test/e2e/cimd_auth_helpers_test.go @@ -0,0 +1,257 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package e2e_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "sync" + "time" +) + +// testHelper is a minimal subset of testing.TB and ginkgo.GinkgoTInterface that +// the CIMD mock server helpers require. Both *testing.T and GinkgoT() satisfy +// this interface, so helpers can be called from plain Go tests and Ginkgo specs. +type testHelper interface { + Helper() + Cleanup(func()) +} + +// cimdAuthRequest captures parameters from an OAuth authorization request. +type cimdAuthRequest struct { + ClientID string + RedirectURI string + State string + CodeChallenge string +} + +// cimdMockAuthServer is a minimal httptest-based mock authorization server +// for CIMD testing. Unlike OIDCMockServer (Fosite-backed), this server accepts +// any HTTPS URL as a client_id, which is required to verify CIMD behaviour. +type cimdMockAuthServer struct { + server *httptest.Server + authRequestChan chan cimdAuthRequest + + mu sync.Mutex + lastClientID string + dcrCalled bool + cimdSupported bool +} + +// newCIMDMockAuthServer creates and starts a mock authorization server that +// advertises client_id_metadata_document_supported. It registers t.Cleanup to +// close the server automatically. +func newCIMDMockAuthServer(tb testHelper, cimdSupported bool) *cimdMockAuthServer { + tb.Helper() + + s := &cimdMockAuthServer{ + authRequestChan: make(chan cimdAuthRequest, 4), + cimdSupported: cimdSupported, + } + + mux := http.NewServeMux() + mux.HandleFunc("/.well-known/openid-configuration", s.handleDiscovery) + mux.HandleFunc("/oauth/authorize", s.handleAuthorize) + mux.HandleFunc("/oauth/token", s.handleToken) + mux.HandleFunc("/oauth/register", s.handleRegister) + mux.HandleFunc("/.well-known/jwks.json", s.handleJWKS) + mux.HandleFunc("/.well-known/mcp-resource", s.handleResourceMetadata) + + s.server = httptest.NewServer(mux) + tb.Cleanup(s.server.Close) + + return s +} + +// URL returns the base URL of the mock authorization server. +func (s *cimdMockAuthServer) URL() string { + return s.server.URL +} + +// IssuerURL returns the issuer URL (same as URL for this mock). +func (s *cimdMockAuthServer) IssuerURL() string { + return s.server.URL +} + +// ResourceMetadataURL returns the RFC 9728 resource metadata URL for this server. +func (s *cimdMockAuthServer) ResourceMetadataURL() string { + return fmt.Sprintf("%s/.well-known/mcp-resource", s.server.URL) +} + +// WaitForAuthRequest blocks until an authorization request arrives or the timeout +// elapses. +func (s *cimdMockAuthServer) WaitForAuthRequest(timeout time.Duration) (cimdAuthRequest, error) { + select { + case req := <-s.authRequestChan: + return req, nil + case <-time.After(timeout): + return cimdAuthRequest{}, fmt.Errorf("timeout waiting for auth request after %s", timeout) + } +} + +// DcrWasCalled returns true if the DCR /oauth/register endpoint was ever called. +func (s *cimdMockAuthServer) DcrWasCalled() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.dcrCalled +} + +// LastClientID returns the most recent client_id seen in /oauth/authorize. +func (s *cimdMockAuthServer) LastClientID() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.lastClientID +} + +// handleDiscovery serves the OIDC discovery document. It sets +// client_id_metadata_document_supported based on the server's configuration. +func (s *cimdMockAuthServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) { + doc := map[string]interface{}{ + "issuer": s.server.URL, + "authorization_endpoint": fmt.Sprintf("%s/oauth/authorize", s.server.URL), + "token_endpoint": fmt.Sprintf("%s/oauth/token", s.server.URL), + "registration_endpoint": fmt.Sprintf("%s/oauth/register", s.server.URL), + "jwks_uri": fmt.Sprintf("%s/.well-known/jwks.json", s.server.URL), + "code_challenge_methods_supported": []string{"S256"}, + "response_types_supported": []string{"code"}, + "grant_types_supported": []string{"authorization_code", "refresh_token"}, + "client_id_metadata_document_supported": s.cimdSupported, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(doc) +} + +// handleAuthorize captures the authorization request and either immediately +// redirects (when auto_complete=true) or places the request into the channel +// for the test to inspect. +func (s *cimdMockAuthServer) handleAuthorize(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + req := cimdAuthRequest{ + ClientID: q.Get("client_id"), + RedirectURI: q.Get("redirect_uri"), + State: q.Get("state"), + CodeChallenge: q.Get("code_challenge"), + } + + s.mu.Lock() + s.lastClientID = req.ClientID + s.mu.Unlock() + + // Always send into the channel so WaitForAuthRequest can inspect it. + select { + case s.authRequestChan <- req: + default: + // Channel buffer full; drop the duplicate. + } + + if q.Get("auto_complete") == "true" { + redirectURI := req.RedirectURI + if redirectURI == "" { + http.Error(w, "missing redirect_uri", http.StatusBadRequest) + return + } + separator := "&" + if len(q.Get("redirect_uri")) > 0 { + // redirect_uri itself may or may not have a query string already; + // we append to it by adding a '?' if needed. + separator = "?" + for _, ch := range redirectURI { + if ch == '?' { + separator = "&" + break + } + } + } + http.Redirect(w, r, + fmt.Sprintf("%s%scode=test-auth-code&state=%s", redirectURI, separator, req.State), + http.StatusFound, + ) + return + } + + // Without auto_complete the test must drive the flow externally. + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("authorization pending")) +} + +// handleToken accepts any code=test-auth-code and returns a minimal access token. +func (*cimdMockAuthServer) handleToken(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + tokenResp := map[string]interface{}{ + "access_token": "test-access-token-cimd", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "test-refresh-token-cimd", + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(tokenResp) +} + +// handleRegister is the DCR endpoint. Calling it records that DCR was used. +func (s *cimdMockAuthServer) handleRegister(w http.ResponseWriter, _ *http.Request) { + s.mu.Lock() + s.dcrCalled = true + s.mu.Unlock() + + resp := map[string]interface{}{ + "client_id": "dcr-issued-client-id", + "client_secret": "dcr-issued-secret", + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(resp) +} + +// handleJWKS returns an empty JWKS set. +func (*cimdMockAuthServer) handleJWKS(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"keys":[]}`)) +} + +// handleResourceMetadata returns RFC 9728 protected resource metadata pointing +// at this authorization server. +func (s *cimdMockAuthServer) handleResourceMetadata(w http.ResponseWriter, _ *http.Request) { + meta := map[string]interface{}{ + "resource": s.server.URL, + "authorization_servers": []string{s.server.URL}, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(meta) +} + +// newCIMDMockMCPServer creates a minimal httptest MCP server that: +// - Returns 401 with WWW-Authenticate header when there is no Authorization header. +// - Returns a minimal JSON-RPC success response when an Authorization header is present. +// +// asURL is the base URL of the authorization server; it is embedded in the +// WWW-Authenticate header's realm and resource_metadata attributes. +func newCIMDMockMCPServer(tb testHelper, asURL string) *httptest.Server { + tb.Helper() + + resourceMetaURL := fmt.Sprintf("%s/.well-known/mcp-resource", asURL) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") == "" { + w.Header().Set( + "WWW-Authenticate", + fmt.Sprintf(`Bearer realm="%s",resource_metadata="%s"`, asURL, resourceMetaURL), + ) + w.WriteHeader(http.StatusUnauthorized) + return + } + + // Minimal JSON-RPC success response so the proxy can verify connectivity. + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2024-11-05","capabilities":{},"serverInfo":{"name":"cimd-mock-mcp","version":"0.0.1"}}}`)) + })) + + tb.Cleanup(srv.Close) + return srv +} diff --git a/test/e2e/cimd_auth_test.go b/test/e2e/cimd_auth_test.go new file mode 100644 index 0000000000..113cb804e4 --- /dev/null +++ b/test/e2e/cimd_auth_test.go @@ -0,0 +1,219 @@ +// SPDX-FileCopyrightText: Copyright 2025 Stacklok, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package e2e_test + +import ( + "bytes" + "io" + "net/http" + "os" + "os/exec" + "regexp" + "strings" + "time" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "github.com/stacklok/toolhive/pkg/oauthproto" + "github.com/stacklok/toolhive/test/e2e" +) + +// startCIMDRunCommand starts `thv run --name --remote-auth …` +// and returns the exec.Cmd together with a buffer that captures combined stdout +// and stderr. The buffer is safe to read concurrently from the test goroutine. +func startCIMDRunCommand( + config *e2e.TestConfig, + serverName string, + mcpURL string, + asIssuerURL string, +) (*exec.Cmd, *bytes.Buffer) { + args := []string{ + "run", + mcpURL, + "--name", serverName, + "--remote-auth", + "--remote-auth-skip-browser", + "--remote-auth-issuer", asIssuerURL, + "--remote-auth-timeout", "30s", + } + + GinkgoWriter.Printf("Starting thv run with args: %v\n", args) + + cmd := exec.Command(config.THVBinary, args...) //nolint:gosec // Intentional for e2e testing + cmd.Env = os.Environ() + + var outputBuffer bytes.Buffer + multiWriter := io.MultiWriter(&outputBuffer, GinkgoWriter) + cmd.Stdout = multiWriter + cmd.Stderr = multiWriter + + err := cmd.Start() + Expect(err).ToNot(HaveOccurred(), "thv run should start without error") + + return cmd, &outputBuffer +} + +// extractAuthURL scans the captured output buffer for the OAuth browser URL +// that ToolHive prints when --remote-auth-skip-browser is set. +func extractAuthURL(output string) string { + urlPattern := regexp.MustCompile(`Please open this URL in your browser: (https?://[^\s"]+)`) + matches := urlPattern.FindStringSubmatch(output) + if len(matches) >= 2 { + return matches[1] + } + return "" +} + +// appendAutoComplete appends or sets auto_complete=true on an authorize URL so +// that the cimdMockAuthServer will immediately redirect to the callback. +func appendAutoComplete(authURL string) string { + if authURL == "" { + return authURL + } + separator := "&" + if !strings.Contains(authURL, "?") { + separator = "?" + } + return authURL + separator + "auto_complete=true" +} + +var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial, func() { + var config *e2e.TestConfig + + BeforeEach(func() { + config = e2e.NewTestConfig() + + err := e2e.CheckTHVBinaryAvailable(config) + Expect(err).ToNot(HaveOccurred(), "thv binary should be available for testing") + }) + + Context("when the authorization server advertises CIMD support", func() { + It("uses the CIMD client_id and skips DCR", func() { + By("Starting mock authorization server with CIMD support enabled") + mockAS := newCIMDMockAuthServer(GinkgoT(), true) + + By("Starting mock MCP server that requires authentication") + mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL()) + + serverName := e2e.GenerateUniqueServerName("cimd-cimd-supported") + + By("Starting thv run pointing at the mock MCP server") + cmd, outputBuffer := startCIMDRunCommand(config, serverName, mockMCP.URL, mockAS.IssuerURL()) + + defer func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + } + if config.CleanupAfter { + _ = e2e.StopAndRemoveMCPServer(config, serverName) + } + }() + + By("Waiting for the OAuth URL to appear in the output") + var authURL string + Eventually(func() string { + authURL = extractAuthURL(outputBuffer.String()) + return authURL + }, 30*time.Second, 500*time.Millisecond).ShouldNot(BeEmpty(), + "thv run should print 'Please open this URL in your browser'") + + By("Completing the OAuth flow via auto_complete") + autoURL := appendAutoComplete(authURL) + client := &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return nil // follow redirects + }, + } + resp, err := client.Get(autoURL) //nolint:gosec // URL is test-controlled + Expect(err).ToNot(HaveOccurred(), "GET to auto-complete URL should succeed") + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 400), + "auto-complete redirect chain should succeed") + + By("Waiting for the authorization request to be captured by the mock AS") + authReq, err := mockAS.WaitForAuthRequest(15 * time.Second) + Expect(err).ToNot(HaveOccurred(), "mock AS should receive an authorization request") + + By("Asserting client_id equals the CIMD metadata URL") + Expect(authReq.ClientID).To(Equal(oauthproto.ToolHiveClientMetadataDocumentURL), + "thv run should use the CIMD metadata URL as client_id when AS advertises support") + + By("Asserting PKCE code_challenge was included") + Expect(authReq.CodeChallenge).ToNot(BeEmpty(), + "PKCE code_challenge must be present in the authorization request") + + By("Asserting DCR was NOT called") + Expect(mockAS.DcrWasCalled()).To(BeFalse(), + "DCR registration endpoint must not be called when CIMD is used") + + By("Waiting for thv to report the server as running") + err = e2e.WaitForMCPServer(config, serverName, 30*time.Second) + Expect(err).ToNot(HaveOccurred(), "server should appear as running in thv list") + }) + }) + + Context("when the authorization server does NOT advertise CIMD support", func() { + It("falls back to DCR and does not use the CIMD client_id", func() { + By("Starting mock authorization server with CIMD support disabled") + mockAS := newCIMDMockAuthServer(GinkgoT(), false) + + By("Starting mock MCP server that requires authentication") + mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL()) + + serverName := e2e.GenerateUniqueServerName("cimd-dcr-fallback") + + By("Starting thv run pointing at the mock MCP server") + cmd, outputBuffer := startCIMDRunCommand(config, serverName, mockMCP.URL, mockAS.IssuerURL()) + + defer func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + } + if config.CleanupAfter { + _ = e2e.StopAndRemoveMCPServer(config, serverName) + } + }() + + By("Waiting for the OAuth URL to appear in the output") + var authURL string + Eventually(func() string { + authURL = extractAuthURL(outputBuffer.String()) + return authURL + }, 30*time.Second, 500*time.Millisecond).ShouldNot(BeEmpty(), + "thv run should print 'Please open this URL in your browser'") + + By("Completing the OAuth flow via auto_complete") + autoURL := appendAutoComplete(authURL) + client := &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return nil + }, + } + resp, err := client.Get(autoURL) //nolint:gosec // URL is test-controlled + Expect(err).ToNot(HaveOccurred(), "GET to auto-complete URL should succeed") + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + Expect(resp.StatusCode).To(BeNumerically("<", 400)) + + By("Waiting for the authorization request to be captured by the mock AS") + authReq, err := mockAS.WaitForAuthRequest(15 * time.Second) + Expect(err).ToNot(HaveOccurred(), "mock AS should receive an authorization request") + + By("Asserting client_id is NOT the CIMD metadata URL") + Expect(authReq.ClientID).ToNot(Equal(oauthproto.ToolHiveClientMetadataDocumentURL), + "thv run must not use the CIMD metadata URL when the AS does not advertise support") + + By("Asserting DCR WAS called") + // Give thv a moment to hit the DCR endpoint before asserting. + Eventually(mockAS.DcrWasCalled, 10*time.Second, 500*time.Millisecond).Should(BeTrue(), + "DCR registration endpoint must be called when CIMD is not advertised") + }) + }) +}) diff --git a/test/e2e/oidc_mock.go b/test/e2e/oidc_mock.go index 9c3651aac0..629cdba8ab 100644 --- a/test/e2e/oidc_mock.go +++ b/test/e2e/oidc_mock.go @@ -214,6 +214,7 @@ func (m *OIDCMockServer) handleDiscovery(w http.ResponseWriter, _ *http.Request) "subject_types_supported": []string{"public"}, "id_token_signing_alg_values_supported": []string{"RS256"}, "scopes_supported": []string{"openid", "profile", "email"}, + "client_id_metadata_document_supported": true, } w.Header().Set("Content-Type", "application/json")