From e60d139dc88e37398bef7de7b404a14470012c4f Mon Sep 17 00:00:00 2001 From: Jet Chiang Date: Fri, 17 Apr 2026 13:16:33 -0400 Subject: [PATCH 1/3] go sts client Signed-off-by: Jet Chiang --- go/adk/pkg/sts/actor.go | 54 +++++ go/adk/pkg/sts/client.go | 280 +++++++++++++++++++++++ go/adk/pkg/sts/client_test.go | 232 +++++++++++++++++++ go/adk/pkg/sts/errors.go | 87 ++++++++ go/adk/pkg/sts/integration.go | 185 ++++++++++++++++ go/adk/pkg/sts/integration_test.go | 84 +++++++ go/adk/pkg/sts/models.go | 133 +++++++++++ go/adk/pkg/sts/plugin.go | 344 +++++++++++++++++++++++++++++ go/adk/pkg/sts/plugin_test.go | 139 ++++++++++++ go/adk/pkg/sts/utils.go | 139 ++++++++++++ 10 files changed, 1677 insertions(+) create mode 100644 go/adk/pkg/sts/actor.go create mode 100644 go/adk/pkg/sts/client.go create mode 100644 go/adk/pkg/sts/client_test.go create mode 100644 go/adk/pkg/sts/errors.go create mode 100644 go/adk/pkg/sts/integration.go create mode 100644 go/adk/pkg/sts/integration_test.go create mode 100644 go/adk/pkg/sts/models.go create mode 100644 go/adk/pkg/sts/plugin.go create mode 100644 go/adk/pkg/sts/plugin_test.go create mode 100644 go/adk/pkg/sts/utils.go diff --git a/go/adk/pkg/sts/actor.go b/go/adk/pkg/sts/actor.go new file mode 100644 index 000000000..fb42cac42 --- /dev/null +++ b/go/adk/pkg/sts/actor.go @@ -0,0 +1,54 @@ +package sts + +import ( + "fmt" + "os" + "strings" +) + +// DefaultServiceAccountTokenPath is the default path for Kubernetes service account tokens. +const DefaultServiceAccountTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token" + +// ActorTokenService provides actor tokens for STS delegation. +// It reads Kubernetes service account tokens from a file path. +type ActorTokenService struct { + tokenPath string +} + +// NewActorTokenService creates a new ActorTokenService. +// If tokenPath is empty, it uses the default Kubernetes service account token path. +func NewActorTokenService(tokenPath string) *ActorTokenService { + if tokenPath == "" { + tokenPath = DefaultServiceAccountTokenPath + } + return &ActorTokenService{ + tokenPath: tokenPath, + } +} + +// GetActorToken retrieves the actor token for STS delegation. +// This method reads the token from the file each time it's called. +// If loading fails, it returns an empty string and an error (or nil if file doesn't exist). +func (s *ActorTokenService) GetActorToken() (string, error) { + // Check if file exists first + if _, err := os.Stat(s.tokenPath); os.IsNotExist(err) { + return "", nil + } + + data, err := os.ReadFile(s.tokenPath) + if err != nil { + return "", fmt.Errorf("failed to read actor token from %s: %w", s.tokenPath, err) + } + + token := strings.TrimSpace(string(data)) + if token == "" { + return "", fmt.Errorf("empty actor token found at %s", s.tokenPath) + } + + return token, nil +} + +// TokenPath returns the configured token path. +func (s *ActorTokenService) TokenPath() string { + return s.tokenPath +} diff --git a/go/adk/pkg/sts/client.go b/go/adk/pkg/sts/client.go new file mode 100644 index 000000000..59be15b30 --- /dev/null +++ b/go/adk/pkg/sts/client.go @@ -0,0 +1,280 @@ +package sts + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" +) + +// STSClient implements a Security Token Service client for RFC 8693 OAuth 2.0 Token Exchange. +type STSClient struct { + config STSConfig + wellKnownConfig *WellKnownConfiguration + httpClient *http.Client + initMu sync.Mutex +} + +// NewSTSClient creates a new STS client with the given configuration. +func NewSTSClient(config STSConfig) *STSClient { + config = normalizeSTSConfig(config) + return &STSClient{ + config: config, + } +} + +func normalizeSTSConfig(config STSConfig) STSConfig { + if config.Timeout == 0 { + config.Timeout = 5 + } + if config.VerifySSL == nil { + config.VerifySSL = boolPtr(true) + } + return config +} + +// initialize performs lazy initialization of the client. +// Fetches well-known configuration if not already cached. +func (c *STSClient) initialize(ctx context.Context) error { + c.initMu.Lock() + defer c.initMu.Unlock() + + if c.wellKnownConfig != nil && c.httpClient != nil { + return nil + } + + if c.wellKnownConfig == nil { + wellKnownConfig, err := FetchWellKnownConfiguration( + ctx, + c.config.WellKnownURI, + c.config.Timeout, + *c.config.VerifySSL, + c.config.UseIssuerHost, + ) + if err != nil { + return err + } + c.wellKnownConfig = wellKnownConfig + } + + if c.httpClient == nil { + transport := &http.Transport{} + if !*c.config.VerifySSL { + transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + c.httpClient = &http.Client{ + Timeout: time.Duration(c.config.Timeout) * time.Second, + Transport: transport, + } + } + + return nil +} + +// buildFormData creates the form-encoded request data from a TokenExchangeRequest. +func (c *STSClient) buildFormData(req *TokenExchangeRequest) url.Values { + data := url.Values{} + data.Set("grant_type", string(req.GrantType)) + data.Set("subject_token", req.SubjectToken) + data.Set("subject_token_type", string(req.SubjectTokenType)) + + // Add actor token for delegation requests + if req.ActorToken != "" { + data.Set("actor_token", req.ActorToken) + if req.ActorTokenType != "" { + data.Set("actor_token_type", string(req.ActorTokenType)) + } + } + + // Add optional parameters + if req.Resource != nil { + switch v := req.Resource.(type) { + case string: + data.Set("resource", v) + case []string: + for _, r := range v { + data.Add("resource", r) + } + } + } + + if req.Audience != nil { + switch v := req.Audience.(type) { + case string: + data.Set("audience", v) + case []string: + for _, a := range v { + data.Add("audience", a) + } + } + } + + if req.Scope != "" { + data.Set("scope", req.Scope) + } + + if req.RequestedTokenType != "" { + data.Set("requested_token_type", string(req.RequestedTokenType)) + } + + // Add additional parameters + if req.AdditionalParameters != nil { + for key, value := range req.AdditionalParameters { + if v, ok := value.(string); ok { + data.Set(key, v) + } + } + } + + return data +} + +// ExchangeToken performs a token exchange using RFC 8693 OAuth 2.0 Token Exchange. +// +// NOTE: The actor_token and actor_token_type parameters enable delegation scenarios. +// For impersonation (no delegation), omit these parameters. +func (c *STSClient) ExchangeToken( + ctx context.Context, + subjectToken string, + subjectTokenType TokenType, + actorToken string, + actorTokenType TokenType, + resource interface{}, + audience interface{}, + scope string, + requestedTokenType TokenType, + additionalParameters map[string]interface{}, +) (*TokenExchangeResponse, error) { + if err := c.initialize(ctx); err != nil { + return nil, err + } + + // Validate actor token type requirement + if actorToken != "" && actorTokenType == "" { + return nil, NewConfigurationError("actor_token_type is required when actor_token is provided") + } + + req := &TokenExchangeRequest{ + GrantType: GrantTypeTokenExchange, + SubjectToken: subjectToken, + SubjectTokenType: subjectTokenType, + ActorToken: actorToken, + ActorTokenType: actorTokenType, + Resource: resource, + Audience: audience, + Scope: scope, + RequestedTokenType: requestedTokenType, + AdditionalParameters: additionalParameters, + } + + formData := c.buildFormData(req) + + postReq, err := http.NewRequestWithContext(ctx, http.MethodPost, c.wellKnownConfig.TokenEndpoint, strings.NewReader(formData.Encode())) + if err != nil { + return nil, NewNetworkError("failed to create token exchange request", err) + } + + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + postReq.Header.Set("Accept", "application/json") + + resp, err := c.httpClient.Do(postReq) + if err != nil { + return nil, NewNetworkError("network error during token exchange", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + var result TokenExchangeResponse + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, NewConfigurationError(fmt.Sprintf("invalid token exchange response: %v", err)) + } + return &result, nil + } + + // Parse error response + var responseData map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&responseData); err != nil { + // Could not parse error as JSON + return nil, NewTokenExchangeError( + "invalid_response", + fmt.Sprintf("Invalid error response from STS server: status %d", resp.StatusCode), + resp.StatusCode, + ) + } + + // Extract error details + errorCode := "unknown_error" + if ec, ok := responseData["error"].(string); ok { + errorCode = ec + } + + errorDescription := "" + if ed, ok := responseData["error_description"].(string); ok { + errorDescription = ed + } + + return nil, NewTokenExchangeError(errorCode, errorDescription, resp.StatusCode) +} + +// Impersonate performs an impersonation token exchange (no actor token). +func (c *STSClient) Impersonate( + ctx context.Context, + subjectToken string, + subjectTokenType TokenType, + resource interface{}, + audience interface{}, + scope string, + requestedTokenType TokenType, + additionalParameters map[string]interface{}, +) (*TokenExchangeResponse, error) { + return c.ExchangeToken( + ctx, + subjectToken, + subjectTokenType, + "", // no actor token + "", // no actor token type + resource, + audience, + scope, + requestedTokenType, + additionalParameters, + ) +} + +// Delegate performs a delegation token exchange (with actor token). +func (c *STSClient) Delegate( + ctx context.Context, + subjectToken string, + subjectTokenType TokenType, + actorToken string, + actorTokenType TokenType, + resource interface{}, + audience interface{}, + scope string, + requestedTokenType TokenType, + additionalParameters map[string]interface{}, +) (*TokenExchangeResponse, error) { + if subjectToken == "" { + return nil, NewAuthenticationError("subject token required for delegation") + } + if actorToken == "" { + return nil, NewAuthenticationError("actor token required for delegation") + } + return c.ExchangeToken( + ctx, + subjectToken, + subjectTokenType, + actorToken, + actorTokenType, + resource, + audience, + scope, + requestedTokenType, + additionalParameters, + ) +} diff --git a/go/adk/pkg/sts/client_test.go b/go/adk/pkg/sts/client_test.go new file mode 100644 index 000000000..424f4346d --- /dev/null +++ b/go/adk/pkg/sts/client_test.go @@ -0,0 +1,232 @@ +package sts + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" +) + +func newMockSTSClientServer(t *testing.T, handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server { + t.Helper() + var srv *httptest.Server + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/oauth-authorization-server" { + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": srv.URL, + "token_endpoint": srv.URL + "/token", + }) + return + } + handler(w, r) + })) + return srv +} + +func TestSTSClientImpersonateSuccess(t *testing.T) { + t.Parallel() + srv := newMockSTSClientServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/token" || r.Method != http.MethodPost { + http.NotFound(w, r) + return + } + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm() error = %v", err) + } + if got := r.FormValue("subject_token"); got != "subject" { + t.Fatalf("subject_token = %q, want %q", got, "subject") + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "access-token", + "issued_token_type": string(TokenTypeJWT), + "token_type": "Bearer", + "expires_in": 3600, + }) + }) + defer srv.Close() + + client := NewSTSClient(STSConfig{ + WellKnownURI: srv.URL + "/.well-known/oauth-authorization-server", + Timeout: 2, + }) + resp, err := client.Impersonate(context.Background(), "subject", TokenTypeJWT, nil, nil, "", "", nil) + if err != nil { + t.Fatalf("Impersonate() error = %v", err) + } + if resp.AccessToken != "access-token" { + t.Fatalf("AccessToken = %q, want %q", resp.AccessToken, "access-token") + } +} + +func TestSTSClientDelegateBuildsRequestData(t *testing.T) { + t.Parallel() + srv := newMockSTSClientServer(t, func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm() error = %v", err) + } + assertFormValue(t, r.Form, "grant_type", string(GrantTypeTokenExchange)) + assertFormValue(t, r.Form, "subject_token", "subject-token") + assertFormValue(t, r.Form, "subject_token_type", string(TokenTypeJWT)) + assertFormValue(t, r.Form, "actor_token", "actor-token") + assertFormValue(t, r.Form, "actor_token_type", string(TokenTypeJWT)) + assertFormValue(t, r.Form, "audience", "https://api.example.com") + assertFormValue(t, r.Form, "scope", "read write") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "delegated-token", + "issued_token_type": string(TokenTypeJWT), + "token_type": "Bearer", + }) + }) + defer srv.Close() + + client := NewSTSClient(STSConfig{ + WellKnownURI: srv.URL + "/.well-known/oauth-authorization-server", + Timeout: 2, + }) + _, err := client.Delegate( + context.Background(), + "subject-token", + TokenTypeJWT, + "actor-token", + TokenTypeJWT, + nil, + "https://api.example.com", + "read write", + "", + nil, + ) + if err != nil { + t.Fatalf("Delegate() error = %v", err) + } +} + +func TestSTSClientExchangeTokenErrorResponse(t *testing.T) { + t.Parallel() + srv := newMockSTSClientServer(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "invalid_request", + "error_description": "missing required parameter", + }) + }) + defer srv.Close() + + client := NewSTSClient(STSConfig{ + WellKnownURI: srv.URL + "/.well-known/oauth-authorization-server", + Timeout: 2, + }) + _, err := client.Impersonate(context.Background(), "subject", TokenTypeJWT, nil, nil, "", "", nil) + if err == nil { + t.Fatalf("Impersonate() error = nil, want non-nil") + } + exchangeErr, ok := err.(*TokenExchangeError) + if !ok { + t.Fatalf("error type = %T, want *TokenExchangeError", err) + } + if exchangeErr.ErrorCode != "invalid_request" { + t.Fatalf("ErrorCode = %q, want %q", exchangeErr.ErrorCode, "invalid_request") + } +} + +func TestSTSClientNetworkError(t *testing.T) { + t.Parallel() + // Use a closed local listener address to trigger a network failure quickly. + client := NewSTSClient(STSConfig{ + WellKnownURI: "http://127.0.0.1:1/.well-known/oauth-authorization-server", + Timeout: 1, + }) + _, err := client.Impersonate(context.Background(), "subject", TokenTypeJWT, nil, nil, "", "", nil) + if err == nil { + t.Fatalf("Impersonate() error = nil, want non-nil") + } + if !strings.Contains(err.Error(), "failed to fetch well-known configuration") && + !strings.Contains(err.Error(), "network error") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestNewSTSClientAppliesSecureDefaults(t *testing.T) { + t.Parallel() + client := NewSTSClient(STSConfig{ + WellKnownURI: "http://example.com/.well-known/oauth-authorization-server", + }) + + if client.config.Timeout != 5 { + t.Fatalf("Timeout = %d, want 5", client.config.Timeout) + } + if client.config.VerifySSL == nil || !*client.config.VerifySSL { + t.Fatalf("VerifySSL = %v, want true", client.config.VerifySSL) + } +} + +func TestSTSClientInitializeRetriesAfterDiscoveryFailure(t *testing.T) { + t.Parallel() + wellKnownCalls := 0 + var srv *httptest.Server + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/.well-known/oauth-authorization-server": + wellKnownCalls++ + if wellKnownCalls == 1 { + http.Error(w, "temporary failure", http.StatusServiceUnavailable) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": srv.URL, + "token_endpoint": srv.URL + "/token", + }) + case "/token": + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "access-token", + "issued_token_type": string(TokenTypeJWT), + }) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + client := NewSTSClient(STSConfig{ + WellKnownURI: srv.URL + "/.well-known/oauth-authorization-server", + Timeout: 2, + }) + + if _, err := client.Impersonate(context.Background(), "subject", TokenTypeJWT, nil, nil, "", "", nil); err == nil { + t.Fatalf("first Impersonate() error = nil, want discovery error") + } + resp, err := client.Impersonate(context.Background(), "subject", TokenTypeJWT, nil, nil, "", "", nil) + if err != nil { + t.Fatalf("second Impersonate() error = %v", err) + } + if resp.AccessToken != "access-token" { + t.Fatalf("AccessToken = %q, want %q", resp.AccessToken, "access-token") + } + if wellKnownCalls != 2 { + t.Fatalf("well-known calls = %d, want 2", wellKnownCalls) + } +} + +func TestSTSClientDelegateWithoutSubjectToken(t *testing.T) { + t.Parallel() + client := NewSTSClient(STSConfig{ + WellKnownURI: "http://unused", + Timeout: 1, + }) + _, err := client.Delegate(context.Background(), "", TokenTypeJWT, "actor", TokenTypeJWT, nil, nil, "", "", nil) + if err == nil { + t.Fatalf("Delegate() error = nil, want non-nil") + } + if _, ok := err.(*AuthenticationError); !ok { + t.Fatalf("error type = %T, want *AuthenticationError", err) + } +} + +func assertFormValue(t *testing.T, form url.Values, key, want string) { + t.Helper() + if got := form.Get(key); got != want { + t.Fatalf("%s = %q, want %q", key, got, want) + } +} diff --git a/go/adk/pkg/sts/errors.go b/go/adk/pkg/sts/errors.go new file mode 100644 index 000000000..228ce5ac3 --- /dev/null +++ b/go/adk/pkg/sts/errors.go @@ -0,0 +1,87 @@ +package sts + +import "fmt" + +// STSError is the base error type for STS client errors. +type STSError struct { + Message string +} + +func (e *STSError) Error() string { + return e.Message +} + +// TokenExchangeError is raised when token exchange fails. +type TokenExchangeError struct { + STSError + ErrorCode string + ErrorDescription string + StatusCode int +} + +func (e *TokenExchangeError) Error() string { + if e.ErrorDescription != "" { + return fmt.Sprintf("token exchange failed: %s - %s (status: %d)", e.ErrorCode, e.ErrorDescription, e.StatusCode) + } + return fmt.Sprintf("token exchange failed: %s (status: %d)", e.ErrorCode, e.StatusCode) +} + +// NewTokenExchangeError creates a new TokenExchangeError. +func NewTokenExchangeError(errorCode, errorDescription string, statusCode int) *TokenExchangeError { + return &TokenExchangeError{ + STSError: STSError{Message: fmt.Sprintf("token exchange failed: %s", errorCode)}, + ErrorCode: errorCode, + ErrorDescription: errorDescription, + StatusCode: statusCode, + } +} + +// ConfigurationError is raised when STS configuration is invalid. +type ConfigurationError struct { + STSError +} + +// NewConfigurationError creates a new ConfigurationError. +func NewConfigurationError(message string) *ConfigurationError { + return &ConfigurationError{ + STSError: STSError{Message: fmt.Sprintf("STS configuration error: %s", message)}, + } +} + +// AuthenticationError is raised when authentication fails. +type AuthenticationError struct { + STSError +} + +// NewAuthenticationError creates a new AuthenticationError. +func NewAuthenticationError(message string) *AuthenticationError { + return &AuthenticationError{ + STSError: STSError{Message: fmt.Sprintf("STS authentication error: %s", message)}, + } +} + +// NetworkError is raised when network operations fail. +type NetworkError struct { + STSError + Cause error +} + +func (e *NetworkError) Error() string { + if e.Cause != nil { + return fmt.Sprintf("STS network error: %s: %v", e.STSError.Message, e.Cause) + } + return fmt.Sprintf("STS network error: %s", e.STSError.Message) +} + +// Unwrap returns the underlying error for errors.Is/As support. +func (e *NetworkError) Unwrap() error { + return e.Cause +} + +// NewNetworkError creates a new NetworkError. +func NewNetworkError(message string, cause error) *NetworkError { + return &NetworkError{ + STSError: STSError{Message: message}, + Cause: cause, + } +} diff --git a/go/adk/pkg/sts/integration.go b/go/adk/pkg/sts/integration.go new file mode 100644 index 000000000..446e4edbd --- /dev/null +++ b/go/adk/pkg/sts/integration.go @@ -0,0 +1,185 @@ +package sts + +import ( + "context" + "fmt" +) + +// GetSubjectTokenFunc is a function type for extracting subject tokens. +// It receives the bearer token (from Authorization header) and should return +// the subject token for STS exchange, or empty string if not available. +type GetSubjectTokenFunc func(bearerToken string) string + +// DefaultGetSubjectToken extracts the JWT token from the Authorization header. +// It expects the bearerToken to already be the JWT (without "Bearer " prefix). +// This matches how executor.go stores the token in context. +func DefaultGetSubjectToken(bearerToken string) string { + return bearerToken +} + +// FetchActorTokenFunc is a function type for fetching actor tokens dynamically. +// This can be used for scenarios where the actor token needs to be fetched +// at runtime rather than being a static Kubernetes service account token. +type FetchActorTokenFunc func(ctx context.Context) (string, error) + +// STSIntegration provides framework-agnostic STS integration. +// It wires together the STS client, actor token service, and subject token extraction. +type STSIntegration struct { + client *STSClient + actorTokenService *ActorTokenService + fetchActorToken FetchActorTokenFunc + getSubjectToken GetSubjectTokenFunc + staticActorToken string // cached static actor token from service +} + +// NewSTSIntegration creates a new STS integration. +// +// Parameters: +// - wellKnownURI: The well-known configuration URI for the STS server +// - serviceAccountTokenPath: Path to K8s service account token (ignored if fetchActorToken is set) +// - fetchActorToken: Optional function to fetch actor token dynamically +// - getSubjectToken: Optional function to extract subject token from context +// - timeout: Request timeout in seconds (default: 5) +// - verifySSL: Whether to verify SSL certificates (default: true) +// - useIssuerHost: Replace host:port in token_endpoint with host:port from well_known_uri +// +// NOTE: If fetchActorToken is provided, serviceAccountTokenPath is ignored. +// If getSubjectToken is not provided, DefaultGetSubjectToken is used. +func NewSTSIntegration( + wellKnownURI string, + serviceAccountTokenPath string, + fetchActorToken FetchActorTokenFunc, + getSubjectToken GetSubjectTokenFunc, + timeout int, + verifySSL bool, + useIssuerHost bool, +) (*STSIntegration, error) { + config := STSConfig{ + WellKnownURI: wellKnownURI, + Timeout: timeout, + VerifySSL: &verifySSL, + UseIssuerHost: useIssuerHost, + } + + if config.Timeout == 0 { + config.Timeout = 5 + } + + integration := &STSIntegration{ + client: NewSTSClient(config), + fetchActorToken: fetchActorToken, + getSubjectToken: getSubjectToken, + } + + // Only set up actor token service if no dynamic fetch function provided + if fetchActorToken == nil { + integration.actorTokenService = NewActorTokenService(serviceAccountTokenPath) + } + + // Use default subject token extraction if not provided + if integration.getSubjectToken == nil { + integration.getSubjectToken = DefaultGetSubjectToken + } + + return integration, nil +} + +// GetSubjectToken extracts the subject token from the bearer token. +func (i *STSIntegration) GetSubjectToken(bearerToken string) string { + return i.getSubjectToken(bearerToken) +} + +// getActorToken retrieves the actor token, either from cache, dynamic fetch, or file. +func (i *STSIntegration) getActorToken(ctx context.Context) (string, error) { + // Use dynamic fetch if provided + if i.fetchActorToken != nil { + return i.fetchActorToken(ctx) + } + + // Use cached static token if available + if i.staticActorToken != "" { + return i.staticActorToken, nil + } + + // Load from service account token file (one-time load for static tokens) + if i.actorTokenService != nil { + token, err := i.actorTokenService.GetActorToken() + if err != nil { + return "", fmt.Errorf("failed to get actor token from service: %w", err) + } + i.staticActorToken = token + return token, nil + } + + return "", nil +} + +func (i *STSIntegration) actorTokenForExchange(ctx context.Context) (string, error) { + actorToken, err := i.getActorToken(ctx) + if err != nil { + if i.fetchActorToken != nil { + return "", fmt.Errorf("failed to fetch actor token dynamically: %w", err) + } + return "", nil + } + if actorToken == "" { + return "", nil + } + return actorToken, nil +} + +// ExchangeToken performs a token exchange using the STS client. +// It automatically handles actor token retrieval when needed. +func (i *STSIntegration) ExchangeToken( + ctx context.Context, + subjectToken string, + subjectTokenType TokenType, + resource interface{}, + audience interface{}, + scope string, + requestedTokenType TokenType, +) (*TokenExchangeResponse, error) { + actorToken, err := i.actorTokenForExchange(ctx) + if err != nil { + return nil, err + } + + return i.ExchangeTokenWithActorToken(ctx, subjectToken, subjectTokenType, actorToken, resource, audience, scope, requestedTokenType) +} + +// ExchangeTokenWithActorToken performs a token exchange using an actor token +// selected by the caller. This lets ADK-specific integrations own dynamic +// actor-token caching while keeping raw STS client calls behind STSIntegration. +func (i *STSIntegration) ExchangeTokenWithActorToken( + ctx context.Context, + subjectToken string, + subjectTokenType TokenType, + actorToken string, + resource interface{}, + audience interface{}, + scope string, + requestedTokenType TokenType, +) (*TokenExchangeResponse, error) { + var actorTokenType TokenType + if actorToken != "" { + actorTokenType = TokenTypeJWT + } + + return i.client.ExchangeToken( + ctx, + subjectToken, + subjectTokenType, + actorToken, + actorTokenType, + resource, + audience, + scope, + requestedTokenType, + nil, + ) +} + +// Client returns the underlying STS client for advanced use cases. +func (i *STSIntegration) Client() *STSClient { + return i.client +} diff --git a/go/adk/pkg/sts/integration_test.go b/go/adk/pkg/sts/integration_test.go new file mode 100644 index 000000000..f1acae78e --- /dev/null +++ b/go/adk/pkg/sts/integration_test.go @@ -0,0 +1,84 @@ +package sts + +import ( + "context" + "os" + "path/filepath" + "testing" +) + +func TestNewSTSIntegrationDefaultSubjectToken(t *testing.T) { + t.Parallel() + i, err := NewSTSIntegration("http://example.com/.well-known", "", nil, nil, 0, true, false) + if err != nil { + t.Fatalf("NewSTSIntegration() error = %v", err) + } + got := i.GetSubjectToken("bearer-token") + if got != "bearer-token" { + t.Fatalf("GetSubjectToken() = %q, want %q", got, "bearer-token") + } +} + +func TestSTSIntegrationDynamicActorTokenFetch(t *testing.T) { + t.Parallel() + calls := 0 + i, err := NewSTSIntegration( + "http://example.com/.well-known", + "", + func(context.Context) (string, error) { + calls++ + return "dynamic-actor", nil + }, + nil, + 5, + true, + false, + ) + if err != nil { + t.Fatalf("NewSTSIntegration() error = %v", err) + } + got, err := i.getActorToken(context.Background()) + if err != nil { + t.Fatalf("getActorToken() error = %v", err) + } + if got != "dynamic-actor" { + t.Fatalf("getActorToken() = %q, want %q", got, "dynamic-actor") + } + if calls != 1 { + t.Fatalf("fetchActorToken calls = %d, want 1", calls) + } +} + +func TestSTSIntegrationStaticActorTokenCached(t *testing.T) { + t.Parallel() + dir := t.TempDir() + tokenPath := filepath.Join(dir, "actor-token") + if err := os.WriteFile(tokenPath, []byte("static-token"), 0o600); err != nil { + t.Fatalf("failed to write token file: %v", err) + } + + i, err := NewSTSIntegration("http://example.com/.well-known", tokenPath, nil, nil, 5, true, false) + if err != nil { + t.Fatalf("NewSTSIntegration() error = %v", err) + } + + got1, err := i.getActorToken(context.Background()) + if err != nil { + t.Fatalf("first getActorToken() error = %v", err) + } + if got1 != "static-token" { + t.Fatalf("first getActorToken() = %q, want %q", got1, "static-token") + } + + // Change underlying file; cached static token should still be returned. + if err := os.WriteFile(tokenPath, []byte("new-token"), 0o600); err != nil { + t.Fatalf("failed to update token file: %v", err) + } + got2, err := i.getActorToken(context.Background()) + if err != nil { + t.Fatalf("second getActorToken() error = %v", err) + } + if got2 != "static-token" { + t.Fatalf("second getActorToken() = %q, want cached %q", got2, "static-token") + } +} diff --git a/go/adk/pkg/sts/models.go b/go/adk/pkg/sts/models.go new file mode 100644 index 000000000..713fc1765 --- /dev/null +++ b/go/adk/pkg/sts/models.go @@ -0,0 +1,133 @@ +// Package sts implements OAuth 2.0 Token Exchange (RFC 8693) for the Go ADK. +// This package provides a Security Token Service (STS) client with Kubernetes +// service account token support and ADK plugin integration for token propagation. +package sts + +// TokenType represents RFC 8693 defined token types. +type TokenType string + +const ( + // TokenTypeJWT is the JWT token type + TokenTypeJWT TokenType = "urn:ietf:params:oauth:token-type:jwt" + // TokenTypeSAML2 is the SAML2 token type + TokenTypeSAML2 TokenType = "urn:ietf:params:oauth:token-type:saml2" + // TokenTypeSAML1 is the SAML1 token type + TokenTypeSAML1 TokenType = "urn:ietf:params:oauth:token-type:saml1" + // TokenTypeIDToken is the ID token type + TokenTypeIDToken TokenType = "urn:ietf:params:oauth:token-type:id_token" + // TokenTypeAccessToken is the access token type + TokenTypeAccessToken TokenType = "urn:ietf:params:oauth:token-type:access_token" +) + +// GrantType represents OAuth 2.0 grant types. +type GrantType string + +const ( + // GrantTypeTokenExchange is the RFC 8693 token exchange grant type + GrantTypeTokenExchange GrantType = "urn:ietf:params:oauth:grant-type:token-exchange" +) + +// TokenExchangeRequest represents an RFC 8693 Token Exchange Request. +type TokenExchangeRequest struct { + // GrantType is the OAuth 2.0 grant type (required) + GrantType GrantType `json:"grant_type"` + // SubjectToken is the security token representing the identity of the party + // on behalf of whom the new token is being requested (required) + SubjectToken string `json:"subject_token"` + // SubjectTokenType is the type of the subject_token (required) + SubjectTokenType TokenType `json:"subject_token_type"` + // ActorToken is the security token representing the identity of the acting party (optional) + ActorToken string `json:"actor_token,omitempty"` + // ActorTokenType is the type of the actor_token (required if ActorToken is set) + ActorTokenType TokenType `json:"actor_token_type,omitempty"` + // Resource is the logical name of the target service or resource (optional) + Resource interface{} `json:"resource,omitempty"` // Can be string or []string + // Audience is the logical name of the target service or resource (optional) + Audience interface{} `json:"audience,omitempty"` // Can be string or []string + // Scope is the scope of the requested token (optional) + Scope string `json:"scope,omitempty"` + // RequestedTokenType is the type of the requested token (optional) + RequestedTokenType TokenType `json:"requested_token_type,omitempty"` + // AdditionalParameters contains additional parameters for the request (optional) + AdditionalParameters map[string]interface{} `json:"-"` // Not serialized directly, merged into form data +} + +// IsDelegationRequest checks if this is a delegation request (has actor_token). +func (r *TokenExchangeRequest) IsDelegationRequest() bool { + return r.ActorToken != "" +} + +// IsImpersonationRequest checks if this is an impersonation request (no actor_token). +func (r *TokenExchangeRequest) IsImpersonationRequest() bool { + return r.ActorToken == "" +} + +// TokenExchangeResponse represents an RFC 8693 Token Exchange Response. +type TokenExchangeResponse struct { + // AccessToken is the issued security token (required) + AccessToken string `json:"access_token"` + // IssuedTokenType is the type of the issued token (required) + IssuedTokenType TokenType `json:"issued_token_type"` + // TokenType is the type of the access token (default: Bearer) + TokenType string `json:"token_type,omitempty"` + // ExpiresIn is the lifetime in seconds of the access token (optional) + ExpiresIn int `json:"expires_in,omitempty"` + // Scope is the scope of the access token (optional) + Scope string `json:"scope,omitempty"` + // RefreshToken is the refresh token if applicable (optional) + RefreshToken string `json:"refresh_token,omitempty"` + // AdditionalParameters contains additional response parameters (optional) + AdditionalParameters map[string]interface{} `json:"-"` +} + +// TokenExchangeErrorResponse represents an RFC 8693 Token Exchange Error response. +type TokenExchangeErrorResponse struct { + // Error is the error code (required) + Error string `json:"error"` + // ErrorDescription is a human-readable error description (optional) + ErrorDescription string `json:"error_description,omitempty"` + // ErrorURI is a URI identifying the error (optional) + ErrorURI string `json:"error_uri,omitempty"` + // AdditionalParameters contains additional error parameters (optional) + AdditionalParameters map[string]interface{} `json:"-"` +} + +// WellKnownConfiguration represents OAuth 2.0 Authorization Server Metadata. +type WellKnownConfiguration struct { + // Issuer is the authorization server's issuer identifier (required) + Issuer string `json:"issuer"` + // TokenEndpoint is the token endpoint URL (required) + TokenEndpoint string `json:"token_endpoint"` + // TokenEndpointAuthMethodsSupported is the list of supported auth methods (optional) + TokenEndpointAuthMethodsSupported []string `json:"token_endpoint_auth_methods_supported,omitempty"` + // TokenEndpointAuthSigningAlgValuesSupported is the list of supported signing algorithms (optional) + TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` + // AdditionalParameters contains additional configuration parameters (optional) + AdditionalParameters map[string]interface{} `json:"-"` +} + +// STSConfig holds configuration for the STS client. +type STSConfig struct { + // WellKnownURI is the well-known configuration URI (required) + WellKnownURI string + // Timeout is the request timeout in seconds (default: 5) + Timeout int + // VerifySSL controls whether to verify SSL certificates (default: true) + VerifySSL *bool + // UseIssuerHost replaces the host:port in token_endpoint with the host:port from well_known_uri + UseIssuerHost bool +} + +// DefaultSTSConfig returns a default STS configuration. +func DefaultSTSConfig(wellKnownURI string) STSConfig { + return STSConfig{ + WellKnownURI: wellKnownURI, + Timeout: 5, + VerifySSL: boolPtr(true), + UseIssuerHost: false, + } +} + +func boolPtr(v bool) *bool { + return &v +} diff --git a/go/adk/pkg/sts/plugin.go b/go/adk/pkg/sts/plugin.go new file mode 100644 index 000000000..2c3c5d288 --- /dev/null +++ b/go/adk/pkg/sts/plugin.go @@ -0,0 +1,344 @@ +package sts + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/go-logr/logr" + "github.com/golang-jwt/jwt/v5" + "github.com/kagent-dev/kagent/go/adk/pkg/models" + "google.golang.org/adk/agent" + adkplugin "google.golang.org/adk/plugin" + "google.golang.org/genai" +) + +// TokenCacheEntry holds a cached token with its expiry time. +type TokenCacheEntry struct { + Token string + Expiry int64 // Unix timestamp, 0 if no expiry +} + +// HasExpired checks if the token has expired or will expire soon. +func (e *TokenCacheEntry) HasExpired(bufferSeconds int64) bool { + if e.Expiry == 0 { + return false + } + return e.Expiry <= time.Now().Unix()+bufferSeconds +} + +// TokenPropagationPlugin propagates STS tokens to ADK tools. +// It registers as a Go ADK plugin for run-level token preparation and exposes +// a header provider used by MCP tool transports. +type TokenPropagationPlugin struct { + integration *STSIntegration + tokenCache map[string]*TokenCacheEntry // keyed by session ID + actorTokenCache *TokenCacheEntry // used only for dynamic fetchActorToken providers + mu sync.RWMutex + logger logr.Logger + bufferSeconds int64 +} + +// NewTokenPropagationPlugin creates a new token propagation plugin. +// If integration is nil, the plugin will pass through tokens without exchange. +func NewTokenPropagationPlugin(integration *STSIntegration, logger logr.Logger) *TokenPropagationPlugin { + return &TokenPropagationPlugin{ + integration: integration, + tokenCache: make(map[string]*TokenCacheEntry), + logger: logger.WithName("sts-plugin"), + bufferSeconds: 5, + } +} + +// getCachedToken retrieves a valid cached token for the session. +func (p *TokenPropagationPlugin) getCachedToken(sessionID string) (*TokenCacheEntry, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + + entry, ok := p.tokenCache[sessionID] + if !ok { + return nil, false + } + + if entry.HasExpired(p.bufferSeconds) { + return nil, false + } + + return entry, true +} + +// setCachedToken caches a token for the session. +func (p *TokenPropagationPlugin) setCachedToken(sessionID string, token string, expiry int64) { + p.mu.Lock() + defer p.mu.Unlock() + + p.tokenCache[sessionID] = &TokenCacheEntry{ + Token: token, + Expiry: expiry, + } +} + +func (p *TokenPropagationPlugin) getCachedActorToken() (*TokenCacheEntry, bool) { + p.mu.RLock() + defer p.mu.RUnlock() + + if p.actorTokenCache == nil || p.actorTokenCache.HasExpired(p.bufferSeconds) { + return nil, false + } + return p.actorTokenCache, true +} + +func (p *TokenPropagationPlugin) setCachedActorToken(token string) { + p.mu.Lock() + defer p.mu.Unlock() + + p.actorTokenCache = &TokenCacheEntry{ + Token: token, + Expiry: extractJWTExpiry(token), + } +} + +func (p *TokenPropagationPlugin) actorTokenForExchange(ctx context.Context) (string, error) { + if p.integration == nil { + return "", nil + } + + if p.integration.fetchActorToken == nil { + return p.integration.actorTokenForExchange(ctx) + } + + if entry, ok := p.getCachedActorToken(); ok { + return entry.Token, nil + } + + actorToken, err := p.integration.actorTokenForExchange(ctx) + if err != nil || actorToken == "" { + return actorToken, err + } + + p.setCachedActorToken(actorToken) + return actorToken, nil +} + +// BeforeRunCallback is called before the ADK run starts. +// It extracts the subject token, performs STS exchange if needed, and caches the result. +func (p *TokenPropagationPlugin) BeforeRunCallback(ctx agent.InvocationContext) (*genai.Content, error) { + sessionID := "" + if session := ctx.Session(); session != nil { + sessionID = session.ID() + } + if sessionID == "" { + p.logger.V(1).Info("No session ID available, skipping token propagation") + return nil, nil + } + + // Check if we already have a valid cached token for this session. + if entry, ok := p.getCachedToken(sessionID); ok { + p.logger.V(1).Info("Using cached STS token", "sessionID", sessionID) + if entry.Expiry > 0 { + p.logger.V(1).Info("Token expiry remaining", + "expiresIn", time.Until(time.Unix(entry.Expiry, 0)).String()) + } + return nil, nil + } + + // Extract bearer token from context. executor.go stores it with models.BearerTokenKey. + bearerToken := "" + if v := ctx.Value(models.BearerTokenKey); v != nil { + if token, ok := v.(string); ok { + bearerToken = token + } + } + + if bearerToken == "" { + p.logger.V(1).Info("No bearer token in context, skipping token propagation", "sessionID", sessionID) + return nil, nil + } + + // Get subject token + subjectToken := bearerToken + if p.integration != nil { + subjectToken = p.integration.GetSubjectToken(bearerToken) + } + + if subjectToken == "" { + p.logger.V(1).Info("Empty subject token extracted, skipping", "sessionID", sessionID) + return nil, nil + } + + if p.integration != nil { + actorToken, err := p.actorTokenForExchange(ctx) + if err != nil { + p.logger.Error(err, "Failed to fetch actor token dynamically, skipping STS token exchange", "sessionID", sessionID) + return nil, nil + } + + resp, err := p.integration.ExchangeTokenWithActorToken( + ctx, + subjectToken, + TokenTypeJWT, + actorToken, + nil, // resource + nil, // audience + "", // scope + "", // requestedTokenType + ) + if err != nil { + p.logger.Error(err, "STS token exchange failed, tools may not authenticate", "sessionID", sessionID) + return nil, nil + } + + // Cache the exchanged token. + exchangedToken := resp.AccessToken + expiry := int64(0) + if resp.ExpiresIn > 0 { + expiry = time.Now().Unix() + int64(resp.ExpiresIn) + } else { + // Fall back to JWT exp claim for cache TTL. + expiry = extractJWTExpiry(exchangedToken) + } + p.setCachedToken(sessionID, exchangedToken, expiry) + p.logger.Info("Successfully exchanged and cached STS token", "sessionID", sessionID) + } else { + // No STS integration — cache the raw subject token for header injection. + expiry := extractJWTExpiry(subjectToken) + p.setCachedToken(sessionID, subjectToken, expiry) + p.logger.V(1).Info("Cached subject token (no STS exchange)", "sessionID", sessionID) + } + + return nil, nil +} + +// AfterRunCallback is called after the ADK run finishes. +// It cleans up expired tokens from the cache. +func (p *TokenPropagationPlugin) AfterRunCallback(ctx agent.InvocationContext) { + sessionID := "" + if session := ctx.Session(); session != nil { + sessionID = session.ID() + } + if sessionID == "" { + return + } + + p.mu.Lock() + defer p.mu.Unlock() + + // Remove expired subject token. + if entry, ok := p.tokenCache[sessionID]; ok { + if entry.HasExpired(p.bufferSeconds) { + p.logger.V(1).Info("Removing expired subject token from cache", "sessionID", sessionID) + delete(p.tokenCache, sessionID) + } + } + if p.actorTokenCache != nil && p.actorTokenCache.HasExpired(p.bufferSeconds) { + p.logger.V(1).Info("Removing expired actor token from cache") + p.actorTokenCache = nil + } + +} + +// HeaderProvider returns a map of headers to inject into MCP tool HTTP requests. +// It is called by the dynamicHeaderRoundTripper on every MCP HTTP request. +func (p *TokenPropagationPlugin) HeaderProvider(ctx context.Context) map[string]string { + if ctx == nil { + return nil + } + + sessionID := sessionIDFromContext(ctx) + if sessionID == "" { + p.logger.V(1).Info("No session ID in context, MCP request will use existing headers") + return nil + } + + entry, ok := p.getCachedToken(sessionID) + if !ok { + p.logger.V(1).Info("No cached STS token for session, MCP request will use existing headers", "sessionID", sessionID) + return nil + } + + p.logger.V(1).Info("Injecting STS token into MCP request headers", "sessionID", sessionID) + return map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", entry.Token), + } +} + +// Extract session ID from ADK tool / invocation context, which implements SessionID(). +func sessionIDFromContext(ctx context.Context) string { + type sessionContext interface { + SessionID() string + } + sessionCtx, ok := ctx.(sessionContext) + if !ok { + return "" + } + return sessionCtx.SessionID() +} + +// GetTokenForSession retrieves the cached token for a specific session. +// Returns empty string if no valid token is cached. +func (p *TokenPropagationPlugin) GetTokenForSession(sessionID string) string { + entry, ok := p.getCachedToken(sessionID) + if !ok { + return "" + } + return entry.Token +} + +// ClearCache clears all cached tokens. +func (p *TokenPropagationPlugin) ClearCache() { + p.mu.Lock() + defer p.mu.Unlock() + + p.tokenCache = make(map[string]*TokenCacheEntry) + p.actorTokenCache = nil + p.logger.Info("Cleared STS token cache") +} + +// ADKPlugin returns the Go ADK plugin registered with runner.PluginConfig. +func (p *TokenPropagationPlugin) ADKPlugin() (*adkplugin.Plugin, error) { + return adkplugin.New(adkplugin.Config{ + Name: "kagent-sts-token-propagation", + BeforeRunCallback: p.BeforeRunCallback, + AfterRunCallback: p.AfterRunCallback, + }) +} + +// extractJWTExpiry extracts the 'exp' claim from a JWT token without verifying its signature. +// This is ONLY used for cache TTL management, not for security decisions. +// Token validation happens server-side during STS exchange. +func extractJWTExpiry(token string) int64 { + if token == "" { + return 0 + } + + // Parse without signature verification — we only need the exp claim. + parsed, err := jwt.Parse(token, + func(t *jwt.Token) (interface{}, error) { return nil, nil }, + jwt.WithoutClaimsValidation(), + ) + // err is expected (no key), but parsed may still carry the claims. + if parsed == nil || parsed.Claims == nil { + if err != nil { + // Truly unparseable token (not a JWT, etc.) + return 0 + } + } + + if parsed != nil { + if claims, ok := parsed.Claims.(jwt.MapClaims); ok { + if exp, ok := claims["exp"]; ok { + switch v := exp.(type) { + case float64: + return int64(v) + case int64: + return v + case int: + return int64(v) + } + } + } + } + + return 0 +} diff --git a/go/adk/pkg/sts/plugin_test.go b/go/adk/pkg/sts/plugin_test.go new file mode 100644 index 000000000..71d85113f --- /dev/null +++ b/go/adk/pkg/sts/plugin_test.go @@ -0,0 +1,139 @@ +package sts + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/go-logr/logr" + kagentmodels "github.com/kagent-dev/kagent/go/adk/pkg/models" + "google.golang.org/adk/agent" + "google.golang.org/adk/session" + "google.golang.org/genai" +) + +type fakeSessionContext struct { + context.Context + sessionID string +} + +func (f fakeSessionContext) SessionID() string { + return f.sessionID +} + +type fakeInvocationContext struct { + context.Context + sessionID string + ended bool +} + +func (f fakeInvocationContext) Agent() agent.Agent { return nil } +func (f fakeInvocationContext) Artifacts() agent.Artifacts { return nil } +func (f fakeInvocationContext) Memory() agent.Memory { return nil } +func (f fakeInvocationContext) Session() session.Session { return fakeSession{id: f.sessionID} } +func (f fakeInvocationContext) InvocationID() string { return "" } +func (f fakeInvocationContext) Branch() string { return "" } +func (f fakeInvocationContext) UserContent() *genai.Content { return nil } +func (f fakeInvocationContext) RunConfig() *agent.RunConfig { return nil } +func (f *fakeInvocationContext) EndInvocation() { f.ended = true } +func (f fakeInvocationContext) Ended() bool { return f.ended } +func (f fakeInvocationContext) WithContext(ctx context.Context) agent.InvocationContext { + f.Context = ctx + return &f +} + +type fakeSession struct { + id string +} + +func (f fakeSession) ID() string { return f.id } +func (f fakeSession) AppName() string { return "" } +func (f fakeSession) UserID() string { return "" } +func (f fakeSession) State() session.State { return nil } +func (f fakeSession) Events() session.Events { return nil } +func (f fakeSession) LastUpdateTime() time.Time { return time.Time{} } + +func TestHeaderProvider_UsesSessionIDMethod(t *testing.T) { + t.Parallel() + plugin := NewTokenPropagationPlugin(nil, logr.Discard()) + plugin.setCachedToken("sess-123", "token-abc", 0) + + headers := plugin.HeaderProvider(fakeSessionContext{ + Context: context.Background(), + sessionID: "sess-123", + }) + + if headers["Authorization"] != "Bearer token-abc" { + t.Fatalf("Authorization header = %q, want %q", headers["Authorization"], "Bearer token-abc") + } +} + +func TestBeforeRunCallback_ReusesCachedDynamicActorTokenForExchange(t *testing.T) { + t.Parallel() + + fetchCount := 0 + exchangeCount := 0 + var srv *httptest.Server + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/oauth-authorization-server" { + _ = json.NewEncoder(w).Encode(map[string]any{ + "issuer": srv.URL, + "token_endpoint": srv.URL + "/token", + }) + return + } + if r.URL.Path != "/token" { + http.NotFound(w, r) + return + } + exchangeCount++ + if err := r.ParseForm(); err != nil { + t.Fatalf("ParseForm() error = %v", err) + } + if got := r.FormValue("actor_token"); got != "dynamic-actor" { + t.Fatalf("actor_token = %q, want %q", got, "dynamic-actor") + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "access-token", + "issued_token_type": string(TokenTypeJWT), + }) + })) + defer srv.Close() + + integration, err := NewSTSIntegration( + srv.URL+"/.well-known/oauth-authorization-server", + "", + func(context.Context) (string, error) { + fetchCount++ + return "dynamic-actor", nil + }, + nil, + 5, + true, + false, + ) + if err != nil { + t.Fatalf("NewSTSIntegration() error = %v", err) + } + + plugin := NewTokenPropagationPlugin(integration, logr.Discard()) + for _, sessionID := range []string{"sess-one", "sess-two"} { + ctx := context.WithValue(context.Background(), kagentmodels.BearerTokenKey, "subject-token") + if _, err := plugin.BeforeRunCallback(&fakeInvocationContext{ + Context: ctx, + sessionID: sessionID, + }); err != nil { + t.Fatalf("BeforeRunCallback() error = %v", err) + } + } + + if fetchCount != 1 { + t.Fatalf("fetchActorToken calls = %d, want 1", fetchCount) + } + if exchangeCount != 2 { + t.Fatalf("token exchange calls = %d, want 2", exchangeCount) + } +} diff --git a/go/adk/pkg/sts/utils.go b/go/adk/pkg/sts/utils.go new file mode 100644 index 000000000..0e4514758 --- /dev/null +++ b/go/adk/pkg/sts/utils.go @@ -0,0 +1,139 @@ +package sts + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" +) + +const ( + httpProtocol = "http://" + httpsProtocol = "https://" +) + +// FetchWellKnownConfiguration retrieves the OAuth 2.0 Authorization Server Metadata +// from the well-known configuration URI. +// +// NOTE: This makes an HTTP request. Callers should cache the result. +func FetchWellKnownConfiguration(ctx context.Context, wellKnownURI string, timeout int, verifySSL bool, useIssuerHost bool) (*WellKnownConfiguration, error) { + client := &http.Client{ + Timeout: time.Duration(timeout) * time.Second, + } + + if !verifySSL { + client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURI, nil) + if err != nil { + return nil, NewNetworkError("failed to create request", err) + } + + resp, err := client.Do(req) + if err != nil { + return nil, NewNetworkError("failed to fetch well-known configuration", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, NewNetworkError(fmt.Sprintf("failed to fetch well-known configuration: HTTP %d", resp.StatusCode), nil) + } + + var data map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + return nil, NewConfigurationError(fmt.Sprintf("invalid well-known configuration response: %v", err)) + } + + // Add protocol to token_endpoint if it's missing + if tokenEndpoint, ok := data["token_endpoint"].(string); ok { + if !strings.HasPrefix(tokenEndpoint, httpProtocol) && !strings.HasPrefix(tokenEndpoint, httpsProtocol) { + // Use the protocol from the well_known_uri + protocol := httpProtocol + if strings.HasPrefix(wellKnownURI, httpsProtocol) { + protocol = httpsProtocol + } + data["token_endpoint"] = protocol + tokenEndpoint + } + + // Replace host:port in token_endpoint with the host:port from well_known_uri + // Protocol is already resolved above, so token_endpoint always has a scheme here + if useIssuerHost { + normalizedTokenEndpoint, _ := data["token_endpoint"].(string) + issuer, err := url.Parse(wellKnownURI) + if err != nil { + return nil, NewConfigurationError(fmt.Sprintf("invalid well-known URI: %v", err)) + } + + endpoint, err := url.Parse(normalizedTokenEndpoint) + if err != nil { + return nil, NewConfigurationError(fmt.Sprintf("invalid token endpoint in configuration: %v", err)) + } + + // Replace netloc (host:port) with issuer's netloc + newEndpoint := *endpoint + newEndpoint.Host = issuer.Host + data["token_endpoint"] = newEndpoint.String() + } + } + + config := &WellKnownConfiguration{ + Issuer: getString(data, "issuer"), + TokenEndpoint: getString(data, "token_endpoint"), + AdditionalParameters: data, + TokenEndpointAuthMethodsSupported: getStringSlice(data, "token_endpoint_auth_methods_supported"), + TokenEndpointAuthSigningAlgValuesSupported: getStringSlice(data, "token_endpoint_auth_signing_alg_values_supported"), + } + + // Validate required fields + if config.Issuer == "" { + return nil, NewConfigurationError("well-known configuration missing 'issuer' field") + } + if config.TokenEndpoint == "" { + return nil, NewConfigurationError("well-known configuration missing 'token_endpoint' field") + } + + return config, nil +} + +// ParseTokenExchangeError parses a token exchange error response. +func ParseTokenExchangeError(responseData map[string]interface{}) *TokenExchangeError { + errorCode := "unknown_error" + if ec, ok := responseData["error"].(string); ok { + errorCode = ec + } + + errorDescription := "" + if ed, ok := responseData["error_description"].(string); ok { + errorDescription = ed + } + + return NewTokenExchangeError(errorCode, errorDescription, 0) +} + +// Helper functions to safely extract values from map[string]interface{} +func getString(m map[string]interface{}, key string) string { + if v, ok := m[key].(string); ok { + return v + } + return "" +} + +func getStringSlice(m map[string]interface{}, key string) []string { + if v, ok := m[key].([]interface{}); ok { + result := make([]string, 0, len(v)) + for _, item := range v { + if s, ok := item.(string); ok { + result = append(result, s) + } + } + return result + } + return nil +} From 4425c87e05c7785c7a82cc14608746fc6dee304d Mon Sep 17 00:00:00 2001 From: Jet Chiang Date: Fri, 15 May 2026 17:42:33 -0400 Subject: [PATCH 2/3] integration and tests of sts plugin Signed-off-by: Jet Chiang --- go/adk/pkg/agent/agent.go | 13 +++-- go/adk/pkg/mcp/registry.go | 41 ++++++++++++--- go/adk/pkg/mcp/registry_test.go | 77 +++++++++++++++++++++++++++++ go/adk/pkg/runner/adapter.go | 59 +++++++++++++++++++++- go/core/test/e2e/invoke_api_test.go | 18 ++++++- go/go.mod | 2 +- 6 files changed, 196 insertions(+), 14 deletions(-) diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index 1aaa21af4..1aae3637d 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -10,6 +10,7 @@ import ( "github.com/go-logr/logr" "github.com/kagent-dev/kagent/go/adk/pkg/mcp" "github.com/kagent-dev/kagent/go/adk/pkg/models" + "github.com/kagent-dev/kagent/go/adk/pkg/sts" "github.com/kagent-dev/kagent/go/adk/pkg/tools" "github.com/kagent-dev/kagent/go/api/adk" "google.golang.org/adk/agent" @@ -33,7 +34,7 @@ const ( // agentName is used as the ADK agent identity (appears in event Author field). // extraTools are appended to the agent's tool list (e.g. save_memory). func CreateGoogleADKAgent(ctx context.Context, agentConfig *adk.AgentConfig, agentName string, extraTools ...tool.Tool) (agent.Agent, error) { - a, _, err := CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentName, extraTools...) + a, _, err := CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentName, nil, extraTools...) return a, err } @@ -41,7 +42,8 @@ func CreateGoogleADKAgent(ctx context.Context, agentConfig *adk.AgentConfig, age // map of remote-subagent tool name → A2A context session ID (for stamping // outbound A2A events). Callers that only need the agent can use // CreateGoogleADKAgent. -func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig *adk.AgentConfig, agentName string, extraTools ...tool.Tool) (agent.Agent, map[string]string, error) { +// Optional stsPlugin can be provided for token propagation to MCP tools. +func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig *adk.AgentConfig, agentName string, stsPlugin *sts.TokenPropagationPlugin, extraTools ...tool.Tool) (agent.Agent, map[string]string, error) { log := logr.FromContextOrDiscard(ctx) if agentConfig == nil { @@ -49,7 +51,11 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig } propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) == "true" - toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken) + var dynamicHeaderProvider mcp.DynamicHeaderProvider + if stsPlugin != nil { + dynamicHeaderProvider = stsPlugin.HeaderProvider + } + toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken, dynamicHeaderProvider) subagentSessionIDs := make(map[string]string) var remoteAgentTools []tool.Tool @@ -104,6 +110,7 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig beforeToolCallbacks := []llmagent.BeforeToolCallback{} // Strip synthetic HITL tool messages from the model request to avoid unnecessary token usage. beforeModelCallbacks := []llmagent.BeforeModelCallback{} + if len(approvalSet) > 0 { log.Info("Wiring approval callback", "toolCount", len(approvalSet)) beforeToolCallbacks = append(beforeToolCallbacks, MakeApprovalCallback(approvalSet)) diff --git a/go/adk/pkg/mcp/registry.go b/go/adk/pkg/mcp/registry.go index 1dd2a4d04..97cf20ea4 100644 --- a/go/adk/pkg/mcp/registry.go +++ b/go/adk/pkg/mcp/registry.go @@ -18,6 +18,11 @@ import ( "google.golang.org/adk/tool/mcptoolset" ) +// DynamicHeaderProvider is a function that returns headers to inject into MCP requests. +// It receives the context and should return a map of headers. +// This is used for dynamic token injection (e.g., STS tokens) per session. +type DynamicHeaderProvider func(ctx context.Context) map[string]string + const ( // Default timeout matching Python KAGENT_REMOTE_AGENT_TIMEOUT defaultTimeout = 30 * time.Minute @@ -62,9 +67,10 @@ func allowedRequestHeaders(ctx context.Context, allowed []string) map[string]str type mcpServerParams struct { URL string Headers map[string]string - AllowedHeaders []string // header names to forward from incoming request - PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders - ServerType string // "http" or "sse" + AllowedHeaders []string // header names to forward from incoming request + PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders + HeaderProvider DynamicHeaderProvider // optional per-request headers derived from invocation context (e.g., STS exchanged access tokens) + ServerType string // "http" or "sse" Timeout *float64 SseReadTimeout *float64 TLSInsecureSkipVerify *bool @@ -79,7 +85,16 @@ type mcpServerParams struct { // When propagateToken is true, Authorization is forwarded to every MCP server // independently of AllowedHeaders, mirroring the Python ADKTokenPropagationPlugin // behaviour triggered by KAGENT_PROPAGATE_TOKEN. -func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig, propagateToken bool) []tool.Toolset { +// +// Optional headerProvider can be used to inject per-request headers +// derived from invocation context (e.g., STS exchanged access tokens). +func CreateToolsets( + ctx context.Context, + httpTools []adk.HttpMcpServerConfig, + sseTools []adk.SseMcpServerConfig, + propagateToken bool, + headerProvider DynamicHeaderProvider, +) []tool.Toolset { log := logr.FromContextOrDiscard(ctx) var toolsets []tool.Toolset @@ -90,6 +105,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss Headers: httpTool.Params.Headers, AllowedHeaders: httpTool.AllowedHeaders, PropagateToken: propagateToken, + HeaderProvider: headerProvider, ServerType: "http", Timeout: httpTool.Params.Timeout, SseReadTimeout: httpTool.Params.SseReadTimeout, @@ -111,6 +127,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss Headers: sseTool.Params.Headers, AllowedHeaders: sseTool.AllowedHeaders, PropagateToken: propagateToken, + HeaderProvider: headerProvider, ServerType: "sse", Timeout: sseTool.Params.Timeout, SseReadTimeout: sseTool.Params.SseReadTimeout, @@ -208,12 +225,13 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp } var httpTransport http.RoundTripper = baseTransport - if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken { + if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken || params.HeaderProvider != nil { httpTransport = &headerRoundTripper{ base: baseTransport, headers: params.Headers, allowedHeaders: params.AllowedHeaders, propagateToken: params.PropagateToken, + headerProvider: params.HeaderProvider, } } @@ -239,18 +257,20 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp } // headerRoundTripper wraps an http.RoundTripper to add custom headers to all -// requests. It supports three sources of headers, applied in this order so that +// requests. It supports four sources of headers, applied in this order so that // higher-priority sources win on collision: // 1. propagateToken: when true, Authorization is read from the incoming A2A // CallContext and forwarded unconditionally (independent of allowedHeaders). // 2. allowedHeaders: explicit per-header forwarding from the A2A CallContext. -// 3. headers: static key/value pairs configured on the MCP server spec (highest +// 3. headerProvider: runtime headers derived from ADK context, such as STS tokens. +// 4. headers: static key/value pairs configured on the MCP server spec (highest // priority — always wins). type headerRoundTripper struct { base http.RoundTripper headers map[string]string allowedHeaders []string // header names (case-insensitive) to forward from A2A context propagateToken bool // when true, Authorization is forwarded independently + headerProvider DynamicHeaderProvider } func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -273,6 +293,13 @@ func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro req.Header.Set(k, v) } + // Dynamic headers (e.g., STS access tokens) override propagated/allowed headers. + if rt.headerProvider != nil { + for key, value := range rt.headerProvider(req.Context()) { + req.Header.Set(key, value) + } + } + // Apply static headers last — they take precedence over all dynamic sources. for key, value := range rt.headers { req.Header.Set(key, value) diff --git a/go/adk/pkg/mcp/registry_test.go b/go/adk/pkg/mcp/registry_test.go index 7a0cdc0d3..864b4c1b6 100644 --- a/go/adk/pkg/mcp/registry_test.go +++ b/go/adk/pkg/mcp/registry_test.go @@ -319,3 +319,80 @@ func TestAllowedRequestHeaders_ReturnsNilWhenNoMatches(t *testing.T) { t.Errorf("expected nil when no allowed headers are present, got %v", got) } } + +// TestDynamicHeaders_OverridePropagatedAuthorization verifies dynamic headers +// take precedence over propagated and allowed request headers. +func TestDynamicHeaders_OverridePropagatedAuthorization(t *testing.T) { + t.Parallel() + var capturedAuth, capturedTrace string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + capturedTrace = r.Header.Get("X-Trace-Id") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ctx := a2aCtx(map[string][]string{ + "Authorization": {"Bearer incoming"}, + "X-Trace-Id": {"trace-from-request"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + propagateToken: true, + allowedHeaders: []string{"Authorization", "X-Trace-Id"}, + headerProvider: func(context.Context) map[string]string { + return map[string]string{ + "Authorization": "Bearer sts-exchanged", + "X-Trace-Id": "trace-from-dynamic", + } + }, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "Bearer sts-exchanged" { + t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer sts-exchanged") + } + if capturedTrace != "trace-from-dynamic" { + t.Errorf("X-Trace-Id: got %q, want %q", capturedTrace, "trace-from-dynamic") + } +} + +// TestStaticHeaders_OverrideDynamic verifies static configured headers remain +// the highest-precedence source. +func TestStaticHeaders_OverrideDynamic(t *testing.T) { + t.Parallel() + var capturedAuth string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + headers: map[string]string{"Authorization": "Bearer static"}, + headerProvider: func(context.Context) map[string]string { + return map[string]string{"Authorization": "Bearer dynamic"} + }, + } + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "Bearer static" { + t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer static") + } +} diff --git a/go/adk/pkg/runner/adapter.go b/go/adk/pkg/runner/adapter.go index d6230ccad..92314e27a 100644 --- a/go/adk/pkg/runner/adapter.go +++ b/go/adk/pkg/runner/adapter.go @@ -3,13 +3,17 @@ package runner import ( "context" "fmt" + "os" "strings" + "github.com/go-logr/logr" "github.com/kagent-dev/kagent/go/adk/pkg/agent" kagentmemory "github.com/kagent-dev/kagent/go/adk/pkg/memory" "github.com/kagent-dev/kagent/go/adk/pkg/session" + "github.com/kagent-dev/kagent/go/adk/pkg/sts" "github.com/kagent-dev/kagent/go/api/adk" adkmemory "google.golang.org/adk/memory" + adkplugin "google.golang.org/adk/plugin" "google.golang.org/adk/runner" adksession "google.golang.org/adk/session" adktool "google.golang.org/adk/tool" @@ -31,6 +35,8 @@ func CreateRunnerConfig( appName string, memoryService *kagentmemory.KagentMemoryService, ) (runner.Config, map[string]string, error) { + log := logr.FromContextOrDiscard(ctx) + var extraTools []adktool.Tool if memoryService != nil { saveTool, err := kagentmemory.NewSaveMemoryTool(memoryService) @@ -40,7 +46,12 @@ func CreateRunnerConfig( extraTools = append(extraTools, saveTool) } - adkAgent, subagentSessionIDs, err := agent.CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentNameFromAppName(appName), extraTools...) + stsPlugin, err := buildTokenPropagationPlugin(ctx, log) + if err != nil { + return runner.Config{}, nil, err + } + + adkAgent, subagentSessionIDs, err := agent.CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentNameFromAppName(appName), stsPlugin, extraTools...) if err != nil { return runner.Config{}, nil, fmt.Errorf("failed to create agent: %w", err) } @@ -61,11 +72,57 @@ func CreateRunnerConfig( runnerMemory = memoryService } + var adkPlugins []*adkplugin.Plugin + if stsPlugin != nil { + p, err := stsPlugin.ADKPlugin() + if err != nil { + return runner.Config{}, nil, fmt.Errorf("failed to create STS ADK plugin: %w", err) + } + if p != nil { + adkPlugins = append(adkPlugins, p) + } + } + cfg := runner.Config{ AppName: appName, Agent: adkAgent, SessionService: adkSessionService, MemoryService: runnerMemory, + PluginConfig: runner.PluginConfig{ + Plugins: adkPlugins, + }, } + return cfg, subagentSessionIDs, nil } + +func buildTokenPropagationPlugin(ctx context.Context, log logr.Logger) (*sts.TokenPropagationPlugin, error) { + propagateToken := strings.EqualFold(strings.TrimSpace(os.Getenv("KAGENT_PROPAGATE_TOKEN")), "true") + stsWellKnownURI := strings.TrimSpace(os.Getenv("STS_WELL_KNOWN_URI")) + if !propagateToken && stsWellKnownURI == "" { + return nil, nil + } + + // Propagate-only mode: keep parity with Python by enabling plugin without STS exchange. + if stsWellKnownURI == "" { + log.Info("Enabling token propagation plugin without STS exchange") + return sts.NewTokenPropagationPlugin(nil, log), nil + } + + integration, err := sts.NewSTSIntegration( + stsWellKnownURI, + "", + nil, // fetchActorToken + nil, // getSubjectToken + 0, // default timeout + true, // default verifySSL + false, // default useIssuerHost + ) + if err != nil { + return nil, fmt.Errorf("failed to initialize STS integration: %w", err) + } + + log.Info("Enabling STS token propagation plugin", + "wellKnownURI", stsWellKnownURI) + return sts.NewTokenPropagationPlugin(integration, log), nil +} diff --git a/go/core/test/e2e/invoke_api_test.go b/go/core/test/e2e/invoke_api_test.go index 540f362ec..528fce288 100644 --- a/go/core/test/e2e/invoke_api_test.go +++ b/go/core/test/e2e/invoke_api_test.go @@ -1044,6 +1044,15 @@ func TestE2EInvokeCrewAIAgent(t *testing.T) { } func TestE2EInvokeSTSIntegration(t *testing.T) { + runE2EInvokeSTSIntegration(t, "python", nil) +} + +func TestE2EGoInvokeSTSIntegration(t *testing.T) { + goRuntime := v1alpha2.DeclarativeRuntime_Go + runE2EInvokeSTSIntegration(t, "go", &goRuntime) +} + +func runE2EInvokeSTSIntegration(t *testing.T, runtimeName string, runtimeOverride *v1alpha2.DeclarativeRuntime) { // Setup mock STS server agentName := "test-sts" agentServiceAccount := fmt.Sprintf("system:serviceaccount:kagent:%s", agentName) @@ -1079,8 +1088,9 @@ func TestE2EInvokeSTSIntegration(t *testing.T) { modelCfg := setupModelConfig(t, cli, baseURL) agent := setupAgentWithOptions(t, cli, modelCfg.Name, tools, AgentOptions{ - Name: "test-sts-agent", + Name: "test-sts-agent-" + runtimeName, SystemMessage: "You are an agent that adds numbers using the add tool available to you through the everything-mcp-server.", + Runtime: runtimeOverride, Env: []corev1.EnvVar{ { Name: "STS_WELL_KNOWN_URI", @@ -1111,7 +1121,7 @@ func TestE2EInvokeSTSIntegration(t *testing.T) { a2aclient.WithHTTPClient(httpClient)) require.NoError(t, err) - t.Run("sync_invocation", func(t *testing.T) { + t.Run(runtimeName+"/sts_exchange_sync_invocation", func(t *testing.T) { runSyncTest(t, a2aClient, "add 3 and 5", "8", nil) // verify our mock STS server received the token exchange request @@ -1122,6 +1132,10 @@ func TestE2EInvokeSTSIntegration(t *testing.T) { // which contains the may act claim stsRequest := stsRequests[0] require.Equal(t, subjectToken, stsRequest.SubjectToken) + require.Equal(t, "urn:ietf:params:oauth:grant-type:token-exchange", stsRequest.GrantType) + require.Equal(t, "urn:ietf:params:oauth:token-type:jwt", stsRequest.SubjectTokenType) + require.NotEmpty(t, stsRequest.ActorToken) + require.Equal(t, "urn:ietf:params:oauth:token-type:jwt", stsRequest.ActorTokenType) }) } diff --git a/go/go.mod b/go/go.mod index c44f073c7..446180130 100644 --- a/go/go.mod +++ b/go/go.mod @@ -18,6 +18,7 @@ require ( github.com/fatih/color v1.19.0 github.com/go-logr/logr v1.4.3 github.com/go-logr/zapr v1.3.0 + github.com/golang-jwt/jwt/v5 v5.3.1 github.com/golang-migrate/migrate/v4 v4.19.1 // api dependencies github.com/google/uuid v1.6.0 @@ -221,7 +222,6 @@ require ( github.com/goccy/go-json v0.10.3 // indirect github.com/godoc-lint/godoc-lint v0.11.2 // indirect github.com/gofrs/flock v0.13.0 // indirect - github.com/golang-jwt/jwt/v5 v5.3.1 // indirect github.com/golangci/asciicheck v0.5.0 // indirect github.com/golangci/dupl v0.0.0-20260401084720-c99c5cf5c202 // indirect github.com/golangci/go-printf-func-name v0.1.1 // indirect From ad23d5c3ded02dbce1d28b5d1ad2947965c913a5 Mon Sep 17 00:00:00 2001 From: Jet Chiang Date: Fri, 15 May 2026 18:11:36 -0400 Subject: [PATCH 3/3] review comments Signed-off-by: Jet Chiang --- go/adk/pkg/sts/client.go | 29 ++++++++------------ go/adk/pkg/sts/client_test.go | 19 +++++++++++++ go/adk/pkg/sts/errors.go | 4 +-- go/adk/pkg/sts/integration.go | 18 ++++++------ go/adk/pkg/sts/integration_test.go | 19 +++++++++++++ go/adk/pkg/sts/models.go | 18 +++++------- go/adk/pkg/sts/plugin.go | 35 ++++++++---------------- go/adk/pkg/sts/plugin_test.go | 16 +++++++++++ go/adk/pkg/sts/utils.go | 44 +++++++++++++++++++++--------- 9 files changed, 127 insertions(+), 75 deletions(-) diff --git a/go/adk/pkg/sts/client.go b/go/adk/pkg/sts/client.go index 59be15b30..0055cd90a 100644 --- a/go/adk/pkg/sts/client.go +++ b/go/adk/pkg/sts/client.go @@ -2,7 +2,6 @@ package sts import ( "context" - "crypto/tls" "encoding/json" "fmt" "net/http" @@ -33,7 +32,7 @@ func normalizeSTSConfig(config STSConfig) STSConfig { config.Timeout = 5 } if config.VerifySSL == nil { - config.VerifySSL = boolPtr(true) + config.VerifySSL = new(true) } return config } @@ -63,13 +62,9 @@ func (c *STSClient) initialize(ctx context.Context) error { } if c.httpClient == nil { - transport := &http.Transport{} - if !*c.config.VerifySSL { - transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } c.httpClient = &http.Client{ Timeout: time.Duration(c.config.Timeout) * time.Second, - Transport: transport, + Transport: transportWithTLSVerification(*c.config.VerifySSL), } } @@ -144,11 +139,11 @@ func (c *STSClient) ExchangeToken( subjectTokenType TokenType, actorToken string, actorTokenType TokenType, - resource interface{}, - audience interface{}, + resource any, + audience any, scope string, requestedTokenType TokenType, - additionalParameters map[string]interface{}, + additionalParameters map[string]any, ) (*TokenExchangeResponse, error) { if err := c.initialize(ctx); err != nil { return nil, err @@ -197,7 +192,7 @@ func (c *STSClient) ExchangeToken( } // Parse error response - var responseData map[string]interface{} + var responseData map[string]any if err := json.NewDecoder(resp.Body).Decode(&responseData); err != nil { // Could not parse error as JSON return nil, NewTokenExchangeError( @@ -226,11 +221,11 @@ func (c *STSClient) Impersonate( ctx context.Context, subjectToken string, subjectTokenType TokenType, - resource interface{}, - audience interface{}, + resource any, + audience any, scope string, requestedTokenType TokenType, - additionalParameters map[string]interface{}, + additionalParameters map[string]any, ) (*TokenExchangeResponse, error) { return c.ExchangeToken( ctx, @@ -253,11 +248,11 @@ func (c *STSClient) Delegate( subjectTokenType TokenType, actorToken string, actorTokenType TokenType, - resource interface{}, - audience interface{}, + resource any, + audience any, scope string, requestedTokenType TokenType, - additionalParameters map[string]interface{}, + additionalParameters map[string]any, ) (*TokenExchangeResponse, error) { if subjectToken == "" { return nil, NewAuthenticationError("subject token required for delegation") diff --git a/go/adk/pkg/sts/client_test.go b/go/adk/pkg/sts/client_test.go index 424f4346d..0ff8cd976 100644 --- a/go/adk/pkg/sts/client_test.go +++ b/go/adk/pkg/sts/client_test.go @@ -209,6 +209,25 @@ func TestSTSClientInitializeRetriesAfterDiscoveryFailure(t *testing.T) { } } +func TestTransportWithTLSVerificationClonesDefaultTransport(t *testing.T) { + t.Parallel() + defaultTransport, ok := http.DefaultTransport.(*http.Transport) + if !ok { + t.Skip("http.DefaultTransport is not *http.Transport") + } + + transport := transportWithTLSVerification(false) + if transport == defaultTransport { + t.Fatal("transportWithTLSVerification returned http.DefaultTransport directly") + } + if transport.TLSHandshakeTimeout != defaultTransport.TLSHandshakeTimeout { + t.Fatalf("TLSHandshakeTimeout = %v, want %v", transport.TLSHandshakeTimeout, defaultTransport.TLSHandshakeTimeout) + } + if transport.TLSClientConfig == nil || !transport.TLSClientConfig.InsecureSkipVerify { + t.Fatal("TLSClientConfig.InsecureSkipVerify = false, want true") + } +} + func TestSTSClientDelegateWithoutSubjectToken(t *testing.T) { t.Parallel() client := NewSTSClient(STSConfig{ diff --git a/go/adk/pkg/sts/errors.go b/go/adk/pkg/sts/errors.go index 228ce5ac3..668bb39f4 100644 --- a/go/adk/pkg/sts/errors.go +++ b/go/adk/pkg/sts/errors.go @@ -68,9 +68,9 @@ type NetworkError struct { func (e *NetworkError) Error() string { if e.Cause != nil { - return fmt.Sprintf("STS network error: %s: %v", e.STSError.Message, e.Cause) + return fmt.Sprintf("STS network error: %s: %v", e.Message, e.Cause) } - return fmt.Sprintf("STS network error: %s", e.STSError.Message) + return fmt.Sprintf("STS network error: %s", e.Message) } // Unwrap returns the underlying error for errors.Is/As support. diff --git a/go/adk/pkg/sts/integration.go b/go/adk/pkg/sts/integration.go index 446e4edbd..f41bbfa57 100644 --- a/go/adk/pkg/sts/integration.go +++ b/go/adk/pkg/sts/integration.go @@ -3,6 +3,7 @@ package sts import ( "context" "fmt" + "sync" ) // GetSubjectTokenFunc is a function type for extracting subject tokens. @@ -30,6 +31,7 @@ type STSIntegration struct { fetchActorToken FetchActorTokenFunc getSubjectToken GetSubjectTokenFunc staticActorToken string // cached static actor token from service + actorTokenMu sync.Mutex } // NewSTSIntegration creates a new STS integration. @@ -96,6 +98,9 @@ func (i *STSIntegration) getActorToken(ctx context.Context) (string, error) { return i.fetchActorToken(ctx) } + i.actorTokenMu.Lock() + defer i.actorTokenMu.Unlock() + // Use cached static token if available if i.staticActorToken != "" { return i.staticActorToken, nil @@ -117,10 +122,7 @@ func (i *STSIntegration) getActorToken(ctx context.Context) (string, error) { func (i *STSIntegration) actorTokenForExchange(ctx context.Context) (string, error) { actorToken, err := i.getActorToken(ctx) if err != nil { - if i.fetchActorToken != nil { - return "", fmt.Errorf("failed to fetch actor token dynamically: %w", err) - } - return "", nil + return "", fmt.Errorf("failed to fetch actor token: %w", err) } if actorToken == "" { return "", nil @@ -134,8 +136,8 @@ func (i *STSIntegration) ExchangeToken( ctx context.Context, subjectToken string, subjectTokenType TokenType, - resource interface{}, - audience interface{}, + resource any, + audience any, scope string, requestedTokenType TokenType, ) (*TokenExchangeResponse, error) { @@ -155,8 +157,8 @@ func (i *STSIntegration) ExchangeTokenWithActorToken( subjectToken string, subjectTokenType TokenType, actorToken string, - resource interface{}, - audience interface{}, + resource any, + audience any, scope string, requestedTokenType TokenType, ) (*TokenExchangeResponse, error) { diff --git a/go/adk/pkg/sts/integration_test.go b/go/adk/pkg/sts/integration_test.go index f1acae78e..a31bd91ac 100644 --- a/go/adk/pkg/sts/integration_test.go +++ b/go/adk/pkg/sts/integration_test.go @@ -82,3 +82,22 @@ func TestSTSIntegrationStaticActorTokenCached(t *testing.T) { t.Fatalf("second getActorToken() = %q, want cached %q", got2, "static-token") } } + +func TestSTSIntegrationStaticActorTokenErrorPropagates(t *testing.T) { + t.Parallel() + dir := t.TempDir() + tokenPath := filepath.Join(dir, "empty-token") + if err := os.WriteFile(tokenPath, []byte(" \n\t "), 0o600); err != nil { + t.Fatalf("failed to write token file: %v", err) + } + + i, err := NewSTSIntegration("http://example.com/.well-known", tokenPath, nil, nil, 5, true, false) + if err != nil { + t.Fatalf("NewSTSIntegration() error = %v", err) + } + + _, err = i.actorTokenForExchange(context.Background()) + if err == nil { + t.Fatalf("actorTokenForExchange() error = nil, want non-nil") + } +} diff --git a/go/adk/pkg/sts/models.go b/go/adk/pkg/sts/models.go index 713fc1765..6a18bd283 100644 --- a/go/adk/pkg/sts/models.go +++ b/go/adk/pkg/sts/models.go @@ -41,15 +41,15 @@ type TokenExchangeRequest struct { // ActorTokenType is the type of the actor_token (required if ActorToken is set) ActorTokenType TokenType `json:"actor_token_type,omitempty"` // Resource is the logical name of the target service or resource (optional) - Resource interface{} `json:"resource,omitempty"` // Can be string or []string + Resource any `json:"resource,omitempty"` // Can be string or []string // Audience is the logical name of the target service or resource (optional) - Audience interface{} `json:"audience,omitempty"` // Can be string or []string + Audience any `json:"audience,omitempty"` // Can be string or []string // Scope is the scope of the requested token (optional) Scope string `json:"scope,omitempty"` // RequestedTokenType is the type of the requested token (optional) RequestedTokenType TokenType `json:"requested_token_type,omitempty"` // AdditionalParameters contains additional parameters for the request (optional) - AdditionalParameters map[string]interface{} `json:"-"` // Not serialized directly, merged into form data + AdditionalParameters map[string]any `json:"-"` // Not serialized directly, merged into form data } // IsDelegationRequest checks if this is a delegation request (has actor_token). @@ -77,7 +77,7 @@ type TokenExchangeResponse struct { // RefreshToken is the refresh token if applicable (optional) RefreshToken string `json:"refresh_token,omitempty"` // AdditionalParameters contains additional response parameters (optional) - AdditionalParameters map[string]interface{} `json:"-"` + AdditionalParameters map[string]any `json:"-"` } // TokenExchangeErrorResponse represents an RFC 8693 Token Exchange Error response. @@ -89,7 +89,7 @@ type TokenExchangeErrorResponse struct { // ErrorURI is a URI identifying the error (optional) ErrorURI string `json:"error_uri,omitempty"` // AdditionalParameters contains additional error parameters (optional) - AdditionalParameters map[string]interface{} `json:"-"` + AdditionalParameters map[string]any `json:"-"` } // WellKnownConfiguration represents OAuth 2.0 Authorization Server Metadata. @@ -103,7 +103,7 @@ type WellKnownConfiguration struct { // TokenEndpointAuthSigningAlgValuesSupported is the list of supported signing algorithms (optional) TokenEndpointAuthSigningAlgValuesSupported []string `json:"token_endpoint_auth_signing_alg_values_supported,omitempty"` // AdditionalParameters contains additional configuration parameters (optional) - AdditionalParameters map[string]interface{} `json:"-"` + AdditionalParameters map[string]any `json:"-"` } // STSConfig holds configuration for the STS client. @@ -123,11 +123,7 @@ func DefaultSTSConfig(wellKnownURI string) STSConfig { return STSConfig{ WellKnownURI: wellKnownURI, Timeout: 5, - VerifySSL: boolPtr(true), + VerifySSL: new(true), UseIssuerHost: false, } } - -func boolPtr(v bool) *bool { - return &v -} diff --git a/go/adk/pkg/sts/plugin.go b/go/adk/pkg/sts/plugin.go index 2c3c5d288..39fe26f5f 100644 --- a/go/adk/pkg/sts/plugin.go +++ b/go/adk/pkg/sts/plugin.go @@ -235,7 +235,6 @@ func (p *TokenPropagationPlugin) AfterRunCallback(ctx agent.InvocationContext) { p.logger.V(1).Info("Removing expired actor token from cache") p.actorTokenCache = nil } - } // HeaderProvider returns a map of headers to inject into MCP tool HTTP requests. @@ -312,31 +311,19 @@ func extractJWTExpiry(token string) int64 { return 0 } - // Parse without signature verification — we only need the exp claim. - parsed, err := jwt.Parse(token, - func(t *jwt.Token) (interface{}, error) { return nil, nil }, - jwt.WithoutClaimsValidation(), - ) - // err is expected (no key), but parsed may still carry the claims. - if parsed == nil || parsed.Claims == nil { - if err != nil { - // Truly unparseable token (not a JWT, etc.) - return 0 - } + claims := jwt.MapClaims{} + if _, _, err := jwt.NewParser(jwt.WithoutClaimsValidation()).ParseUnverified(token, claims); err != nil { + return 0 } - if parsed != nil { - if claims, ok := parsed.Claims.(jwt.MapClaims); ok { - if exp, ok := claims["exp"]; ok { - switch v := exp.(type) { - case float64: - return int64(v) - case int64: - return v - case int: - return int64(v) - } - } + if exp, ok := claims["exp"]; ok { + switch v := exp.(type) { + case float64: + return int64(v) + case int64: + return v + case int: + return int64(v) } } diff --git a/go/adk/pkg/sts/plugin_test.go b/go/adk/pkg/sts/plugin_test.go index 71d85113f..8bf15cb0b 100644 --- a/go/adk/pkg/sts/plugin_test.go +++ b/go/adk/pkg/sts/plugin_test.go @@ -9,6 +9,7 @@ import ( "time" "github.com/go-logr/logr" + "github.com/golang-jwt/jwt/v5" kagentmodels "github.com/kagent-dev/kagent/go/adk/pkg/models" "google.golang.org/adk/agent" "google.golang.org/adk/session" @@ -137,3 +138,18 @@ func TestBeforeRunCallback_ReusesCachedDynamicActorTokenForExchange(t *testing.T t.Fatalf("token exchange calls = %d, want 2", exchangeCount) } } + +func TestExtractJWTExpiryUsesUnverifiedClaims(t *testing.T) { + t.Parallel() + want := time.Now().Add(time.Hour).Unix() + token, err := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "exp": want, + }).SignedString([]byte("secret")) + if err != nil { + t.Fatalf("failed to sign token: %v", err) + } + + if got := extractJWTExpiry(token); got != want { + t.Fatalf("extractJWTExpiry() = %d, want %d", got, want) + } +} diff --git a/go/adk/pkg/sts/utils.go b/go/adk/pkg/sts/utils.go index 0e4514758..7bea5dd49 100644 --- a/go/adk/pkg/sts/utils.go +++ b/go/adk/pkg/sts/utils.go @@ -22,13 +22,8 @@ const ( // NOTE: This makes an HTTP request. Callers should cache the result. func FetchWellKnownConfiguration(ctx context.Context, wellKnownURI string, timeout int, verifySSL bool, useIssuerHost bool) (*WellKnownConfiguration, error) { client := &http.Client{ - Timeout: time.Duration(timeout) * time.Second, - } - - if !verifySSL { - client.Transport = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } + Timeout: time.Duration(timeout) * time.Second, + Transport: transportWithTLSVerification(verifySSL), } req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURI, nil) @@ -46,7 +41,7 @@ func FetchWellKnownConfiguration(ctx context.Context, wellKnownURI string, timeo return nil, NewNetworkError(fmt.Sprintf("failed to fetch well-known configuration: HTTP %d", resp.StatusCode), nil) } - var data map[string]interface{} + var data map[string]any if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { return nil, NewConfigurationError(fmt.Sprintf("invalid well-known configuration response: %v", err)) } @@ -103,7 +98,7 @@ func FetchWellKnownConfiguration(ctx context.Context, wellKnownURI string, timeo } // ParseTokenExchangeError parses a token exchange error response. -func ParseTokenExchangeError(responseData map[string]interface{}) *TokenExchangeError { +func ParseTokenExchangeError(responseData map[string]any) *TokenExchangeError { errorCode := "unknown_error" if ec, ok := responseData["error"].(string); ok { errorCode = ec @@ -117,16 +112,16 @@ func ParseTokenExchangeError(responseData map[string]interface{}) *TokenExchange return NewTokenExchangeError(errorCode, errorDescription, 0) } -// Helper functions to safely extract values from map[string]interface{} -func getString(m map[string]interface{}, key string) string { +// Helper functions to safely extract values from map[string]any +func getString(m map[string]any, key string) string { if v, ok := m[key].(string); ok { return v } return "" } -func getStringSlice(m map[string]interface{}, key string) []string { - if v, ok := m[key].([]interface{}); ok { +func getStringSlice(m map[string]any, key string) []string { + if v, ok := m[key].([]any); ok { result := make([]string, 0, len(v)) for _, item := range v { if s, ok := item.(string); ok { @@ -137,3 +132,26 @@ func getStringSlice(m map[string]interface{}, key string) []string { } return nil } + +func transportWithTLSVerification(verifySSL bool) *http.Transport { + transport := cloneDefaultHTTPTransport() + if verifySSL { + return transport + } + + tlsConfig := &tls.Config{} + if transport.TLSClientConfig != nil { + tlsConfig = transport.TLSClientConfig.Clone() + } + tlsConfig.InsecureSkipVerify = true + transport.TLSClientConfig = tlsConfig + return transport +} + +func cloneDefaultHTTPTransport() *http.Transport { + defaultTransport, ok := http.DefaultTransport.(*http.Transport) + if !ok { + return &http.Transport{} + } + return defaultTransport.Clone() +}