diff --git a/pkg/auth/remote/handler_test.go b/pkg/auth/remote/handler_test.go index a57c7e874e..b7761964d9 100644 --- a/pkg/auth/remote/handler_test.go +++ b/pkg/auth/remote/handler_test.go @@ -937,3 +937,77 @@ func TestAuthenticate_BearerTokenDiscovery(t *testing.T) { assert.Equal(t, "Bearer", token.TokenType) }) } + +// TestResolveClientCredentials verifies the credential selection priority in +// resolveClientCredentials: CachedCIMDClientID > CachedClientID (DCR) > +// statically-configured ClientID. +func TestResolveClientCredentials(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config *Config + wantClientID string + wantClientSecret string + }{ + { + name: "CachedCIMDClientID takes precedence over DCR and static credentials", + config: &Config{ + ClientID: "static-client-id", + ClientSecret: "static-secret", + CachedClientID: "dcr-client-id", + CachedCIMDClientID: "https://toolhive.dev/oauth/client-metadata.json", + }, + wantClientID: "https://toolhive.dev/oauth/client-metadata.json", + wantClientSecret: "", + }, + { + name: "CachedCIMDClientID returns empty secret (token_endpoint_auth_method=none)", + config: &Config{ + CachedCIMDClientID: "https://toolhive.dev/oauth/client-metadata.json", + }, + wantClientID: "https://toolhive.dev/oauth/client-metadata.json", + wantClientSecret: "", + }, + { + // When CachedClientID is set the DCR client_id is used, but because + // CachedClientSecretRef is empty (no secret reference stored) the + // function falls through to the statically-configured ClientSecret. + name: "CachedClientID used when CachedCIMDClientID is empty", + config: &Config{ + ClientID: "static-client-id", + ClientSecret: "static-secret", + CachedClientID: "dcr-client-id", + }, + wantClientID: "dcr-client-id", + wantClientSecret: "static-secret", + }, + { + name: "static credentials used when no cached credentials exist", + config: &Config{ + ClientID: "static-client-id", + ClientSecret: "static-secret", + }, + wantClientID: "static-client-id", + wantClientSecret: "static-secret", + }, + { + name: "all empty returns empty strings", + config: &Config{}, + wantClientID: "", + wantClientSecret: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h := &Handler{config: tt.config} + gotClientID, gotClientSecret := h.resolveClientCredentials(context.Background()) + + assert.Equal(t, tt.wantClientID, gotClientID, "clientID mismatch") + assert.Equal(t, tt.wantClientSecret, gotClientSecret, "clientSecret mismatch") + }) + } +} diff --git a/test/e2e/cimd_auth_helpers_test.go b/test/e2e/cimd_auth_helpers_test.go index bf54b6977c..6af96eda2d 100644 --- a/test/e2e/cimd_auth_helpers_test.go +++ b/test/e2e/cimd_auth_helpers_test.go @@ -35,21 +35,27 @@ type cimdMockAuthServer struct { server *httptest.Server authRequestChan chan cimdAuthRequest - mu sync.Mutex - lastClientID string - dcrCalled bool - cimdSupported bool + mu sync.Mutex + lastClientID string + dcrCalled bool + cimdSupported bool + rejectCIMD bool + cimdRejectedOnce bool } // newCIMDMockAuthServer creates and starts a mock authorization server that // advertises client_id_metadata_document_supported. It registers t.Cleanup to -// close the server automatically. -func newCIMDMockAuthServer(tb testHelper, cimdSupported bool) *cimdMockAuthServer { +// close the server automatically. Pass rejectCIMD=true to make the server +// reject the first authorization request that uses a CIMD client_id (an HTTPS +// URL), simulating an AS that advertises CIMD support but rejects it at +// runtime, triggering the DCR fallback path in ToolHive. +func newCIMDMockAuthServer(tb testHelper, cimdSupported bool, rejectCIMD bool) *cimdMockAuthServer { tb.Helper() s := &cimdMockAuthServer{ authRequestChan: make(chan cimdAuthRequest, 4), cimdSupported: cimdSupported, + rejectCIMD: rejectCIMD, } mux := http.NewServeMux() @@ -124,9 +130,24 @@ func (s *cimdMockAuthServer) handleDiscovery(w http.ResponseWriter, _ *http.Requ _ = json.NewEncoder(w).Encode(doc) } +// RejectCIMDWasCalled returns true if the server rejected a CIMD client_id at +// least once. Callers use this to assert that the CIMD path was attempted +// before the DCR fallback fired. +func (s *cimdMockAuthServer) RejectCIMDWasCalled() bool { + s.mu.Lock() + defer s.mu.Unlock() + return s.cimdRejectedOnce +} + // handleAuthorize captures the authorization request and either immediately // redirects (when auto_complete=true) or places the request into the channel // for the test to inspect. +// +// When rejectCIMD is true, the first request whose client_id is an HTTPS URL +// (i.e. a CIMD metadata document URL) is rejected by redirecting to the +// callback with error=invalid_client. This simulates an AS that advertises +// CIMD support but rejects it at the authorization endpoint, triggering the +// DCR fallback path in ToolHive. func (s *cimdMockAuthServer) handleAuthorize(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() req := cimdAuthRequest{ @@ -138,6 +159,32 @@ func (s *cimdMockAuthServer) handleAuthorize(w http.ResponseWriter, r *http.Requ s.mu.Lock() s.lastClientID = req.ClientID + + // If rejectCIMD is armed and this is the first CIMD request, reject it. + // A CIMD client_id is any HTTPS URL (see oauthproto.IsClientIDMetadataDocumentURL). + if s.rejectCIMD && !s.cimdRejectedOnce && isCIMDClientID(req.ClientID) { + s.cimdRejectedOnce = true + s.mu.Unlock() + + redirectURI := req.RedirectURI + if redirectURI == "" { + http.Error(w, "missing redirect_uri", http.StatusBadRequest) + return + } + separator := "?" + for _, ch := range redirectURI { + if ch == '?' { + separator = "&" + break + } + } + http.Redirect(w, r, + fmt.Sprintf("%s%serror=invalid_client&state=%s&error_description=cimd+not+supported", + redirectURI, separator, req.State), + http.StatusFound, + ) + return + } s.mu.Unlock() // Always send into the channel so WaitForAuthRequest can inspect it. @@ -226,6 +273,13 @@ func (s *cimdMockAuthServer) handleResourceMetadata(w http.ResponseWriter, _ *ht _ = json.NewEncoder(w).Encode(meta) } +// isCIMDClientID returns true if clientID looks like a CIMD metadata document +// URL (i.e. any HTTPS URL). This mirrors oauthproto.IsClientIDMetadataDocumentURL +// without importing the production package from a test helper. +func isCIMDClientID(clientID string) bool { + return len(clientID) >= 8 && clientID[:8] == "https://" +} + // newCIMDMockMCPServer creates a minimal httptest MCP server that: // - Returns 401 with WWW-Authenticate header when there is no Authorization header. // - Returns a minimal JSON-RPC success response when an Authorization header is present. diff --git a/test/e2e/cimd_auth_test.go b/test/e2e/cimd_auth_test.go index 113cb804e4..b8d775eec3 100644 --- a/test/e2e/cimd_auth_test.go +++ b/test/e2e/cimd_auth_test.go @@ -92,7 +92,7 @@ var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial, Context("when the authorization server advertises CIMD support", func() { It("uses the CIMD client_id and skips DCR", func() { By("Starting mock authorization server with CIMD support enabled") - mockAS := newCIMDMockAuthServer(GinkgoT(), true) + mockAS := newCIMDMockAuthServer(GinkgoT(), true, false) By("Starting mock MCP server that requires authentication") mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL()) @@ -160,7 +160,7 @@ var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial, Context("when the authorization server does NOT advertise CIMD support", func() { It("falls back to DCR and does not use the CIMD client_id", func() { By("Starting mock authorization server with CIMD support disabled") - mockAS := newCIMDMockAuthServer(GinkgoT(), false) + mockAS := newCIMDMockAuthServer(GinkgoT(), false, false) By("Starting mock MCP server that requires authentication") mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL()) @@ -216,4 +216,88 @@ var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial, "DCR registration endpoint must be called when CIMD is not advertised") }) }) + + Context("CIMD fallback and warm-start behaviour", func() { + It("falls back to DCR when AS rejects the CIMD client_id", func() { + By("Starting mock authorization server: CIMD advertised but first CIMD request rejected") + mockAS := newCIMDMockAuthServer(GinkgoT(), true, true) + + By("Starting mock MCP server that requires authentication") + mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL()) + + serverName := e2e.GenerateUniqueServerName("cimd-reject-fallback") + + By("Starting thv run pointing at the mock MCP server") + cmd, outputBuffer := startCIMDRunCommand(config, serverName, mockMCP.URL, mockAS.IssuerURL()) + + defer func() { + if cmd.Process != nil { + _ = cmd.Process.Kill() + _ = cmd.Wait() + } + if config.CleanupAfter { + _ = e2e.StopAndRemoveMCPServer(config, serverName) + } + }() + + By("Waiting for the first OAuth URL (CIMD attempt) to appear in the output") + var firstAuthURL string + Eventually(func() string { + firstAuthURL = extractAuthURL(outputBuffer.String()) + return firstAuthURL + }, 30*time.Second, 500*time.Millisecond).ShouldNot(BeEmpty(), + "thv run should print 'Please open this URL in your browser' for the CIMD attempt") + + By("Visiting the first URL — the AS will redirect back with error=invalid_client") + client := &http.Client{ + Timeout: 10 * time.Second, + CheckRedirect: func(_ *http.Request, _ []*http.Request) error { + return nil // follow redirects + }, + } + autoFirstURL := appendAutoComplete(firstAuthURL) + resp, err := client.Get(autoFirstURL) //nolint:gosec // URL is test-controlled + Expect(err).ToNot(HaveOccurred(), "GET to first auto-complete URL should not error") + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + // The redirect chain ends at the ToolHive callback; any 2xx/3xx is fine. + Expect(resp.StatusCode).To(BeNumerically("<", 500), + "redirect chain for CIMD rejection should not produce a server error") + + By("Asserting the mock AS registered the CIMD rejection") + Eventually(mockAS.RejectCIMDWasCalled, 10*time.Second, 500*time.Millisecond).Should(BeTrue(), + "mock AS must have rejected the CIMD client_id before DCR retry") + + By("Waiting for the second OAuth URL (DCR retry) to appear in the output") + var secondAuthURL string + Eventually(func() string { + out := outputBuffer.String() + // The second URL appears after the first; find the last occurrence. + allURLs := regexp.MustCompile(`Please open this URL in your browser: (https?://[^\s"]+)`). + FindAllStringSubmatch(out, -1) + if len(allURLs) >= 2 { + secondAuthURL = allURLs[len(allURLs)-1][1] + } + return secondAuthURL + }, 45*time.Second, 500*time.Millisecond).ShouldNot(BeEmpty(), + "thv run should print a second OAuth URL after the CIMD rejection triggers a DCR retry") + + By("Completing the DCR OAuth flow via auto_complete") + autoSecondURL := appendAutoComplete(secondAuthURL) + resp2, err := client.Get(autoSecondURL) //nolint:gosec // URL is test-controlled + Expect(err).ToNot(HaveOccurred(), "GET to second auto-complete URL should succeed") + _, _ = io.Copy(io.Discard, resp2.Body) + _ = resp2.Body.Close() + Expect(resp2.StatusCode).To(BeNumerically("<", 400), + "DCR auto-complete redirect chain should succeed") + + By("Asserting DCR was called during the retry") + Eventually(mockAS.DcrWasCalled, 10*time.Second, 500*time.Millisecond).Should(BeTrue(), + "DCR registration endpoint must be called after CIMD rejection") + + By("Waiting for thv to report the server as running") + err = e2e.WaitForMCPServer(config, serverName, 30*time.Second) + Expect(err).ToNot(HaveOccurred(), "server should appear as running in thv list after CIMD→DCR fallback") + }) + }) })