Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions pkg/auth/remote/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,3 +937,77 @@ func TestAuthenticate_BearerTokenDiscovery(t *testing.T) {
assert.Equal(t, "Bearer", token.TokenType)
})
}

// TestResolveClientCredentials verifies the credential selection priority in
// resolveClientCredentials: CachedCIMDClientID > CachedClientID (DCR) >
// statically-configured ClientID.
func TestResolveClientCredentials(t *testing.T) {
t.Parallel()

tests := []struct {
name string
config *Config
wantClientID string
wantClientSecret string
}{
{
name: "CachedCIMDClientID takes precedence over DCR and static credentials",
config: &Config{
ClientID: "static-client-id",
ClientSecret: "static-secret",
CachedClientID: "dcr-client-id",
CachedCIMDClientID: "https://toolhive.dev/oauth/client-metadata.json",
},
wantClientID: "https://toolhive.dev/oauth/client-metadata.json",
wantClientSecret: "",
},
{
name: "CachedCIMDClientID returns empty secret (token_endpoint_auth_method=none)",
config: &Config{
CachedCIMDClientID: "https://toolhive.dev/oauth/client-metadata.json",
},
wantClientID: "https://toolhive.dev/oauth/client-metadata.json",
wantClientSecret: "",
},
{
// When CachedClientID is set the DCR client_id is used, but because
// CachedClientSecretRef is empty (no secret reference stored) the
// function falls through to the statically-configured ClientSecret.
name: "CachedClientID used when CachedCIMDClientID is empty",
config: &Config{
ClientID: "static-client-id",
ClientSecret: "static-secret",
CachedClientID: "dcr-client-id",
},
wantClientID: "dcr-client-id",
wantClientSecret: "static-secret",
},
{
name: "static credentials used when no cached credentials exist",
config: &Config{
ClientID: "static-client-id",
ClientSecret: "static-secret",
},
wantClientID: "static-client-id",
wantClientSecret: "static-secret",
},
{
name: "all empty returns empty strings",
config: &Config{},
wantClientID: "",
wantClientSecret: "",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

h := &Handler{config: tt.config}
gotClientID, gotClientSecret := h.resolveClientCredentials(context.Background())

assert.Equal(t, tt.wantClientID, gotClientID, "clientID mismatch")
assert.Equal(t, tt.wantClientSecret, gotClientSecret, "clientSecret mismatch")
})
}
}
66 changes: 60 additions & 6 deletions test/e2e/cimd_auth_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,27 @@ type cimdMockAuthServer struct {
server *httptest.Server
authRequestChan chan cimdAuthRequest

mu sync.Mutex
lastClientID string
dcrCalled bool
cimdSupported bool
mu sync.Mutex
lastClientID string
dcrCalled bool
cimdSupported bool
rejectCIMD bool
cimdRejectedOnce bool
}

