From 0f4c548a73b3190f35fccc0b79b8c2e2e1b908ca Mon Sep 17 00:00:00 2001 From: mesutoezdil Date: Wed, 3 Jun 2026 22:48:55 +0200 Subject: [PATCH] fix(bedrock): preserve thinking blocks in multi-turn tool use When extended thinking is active, Bedrock returns thinking content blocks together with tool use blocks. The API requires these blocks to be sent back unmodified in the next request. Without them, Bedrock returns a ValidationException with toolUse.input is empty. Only emit thinking blocks for the last assistant turn before tool results. Sending them in all turns causes token counts to compound across long sessions. Truncate tool results in historical turns to 2000 chars. Older turns do not need full fidelity for large kubectl or YAML outputs. Extract inference config construction into a testable helper so that the temperature/top_p exclusion logic for extended thinking can be exercised directly without mocking the AWS Bedrock client. Fixes #1870 Signed-off-by: mesutoezdil --- go/adk/pkg/models/bedrock.go | 154 ++++++++++++++++--- go/adk/pkg/models/bedrock_test.go | 248 ++++++++++++++++++++++++++++++ 2 files changed, 379 insertions(+), 23 deletions(-) diff --git a/go/adk/pkg/models/bedrock.go b/go/adk/pkg/models/bedrock.go index d9db5a842e..216a95dcc4 100644 --- a/go/adk/pkg/models/bedrock.go +++ b/go/adk/pkg/models/bedrock.go @@ -171,20 +171,9 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques // is written with the sanitized name Bedrock already knows about. messages, systemInstruction := convertGenaiContentsToBedrockMessages(req.Contents, nameMap) - // Build inference config - var inferenceConfig *types.InferenceConfiguration - if m.Config.MaxTokens != nil || m.Config.Temperature != nil || m.Config.TopP != nil { - inferenceConfig = &types.InferenceConfiguration{} - if m.Config.MaxTokens != nil { - inferenceConfig.MaxTokens = aws.Int32(int32(*m.Config.MaxTokens)) - } - if m.Config.Temperature != nil { - inferenceConfig.Temperature = aws.Float32(float32(*m.Config.Temperature)) - } - if m.Config.TopP != nil { - inferenceConfig.TopP = aws.Float32(float32(*m.Config.TopP)) - } - } + // temperature/top_p must not be sent when thinking is active. https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + _, thinkingEnabled := m.Config.AdditionalModelRequestFields["thinking"] + inferenceConfig := buildInferenceConfig(m.Config, thinkingEnabled) // Build system prompt var systemPrompt []types.SystemContentBlock @@ -248,6 +237,10 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me toolCalls := make(map[int32]*streamingToolCall) var completedToolCalls []*genai.Part + // https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + reasoningBlocks := make(map[int32]*streamingReasoningBlock) + var completedThinkingParts []*genai.Part + // Get the event stream and read events from the channel stream := output.GetStream() defer stream.Close() @@ -295,10 +288,22 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me if tc, ok := toolCalls[blockIdx]; ok && delta.Value.Input != nil { tc.InputJSON += aws.ToString(delta.Value.Input) } + + case *types.ContentBlockDeltaMemberReasoningContent: + if _, ok := reasoningBlocks[blockIdx]; !ok { + reasoningBlocks[blockIdx] = &streamingReasoningBlock{} + } + rb := reasoningBlocks[blockIdx] + switch inner := delta.Value.(type) { + case *types.ReasoningContentBlockDeltaMemberText: + rb.Text.WriteString(inner.Value) + case *types.ReasoningContentBlockDeltaMemberSignature: + rb.Signature = inner.Value + } } } - // Handle content block stop (tool use complete) + // Handle content block stop (tool use or thinking block complete) if stop, ok := event.(*types.ConverseStreamOutputMemberContentBlockStop); ok { blockIdx := aws.ToInt32(stop.Value.ContentBlockIndex) if tc, ok := toolCalls[blockIdx]; ok { @@ -316,7 +321,18 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me Args: args, } completedToolCalls = append(completedToolCalls, &genai.Part{FunctionCall: functionCall}) - delete(toolCalls, blockIdx) // Clean up + delete(toolCalls, blockIdx) + } + if rb, ok := reasoningBlocks[blockIdx]; ok { + part := &genai.Part{ + Thought: true, + Text: rb.Text.String(), + } + if rb.Signature != "" { + part.ThoughtSignature = []byte(rb.Signature) + } + completedThinkingParts = append(completedThinkingParts, part) + delete(reasoningBlocks, blockIdx) } } @@ -337,8 +353,9 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me } } - // Build final response + // thinking parts first; block order must match what Bedrock returned. finalParts := []*genai.Part{} + finalParts = append(finalParts, completedThinkingParts...) text := aggregatedText.String() if text != "" { finalParts = append(finalParts, &genai.Part{Text: text}) @@ -366,6 +383,11 @@ type streamingToolCall struct { InputJSON string // Accumulated JSON input } +type streamingReasoningBlock struct { + Text strings.Builder + Signature string +} + // parseArgs parses the accumulated JSON input into a map func (tc *streamingToolCall) parseArgs() map[string]any { if tc.InputJSON == "" { @@ -403,6 +425,20 @@ func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, parts := []*genai.Part{} if message, ok := output.Output.(*types.ConverseOutputMemberMessage); ok { for _, block := range message.Value.Content { + // https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + if reasoningBlock, ok := block.(*types.ContentBlockMemberReasoningContent); ok { + if textBlock, ok := reasoningBlock.Value.(*types.ReasoningContentBlockMemberReasoningText); ok { + part := &genai.Part{ + Thought: true, + Text: aws.ToString(textBlock.Value.Text), + } + if textBlock.Value.Signature != nil { + part.ThoughtSignature = []byte(aws.ToString(textBlock.Value.Signature)) + } + parts = append(parts, part) + } + continue + } // Handle text content if textBlock, ok := block.(*types.ContentBlockMemberText); ok { parts = append(parts, &genai.Part{Text: textBlock.Value}) @@ -473,6 +509,15 @@ func documentToMap(doc document.Interface) map[string]any { return result } +const historyToolResultMaxLen = 2000 + +func truncateToolResult(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + fmt.Sprintf("\n... [truncated, %d chars omitted]", len(s)-maxLen) +} + // convertGenaiContentsToBedrockMessages converts genai.Content to Bedrock Converse API message format. // nameMap is the original->sanitized tool name map produced by convertGenaiToolsToBedrock. // Any FunctionCall found in the conversation history is written with the sanitized name so @@ -486,17 +531,40 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma idMap := make(map[string]string) idCounter := 0 - for _, content := range contents { + // Bedrock only requires thinking blocks in the last assistant turn before tool results. + // Sending them in earlier turns causes token counts to compound across long sessions. + // Truncate tool results in all turns except the most recent one carrying them. + lastThinkingIdx, lastToolResultIdx := -1, -1 + for i, c := range contents { + if c == nil { + continue + } + for _, p := range c.Parts { + if p == nil { + continue + } + if p.Thought && (c.Role == "model" || c.Role == "assistant") { + lastThinkingIdx = i + } + if p.FunctionResponse != nil && c.Role == "user" { + lastToolResultIdx = i + } + } + } + + for i, content := range contents { if content == nil || len(content.Parts) == 0 { continue } - // Determine role role := types.ConversationRoleUser if content.Role == "model" || content.Role == "assistant" { role = types.ConversationRoleAssistant } + emitThinking := i == lastThinkingIdx + truncateTools := i != lastToolResultIdx + var contentBlocks []types.ContentBlock for _, part := range content.Parts { @@ -504,9 +572,26 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma continue } - // Handle text + // Thought parts also carry Text; check Thought first. https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + if part.Thought { + if !emitThinking { + continue + } + reasoningText := &types.ReasoningTextBlock{ + Text: aws.String(part.Text), + } + if len(part.ThoughtSignature) > 0 { + reasoningText.Signature = aws.String(string(part.ThoughtSignature)) + } + contentBlocks = append(contentBlocks, &types.ContentBlockMemberReasoningContent{ + Value: &types.ReasoningContentBlockMemberReasoningText{ + Value: *reasoningText, + }, + }) + continue + } + if part.Text != "" { - // Check if this is a system message if content.Role == "system" { systemInstruction = part.Text continue @@ -536,10 +621,11 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma continue } - // Handle function response (tool result in Bedrock terminology) if part.FunctionResponse != nil { - // Extract response content result := extractFunctionResponseContent(part.FunctionResponse.Response) + if truncateTools { + result = truncateToolResult(result, historyToolResultMaxLen) + } toolResult := types.ToolResultBlock{ ToolUseId: aws.String(sanitizeBedrockToolID(part.FunctionResponse.ID, idMap, &idCounter)), Content: []types.ToolResultContentBlock{ @@ -641,3 +727,25 @@ func bedrockStopReasonToGenai(reason types.StopReason) genai.FinishReason { return genai.FinishReasonStop } } + +// buildInferenceConfig constructs the Bedrock InferenceConfiguration from a +// BedrockConfig. When thinking is enabled, temperature and top_p must be +// omitted per the Bedrock extended-thinking API contract. +func buildInferenceConfig(cfg *BedrockConfig, thinkingEnabled bool) *types.InferenceConfiguration { + if cfg.MaxTokens == nil && (thinkingEnabled || (cfg.Temperature == nil && cfg.TopP == nil)) { + return nil + } + ic := &types.InferenceConfiguration{} + if cfg.MaxTokens != nil { + ic.MaxTokens = aws.Int32(int32(*cfg.MaxTokens)) + } + if !thinkingEnabled { + if cfg.Temperature != nil { + ic.Temperature = aws.Float32(float32(*cfg.Temperature)) + } + if cfg.TopP != nil { + ic.TopP = aws.Float32(float32(*cfg.TopP)) + } + } + return ic +} diff --git a/go/adk/pkg/models/bedrock_test.go b/go/adk/pkg/models/bedrock_test.go index de2d1c3caf..9985e56405 100644 --- a/go/adk/pkg/models/bedrock_test.go +++ b/go/adk/pkg/models/bedrock_test.go @@ -2,6 +2,7 @@ package models import ( "encoding/json" + "strings" "testing" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" @@ -102,6 +103,41 @@ func TestConvertGenaiContentsToBedrockMessages(t *testing.T) { } }, }, + { + name: "thinking block preserved as ReasoningContent", + contents: []*genai.Content{ + { + Role: "model", + Parts: []*genai.Part{ + {Thought: true, Text: "let me think", ThoughtSignature: []byte("sig123")}, + {FunctionCall: &genai.FunctionCall{ID: "c1", Name: "get_weather", Args: map[string]any{"location": "Paris"}}}, + }, + }, + }, + wantMsgCount: 1, + checkMsg: func(t *testing.T, msgs []types.Message) { + if len(msgs[0].Content) != 2 { + t.Fatalf("expected 2 blocks (thinking + toolUse), got %d", len(msgs[0].Content)) + } + rb, ok := msgs[0].Content[0].(*types.ContentBlockMemberReasoningContent) + if !ok { + t.Fatalf("block 0: want *ContentBlockMemberReasoningContent, got %T", msgs[0].Content[0]) + } + rt, ok := rb.Value.(*types.ReasoningContentBlockMemberReasoningText) + if !ok { + t.Fatalf("reasoning value: want *ReasoningContentBlockMemberReasoningText, got %T", rb.Value) + } + if *rt.Value.Text != "let me think" { + t.Errorf("text: want %q, got %q", "let me think", *rt.Value.Text) + } + if *rt.Value.Signature != "sig123" { + t.Errorf("signature: want %q, got %q", "sig123", *rt.Value.Signature) + } + if _, ok := msgs[0].Content[1].(*types.ContentBlockMemberToolUse); !ok { + t.Errorf("block 1: want *ContentBlockMemberToolUse, got %T", msgs[0].Content[1]) + } + }, + }, } for _, tt := range tests { @@ -424,3 +460,215 @@ func TestStreamingToolCallParseArgs(t *testing.T) { }) } } + +func TestThinkingOnlyInLastAssistantTurn(t *testing.T) { + contents := []*genai.Content{ + { + Role: "model", + Parts: []*genai.Part{ + {Thought: true, Text: "first think", ThoughtSignature: []byte("sig1")}, + {FunctionCall: &genai.FunctionCall{ID: "c1", Name: "tool_a", Args: map[string]any{}}}, + }, + }, + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c1", Name: "tool_a", Response: map[string]any{"r": "v1"}}}}, + }, + { + Role: "model", + Parts: []*genai.Part{ + {Thought: true, Text: "second think", ThoughtSignature: []byte("sig2")}, + {FunctionCall: &genai.FunctionCall{ID: "c2", Name: "tool_b", Args: map[string]any{}}}, + }, + }, + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c2", Name: "tool_b", Response: map[string]any{"r": "v2"}}}}, + }, + } + + msgs, _ := convertGenaiContentsToBedrockMessages(contents, nil) + if len(msgs) != 4 { + t.Fatalf("want 4 messages, got %d", len(msgs)) + } + + // First assistant turn must NOT contain reasoning content. + for _, block := range msgs[0].Content { + if _, ok := block.(*types.ContentBlockMemberReasoningContent); ok { + t.Error("first assistant turn must not contain reasoning content") + } + } + + // Last assistant turn (index 2) must contain reasoning content. + hasReasoning := false + for _, block := range msgs[2].Content { + if _, ok := block.(*types.ContentBlockMemberReasoningContent); ok { + hasReasoning = true + } + } + if !hasReasoning { + t.Error("last assistant turn must contain reasoning content") + } +} + +func TestHistoricalToolResultTruncation(t *testing.T) { + longOutput := strings.Repeat("x", historyToolResultMaxLen+500) + contents := []*genai.Content{ + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c1", Name: "tool_a", Response: map[string]any{"result": longOutput}}}}, + }, + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c2", Name: "tool_b", Response: map[string]any{"result": longOutput}}}}, + }, + } + + msgs, _ := convertGenaiContentsToBedrockMessages(contents, nil) + if len(msgs) != 2 { + t.Fatalf("want 2 messages, got %d", len(msgs)) + } + + extractText := func(msg types.Message) string { + for _, block := range msg.Content { + if tr, ok := block.(*types.ContentBlockMemberToolResult); ok { + for _, c := range tr.Value.Content { + if txt, ok := c.(*types.ToolResultContentBlockMemberText); ok { + return txt.Value + } + } + } + } + return "" + } + + first := extractText(msgs[0]) + if len(first) >= len(longOutput) { + t.Errorf("historical tool result should be truncated, got len=%d", len(first)) + } + + last := extractText(msgs[1]) + if len(last) != len(longOutput) { + t.Errorf("latest tool result must not be truncated, got len=%d want %d", len(last), len(longOutput)) + } +} + +func TestTruncateToolResult(t *testing.T) { + cases := []struct { + name string + input string + maxLen int + wantLen int + wantMsg bool + }{ + {"no truncation needed", "short", 100, 5, false}, + {"exact boundary", strings.Repeat("a", 100), 100, 100, false}, + {"truncated", strings.Repeat("a", 150), 100, 0, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := truncateToolResult(tc.input, tc.maxLen) + if tc.wantMsg { + if len(got) <= tc.maxLen { + t.Errorf("expected truncated result longer than maxLen, got %d", len(got)) + } + if !strings.Contains(got, "truncated") { + t.Error("truncated result must contain truncation notice") + } + } else { + if len(got) != tc.wantLen { + t.Errorf("want len %d, got %d", tc.wantLen, len(got)) + } + } + }) + } +} + +func TestBuildInferenceConfig(t *testing.T) { + f64 := func(v float64) *float64 { return &v } + f32 := func(v float32) *float32 { return &v } + i32 := func(v int32) *int32 { return &v } + + tests := []struct { + name string + cfg BedrockConfig + thinkingActive bool + wantNil bool + wantTemp *float32 + wantTopP *float32 + wantMaxTokens *int32 + }{ + { + name: "thinking drops temperature and topP", + cfg: BedrockConfig{Temperature: f64(0.7), TopP: f64(0.9)}, + thinkingActive: true, + wantNil: true, + }, + { + name: "thinking with maxTokens keeps only maxTokens", + cfg: BedrockConfig{Temperature: f64(0.7), TopP: f64(0.9), MaxTokens: func() *int { v := 1000; return &v }()}, + thinkingActive: true, + wantNil: false, + wantMaxTokens: i32(1000), + }, + { + name: "no thinking passes temperature and topP", + cfg: BedrockConfig{Temperature: f64(0.7), TopP: f64(0.9)}, + thinkingActive: false, + wantNil: false, + wantTemp: f32(0.7), + wantTopP: f32(0.9), + }, + { + name: "all nil returns nil", + cfg: BedrockConfig{}, + thinkingActive: false, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildInferenceConfig(&tt.cfg, tt.thinkingActive) + if tt.wantNil { + if got != nil { + t.Fatalf("want nil, got %+v", got) + } + return + } + if got == nil { + t.Fatal("want non-nil InferenceConfiguration, got nil") + } + if tt.wantTemp == nil && got.Temperature != nil { + t.Errorf("temperature: want nil, got %v", *got.Temperature) + } + if tt.wantTemp != nil { + if got.Temperature == nil { + t.Fatalf("temperature: want %v, got nil", *tt.wantTemp) + } + if *got.Temperature != *tt.wantTemp { + t.Errorf("temperature: want %v, got %v", *tt.wantTemp, *got.Temperature) + } + } + if tt.wantTopP == nil && got.TopP != nil { + t.Errorf("topP: want nil, got %v", *got.TopP) + } + if tt.wantTopP != nil { + if got.TopP == nil { + t.Fatalf("topP: want %v, got nil", *tt.wantTopP) + } + if *got.TopP != *tt.wantTopP { + t.Errorf("topP: want %v, got %v", *tt.wantTopP, *got.TopP) + } + } + if tt.wantMaxTokens != nil { + if got.MaxTokens == nil { + t.Fatalf("maxTokens: want %v, got nil", *tt.wantMaxTokens) + } + if *got.MaxTokens != *tt.wantMaxTokens { + t.Errorf("maxTokens: want %v, got %v", *tt.wantMaxTokens, *got.MaxTokens) + } + } + }) + } +}