Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions go/adk/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/go-logr/logr"
"github.com/kagent-dev/kagent/go/adk/pkg/mcp"
"github.com/kagent-dev/kagent/go/adk/pkg/models"
"github.com/kagent-dev/kagent/go/adk/pkg/sts"
"github.com/kagent-dev/kagent/go/adk/pkg/tools"
"github.com/kagent-dev/kagent/go/api/adk"
"google.golang.org/adk/agent"
Expand All @@ -33,23 +34,28 @@ const (
// agentName is used as the ADK agent identity (appears in event Author field).
// extraTools are appended to the agent's tool list (e.g. save_memory).
func CreateGoogleADKAgent(ctx context.Context, agentConfig *adk.AgentConfig, agentName string, extraTools ...tool.Tool) (agent.Agent, error) {
a, _, err := CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentName, extraTools...)
a, _, err := CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentName, nil, extraTools...)
return a, err
}

// CreateGoogleADKAgentWithSubagentSessionIDs creates a Google ADK agent and a
// map of remote-subagent tool name → A2A context session ID (for stamping
// outbound A2A events). Callers that only need the agent can use
// CreateGoogleADKAgent.
func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig *adk.AgentConfig, agentName string, extraTools ...tool.Tool) (agent.Agent, map[string]string, error) {
// Optional stsPlugin can be provided for token propagation to MCP tools.
func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig *adk.AgentConfig, agentName string, stsPlugin *sts.TokenPropagationPlugin, extraTools ...tool.Tool) (agent.Agent, map[string]string, error) {
log := logr.FromContextOrDiscard(ctx)

if agentConfig == nil {
return nil, nil, fmt.Errorf("agent config is required")
}

propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) == "true"
toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken)
var dynamicHeaderProvider mcp.DynamicHeaderProvider
if stsPlugin != nil {
dynamicHeaderProvider = stsPlugin.HeaderProvider
}
toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken, dynamicHeaderProvider)
subagentSessionIDs := make(map[string]string)

var remoteAgentTools []tool.Tool
Expand Down Expand Up @@ -104,6 +110,7 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig
beforeToolCallbacks := []llmagent.BeforeToolCallback{}
// Strip synthetic HITL tool messages from the model request to avoid unnecessary token usage.
beforeModelCallbacks := []llmagent.BeforeModelCallback{}