// newCIMDMockAuthServer creates and starts a mock authorization server that
// advertises client_id_metadata_document_supported. It registers t.Cleanup to
// close the server automatically.
func newCIMDMockAuthServer(tb testHelper, cimdSupported bool) *cimdMockAuthServer {
// close the server automatically. Pass rejectCIMD=true to make the server
// reject the first authorization request that uses a CIMD client_id (an HTTPS
// URL), simulating an AS that advertises CIMD support but rejects it at
// runtime, triggering the DCR fallback path in ToolHive.
func newCIMDMockAuthServer(tb testHelper, cimdSupported bool, rejectCIMD bool) *cimdMockAuthServer {
tb.Helper()

s := &cimdMockAuthServer{
authRequestChan: make(chan cimdAuthRequest, 4),
cimdSupported: cimdSupported,
rejectCIMD: rejectCIMD,
}

mux := http.NewServeMux()
Expand Down Expand Up @@ -124,9 +130,24 @@ func (s *cimdMockAuthServer) handleDiscovery(w http.ResponseWriter, _ *http.Requ
_ = json.NewEncoder(w).Encode(doc)
}

// RejectCIMDWasCalled returns true if the server rejected a CIMD client_id at
// least once. Callers use this to assert that the CIMD path was attempted
// before the DCR fallback fired.
func (s *cimdMockAuthServer) RejectCIMDWasCalled() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.cimdRejectedOnce
}

// handleAuthorize captures the authorization request and either immediately
// redirects (when auto_complete=true) or places the request into the channel
// for the test to inspect.
//
// When rejectCIMD is true, the first request whose client_id is an HTTPS URL
// (i.e. a CIMD metadata document URL) is rejected by redirecting to the
// callback with error=invalid_client. This simulates an AS that advertises
// CIMD support but rejects it at the authorization endpoint, triggering the
// DCR fallback path in ToolHive.
func (s *cimdMockAuthServer) handleAuthorize(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
req := cimdAuthRequest{
Expand All @@ -138,6 +159,32 @@ func (s *cimdMockAuthServer) handleAuthorize(w http.ResponseWriter, r *http.Requ

s.mu.Lock()
s.lastClientID = req.ClientID

// If rejectCIMD is armed and this is the first CIMD request, reject it.
// A CIMD client_id is any HTTPS URL (see oauthproto.IsClientIDMetadataDocumentURL).
if s.rejectCIMD && !s.cimdRejectedOnce && isCIMDClientID(req.ClientID) {
s.cimdRejectedOnce = true
s.mu.Unlock()

redirectURI := req.RedirectURI
if redirectURI == "" {
http.Error(w, "missing redirect_uri", http.StatusBadRequest)
return
}
separator := "?"
for _, ch := range redirectURI {
if ch == '?' {
separator = "&"
break
}
}
http.Redirect(w, r,
fmt.Sprintf("%s%serror=invalid_client&state=%s&error_description=cimd+not+supported",
redirectURI, separator, req.State),
http.StatusFound,
)
return
}
s.mu.Unlock()

// Always send into the channel so WaitForAuthRequest can inspect it.
Expand Down Expand Up @@ -226,6 +273,13 @@ func (s *cimdMockAuthServer) handleResourceMetadata(w http.ResponseWriter, _ *ht
_ = json.NewEncoder(w).Encode(meta)
}

// isCIMDClientID returns true if clientID looks like a CIMD metadata document
// URL (i.e. any HTTPS URL). This mirrors oauthproto.IsClientIDMetadataDocumentURL
// without importing the production package from a test helper.
func isCIMDClientID(clientID string) bool {
return len(clientID) >= 8 && clientID[:8] == "https://"
}

// newCIMDMockMCPServer creates a minimal httptest MCP server that:
// - Returns 401 with WWW-Authenticate header when there is no Authorization header.
// - Returns a minimal JSON-RPC success response when an Authorization header is present.
Expand Down
88 changes: 86 additions & 2 deletions test/e2e/cimd_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial,
Context("when the authorization server advertises CIMD support", func() {
It("uses the CIMD client_id and skips DCR", func() {
By("Starting mock authorization server with CIMD support enabled")
mockAS := newCIMDMockAuthServer(GinkgoT(), true)
mockAS := newCIMDMockAuthServer(GinkgoT(), true, false)

By("Starting mock MCP server that requires authentication")
mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL())
Expand Down Expand Up @@ -160,7 +160,7 @@ var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial,
Context("when the authorization server does NOT advertise CIMD support", func() {
It("falls back to DCR and does not use the CIMD client_id", func() {
By("Starting mock authorization server with CIMD support disabled")
mockAS := newCIMDMockAuthServer(GinkgoT(), false)
mockAS := newCIMDMockAuthServer(GinkgoT(), false, false)

By("Starting mock MCP server that requires authentication")
mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL())
Expand Down Expand Up @@ -216,4 +216,88 @@ var _ = Describe("CIMD Authentication", Label("remote", "auth", "cimd"), Serial,
"DCR registration endpoint must be called when CIMD is not advertised")
})
})

Context("CIMD fallback and warm-start behaviour", func() {
It("falls back to DCR when AS rejects the CIMD client_id", func() {
By("Starting mock authorization server: CIMD advertised but first CIMD request rejected")
mockAS := newCIMDMockAuthServer(GinkgoT(), true, true)

By("Starting mock MCP server that requires authentication")
mockMCP := newCIMDMockMCPServer(GinkgoT(), mockAS.URL())

serverName := e2e.GenerateUniqueServerName("cimd-reject-fallback")

By("Starting thv run pointing at the mock MCP server")
cmd, outputBuffer := startCIMDRunCommand(config, serverName, mockMCP.URL, mockAS.IssuerURL())

defer func() {
if cmd.Process != nil {
_ = cmd.Process.Kill()
_ = cmd.Wait()
}
if config.CleanupAfter {
_ = e2e.StopAndRemoveMCPServer(config, serverName)
}
}()

By("Waiting for the first OAuth URL (CIMD attempt) to appear in the output")
var firstAuthURL string
Eventually(func() string {
firstAuthURL = extractAuthURL(outputBuffer.String())
return firstAuthURL
}, 30*time.Second, 500*time.Millisecond).ShouldNot(BeEmpty(),
"thv run should print 'Please open this URL in your browser' for the CIMD attempt")

By("Visiting the first URL — the AS will redirect back with error=invalid_client")
client := &http.Client{
Timeout: 10 * time.Second,
CheckRedirect: func(_ *http.Request, _ []*http.Request) error {
return nil // follow redirects
},
}
autoFirstURL := appendAutoComplete(firstAuthURL)
resp, err := client.Get(autoFirstURL) //nolint:gosec // URL is test-controlled
Expect(err).ToNot(HaveOccurred(), "GET to first auto-complete URL should not error")
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
// The redirect chain ends at the ToolHive callback; any 2xx/3xx is fine.
Expect(resp.StatusCode).To(BeNumerically("<", 500),
"redirect chain for CIMD rejection should not produce a server error")

By("Asserting the mock AS registered the CIMD rejection")
Eventually(mockAS.RejectCIMDWasCalled, 10*time.Second, 500*time.Millisecond).Should(BeTrue(),
"mock AS must have rejected the CIMD client_id before DCR retry")

By("Waiting for the second OAuth URL (DCR retry) to appear in the output")
var secondAuthURL string
Eventually(func() string {
out := outputBuffer.String()
// The second URL appears after the first; find the last occurrence.
allURLs := regexp.MustCompile(`Please open this URL in your browser: (https?://[^\s"]+)`).
FindAllStringSubmatch(out, -1)
if len(allURLs) >= 2 {
secondAuthURL = allURLs[len(allURLs)-1][1]
}
return secondAuthURL
}, 45*time.Second, 500*time.Millisecond).ShouldNot(BeEmpty(),
"thv run should print a second OAuth URL after the CIMD rejection triggers a DCR retry")

By("Completing the DCR OAuth flow via auto_complete")
autoSecondURL := appendAutoComplete(secondAuthURL)
resp2, err := client.Get(autoSecondURL) //nolint:gosec // URL is test-controlled
Expect(err).ToNot(HaveOccurred(), "GET to second auto-complete URL should succeed")
_, _ = io.Copy(io.Discard, resp2.Body)
_ = resp2.Body.Close()
Expect(resp2.StatusCode).To(BeNumerically("<", 400),
"DCR auto-complete redirect chain should succeed")

By("Asserting DCR was called during the retry")
Eventually(mockAS.DcrWasCalled, 10*time.Second, 500*time.Millisecond).Should(BeTrue(),
"DCR registration endpoint must be called after CIMD rejection")

By("Waiting for thv to report the server as running")
err = e2e.WaitForMCPServer(config, serverName, 30*time.Second)
Expect(err).ToNot(HaveOccurred(), "server should appear as running in thv list after CIMD→DCR fallback")
})
})
})
Loading