if len(approvalSet) > 0 {
log.Info("Wiring approval callback", "toolCount", len(approvalSet))
beforeToolCallbacks = append(beforeToolCallbacks, MakeApprovalCallback(approvalSet))
Expand Down
41 changes: 34 additions & 7 deletions go/adk/pkg/mcp/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ import (
"google.golang.org/adk/tool/mcptoolset"
)

// DynamicHeaderProvider is a function that returns headers to inject into MCP requests.
// It receives the context and should return a map of headers.
// This is used for dynamic token injection (e.g., STS tokens) per session.
type DynamicHeaderProvider func(ctx context.Context) map[string]string

const (
// Default timeout matching Python KAGENT_REMOTE_AGENT_TIMEOUT
defaultTimeout = 30 * time.Minute
Expand Down Expand Up @@ -62,9 +67,10 @@ func allowedRequestHeaders(ctx context.Context, allowed []string) map[string]str
type mcpServerParams struct {
URL string
Headers map[string]string
AllowedHeaders []string // header names to forward from incoming request
PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders
ServerType string // "http" or "sse"
AllowedHeaders []string // header names to forward from incoming request
PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders
HeaderProvider DynamicHeaderProvider // optional per-request headers derived from invocation context (e.g., STS exchanged access tokens)
ServerType string // "http" or "sse"
Timeout *float64
SseReadTimeout *float64
TLSInsecureSkipVerify *bool
Expand All @@ -79,7 +85,16 @@ type mcpServerParams struct {
// When propagateToken is true, Authorization is forwarded to every MCP server
// independently of AllowedHeaders, mirroring the Python ADKTokenPropagationPlugin
// behaviour triggered by KAGENT_PROPAGATE_TOKEN.
func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig, propagateToken bool) []tool.Toolset {
//
// Optional headerProvider can be used to inject per-request headers
// derived from invocation context (e.g., STS exchanged access tokens).
func CreateToolsets(
ctx context.Context,
httpTools []adk.HttpMcpServerConfig,
sseTools []adk.SseMcpServerConfig,
propagateToken bool,
headerProvider DynamicHeaderProvider,
) []tool.Toolset {
log := logr.FromContextOrDiscard(ctx)
var toolsets []tool.Toolset

Expand All @@ -90,6 +105,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
Headers: httpTool.Params.Headers,
AllowedHeaders: httpTool.AllowedHeaders,
PropagateToken: propagateToken,
HeaderProvider: headerProvider,
ServerType: "http",
Timeout: httpTool.Params.Timeout,
SseReadTimeout: httpTool.Params.SseReadTimeout,
Expand All @@ -111,6 +127,7 @@ func CreateToolsets(ctx context.Context, httpTools []adk.HttpMcpServerConfig, ss
Headers: sseTool.Params.Headers,
AllowedHeaders: sseTool.AllowedHeaders,
PropagateToken: propagateToken,
HeaderProvider: headerProvider,
ServerType: "sse",
Timeout: sseTool.Params.Timeout,
SseReadTimeout: sseTool.Params.SseReadTimeout,
Expand Down Expand Up @@ -208,12 +225,13 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
}

var httpTransport http.RoundTripper = baseTransport
if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken {
if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken || params.HeaderProvider != nil {
httpTransport = &headerRoundTripper{
base: baseTransport,
headers: params.Headers,
allowedHeaders: params.AllowedHeaders,
propagateToken: params.PropagateToken,
headerProvider: params.HeaderProvider,
}
}

Expand All @@ -239,18 +257,20 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp
}

// headerRoundTripper wraps an http.RoundTripper to add custom headers to all
// requests. It supports three sources of headers, applied in this order so that
// requests. It supports four sources of headers, applied in this order so that
// higher-priority sources win on collision:
// 1. propagateToken: when true, Authorization is read from the incoming A2A
// CallContext and forwarded unconditionally (independent of allowedHeaders).
// 2. allowedHeaders: explicit per-header forwarding from the A2A CallContext.
// 3. headers: static key/value pairs configured on the MCP server spec (highest
// 3. headerProvider: runtime headers derived from ADK context, such as STS tokens.
// 4. headers: static key/value pairs configured on the MCP server spec (highest
// priority — always wins).
type headerRoundTripper struct {
base http.RoundTripper
headers map[string]string
allowedHeaders []string // header names (case-insensitive) to forward from A2A context
propagateToken bool // when true, Authorization is forwarded independently
headerProvider DynamicHeaderProvider
}

func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
Expand All @@ -273,6 +293,13 @@ func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
req.Header.Set(k, v)
}

// Dynamic headers (e.g., STS access tokens) override propagated/allowed headers.
if rt.headerProvider != nil {
for key, value := range rt.headerProvider(req.Context()) {
req.Header.Set(key, value)
}
}

// Apply static headers last — they take precedence over all dynamic sources.
for key, value := range rt.headers {
req.Header.Set(key, value)
Expand Down
77 changes: 77 additions & 0 deletions go/adk/pkg/mcp/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,3 +319,80 @@ func TestAllowedRequestHeaders_ReturnsNilWhenNoMatches(t *testing.T) {
t.Errorf("expected nil when no allowed headers are present, got %v", got)
}
}

// TestDynamicHeaders_OverridePropagatedAuthorization verifies dynamic headers
// take precedence over propagated and allowed request headers.
func TestDynamicHeaders_OverridePropagatedAuthorization(t *testing.T) {
t.Parallel()
var capturedAuth, capturedTrace string

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
capturedTrace = r.Header.Get("X-Trace-Id")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: x-trace-id here seems to be a bit of a misrepresentation given that our implementation only overwrites the Authorization header.

w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

ctx := a2aCtx(map[string][]string{
"Authorization": {"Bearer incoming"},
"X-Trace-Id": {"trace-from-request"},
})

rt := &headerRoundTripper{
base: http.DefaultTransport,
propagateToken: true,
allowedHeaders: []string{"Authorization", "X-Trace-Id"},
headerProvider: func(context.Context) map[string]string {
return map[string]string{
"Authorization": "Bearer sts-exchanged",
"X-Trace-Id": "trace-from-dynamic",
}
},
}

req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip failed: %v", err)
}
resp.Body.Close()

if capturedAuth != "Bearer sts-exchanged" {
t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer sts-exchanged")
}
if capturedTrace != "trace-from-dynamic" {
t.Errorf("X-Trace-Id: got %q, want %q", capturedTrace, "trace-from-dynamic")
}
}

// TestStaticHeaders_OverrideDynamic verifies static configured headers remain
// the highest-precedence source.
func TestStaticHeaders_OverrideDynamic(t *testing.T) {
t.Parallel()
var capturedAuth string

srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedAuth = r.Header.Get("Authorization")
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()

rt := &headerRoundTripper{
base: http.DefaultTransport,
headers: map[string]string{"Authorization": "Bearer static"},
headerProvider: func(context.Context) map[string]string {
return map[string]string{"Authorization": "Bearer dynamic"}
},
}

req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil)
resp, err := rt.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip failed: %v", err)
}
resp.Body.Close()

if capturedAuth != "Bearer static" {
t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer static")
}
}
59 changes: 58 additions & 1 deletion go/adk/pkg/runner/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package runner
import (
"context"
"fmt"
"os"
"strings"

"github.com/go-logr/logr"
"github.com/kagent-dev/kagent/go/adk/pkg/agent"
kagentmemory "github.com/kagent-dev/kagent/go/adk/pkg/memory"
"github.com/kagent-dev/kagent/go/adk/pkg/session"
"github.com/kagent-dev/kagent/go/adk/pkg/sts"
"github.com/kagent-dev/kagent/go/api/adk"
adkmemory "google.golang.org/adk/memory"
adkplugin "google.golang.org/adk/plugin"
"google.golang.org/adk/runner"
adksession "google.golang.org/adk/session"
adktool "google.golang.org/adk/tool"
Expand All @@ -31,6 +35,8 @@ func CreateRunnerConfig(
appName string,
memoryService *kagentmemory.KagentMemoryService,
) (runner.Config, map[string]string, error) {
log := logr.FromContextOrDiscard(ctx)

var extraTools []adktool.Tool
if memoryService != nil {
saveTool, err := kagentmemory.NewSaveMemoryTool(memoryService)
Expand All @@ -40,7 +46,12 @@ func CreateRunnerConfig(
extraTools = append(extraTools, saveTool)
}

adkAgent, subagentSessionIDs, err := agent.CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentNameFromAppName(appName), extraTools...)
stsPlugin, err := buildTokenPropagationPlugin(ctx, log)
if err != nil {
return runner.Config{}, nil, err
}

adkAgent, subagentSessionIDs, err := agent.CreateGoogleADKAgentWithSubagentSessionIDs(ctx, agentConfig, agentNameFromAppName(appName), stsPlugin, extraTools...)
if err != nil {
return runner.Config{}, nil, fmt.Errorf("failed to create agent: %w", err)
}
Expand All @@ -61,11 +72,57 @@ func CreateRunnerConfig(
runnerMemory = memoryService
}

var adkPlugins []*adkplugin.Plugin
if stsPlugin != nil {
p, err := stsPlugin.ADKPlugin()
if err != nil {
return runner.Config{}, nil, fmt.Errorf("failed to create STS ADK plugin: %w", err)
}
if p != nil {
adkPlugins = append(adkPlugins, p)
}
}

cfg := runner.Config{
AppName: appName,
Agent: adkAgent,
SessionService: adkSessionService,
MemoryService: runnerMemory,
PluginConfig: runner.PluginConfig{
Plugins: adkPlugins,
},
}

return cfg, subagentSessionIDs, nil
}

func buildTokenPropagationPlugin(ctx context.Context, log logr.Logger) (*sts.TokenPropagationPlugin, error) {
propagateToken := strings.EqualFold(strings.TrimSpace(os.Getenv("KAGENT_PROPAGATE_TOKEN")), "true")
stsWellKnownURI := strings.TrimSpace(os.Getenv("STS_WELL_KNOWN_URI"))
if !propagateToken && stsWellKnownURI == "" {
return nil, nil
}

// Propagate-only mode: keep parity with Python by enabling plugin without STS exchange.
if stsWellKnownURI == "" {
log.Info("Enabling token propagation plugin without STS exchange")
return sts.NewTokenPropagationPlugin(nil, log), nil
}

integration, err := sts.NewSTSIntegration(
stsWellKnownURI,
"",
nil, // fetchActorToken
nil, // getSubjectToken
0, // default timeout
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't the default time out 5s?

true, // default verifySSL
false, // default useIssuerHost
)
if err != nil {
return nil, fmt.Errorf("failed to initialize STS integration: %w", err)
}

log.Info("Enabling STS token propagation plugin",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we format this log as one line?

"wellKnownURI", stsWellKnownURI)
return sts.NewTokenPropagationPlugin(integration, log), nil
}
54 changes: 54 additions & 0 deletions go/adk/pkg/sts/actor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package sts

import (
"fmt"
"os"
"strings"
)

// DefaultServiceAccountTokenPath is the default path for Kubernetes service account tokens.
const DefaultServiceAccountTokenPath = "/var/run/secrets/kubernetes.io/serviceaccount/token"

// ActorTokenService provides actor tokens for STS delegation.
// It reads Kubernetes service account tokens from a file path.
type ActorTokenService struct {
tokenPath string
}

// NewActorTokenService creates a new ActorTokenService.
// If tokenPath is empty, it uses the default Kubernetes service account token path.
func NewActorTokenService(tokenPath string) *ActorTokenService {
if tokenPath == "" {
tokenPath = DefaultServiceAccountTokenPath
}
return &ActorTokenService{
tokenPath: tokenPath,
}
}

// GetActorToken retrieves the actor token for STS delegation.
// This method reads the token from the file each time it's called.
// If loading fails, it returns an empty string and an error (or nil if file doesn't exist).
func (s *ActorTokenService) GetActorToken() (string, error) {
// Check if file exists first
if _, err := os.Stat(s.tokenPath); os.IsNotExist(err) {
return "", nil
}

data, err := os.ReadFile(s.tokenPath)
if err != nil {
return "", fmt.Errorf("failed to read actor token from %s: %w", s.tokenPath, err)
}

token := strings.TrimSpace(string(data))
if token == "" {
return "", fmt.Errorf("empty actor token found at %s", s.tokenPath)
}

return token, nil
}

// TokenPath returns the configured token path.
func (s *ActorTokenService) TokenPath() string {
return s.tokenPath
}
Loading
Loading