fix: preserve tool_calls in _stream_llm for native tool calling#1864
fix: preserve tool_calls in _stream_llm for native tool calling#1864getglad wants to merge 1 commit intoNVIDIA:developfrom
Conversation
_stream_llm reconstructed AIMessage from content parts only, dropping tool_calls from streamed chunks. This broke use_native_tool_calling on the ReAct agent — the agent treated every response as a final text answer because output_message.tool_calls was always empty. Use LangChain chunk accumulation (chunk + chunk) which preserves tool_call_chunks, then convert via _chunk_to_message. Move _chunk_to_message from tool_calling_agent to base so both agent types share the implementation.
WalkthroughThe changes refactor streaming chunk handling in the LangChain agent codebase by extracting duplicate Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/base.py (1)
134-145: Consider one-pass accumulation to reduce peak memory on long streams.You can preserve behavior while avoiding the intermediate
chunkslist and second accumulation loop.Refactor sketch
- chunks: list[AIMessageChunk] = [] - async for chunk in runnable.astream(inputs, config=self._runnable_config): - chunks.append(chunk) - - if not chunks: - return AIMessage(content="") - - # Accumulate using LangChain's + operator (preserves tool_call_chunks) - accumulated = chunks[0] - for c in chunks[1:]: - accumulated = accumulated + c + accumulated: AIMessageChunk | None = None + async for chunk in runnable.astream(inputs, config=self._runnable_config): + accumulated = chunk if accumulated is None else (accumulated + chunk) + + if accumulated is None: + return AIMessage(content="") return _chunk_to_message(accumulated)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/base.py` around lines 134 - 145, The current code collects all chunks into a list (chunks) and then performs a second loop to fold them with the LangChain + operator, which increases peak memory for long streams; change to one-pass accumulation while iterating runnable.astream: initialize a sentinel (e.g., accumulated = None) before the async for, on first chunk set accumulated = chunk, on subsequent chunks do accumulated = accumulated + chunk, and after the loop return AIMessage(content="") if accumulated is still None otherwise use accumulated; update references to chunks, runnable.astream, accumulated, and AIMessage accordingly so behavior (including preservation of tool_call_chunks) is unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@packages/nvidia_nat_langchain/tests/agent/test_base.py`:
- Around line 114-139: Update the test_streaming_preserves_tool_calls regression
test to also assert that the reconstructed OpenAI wire-format is present in the
returned message: after calling base_agent._stream_llm(mock_runnable, inputs)
and validating result.tool_calls, add an assertion that
result.additional_kwargs["tool_calls"] exists, is a list, and contains an entry
whose "name" == "get_time" (or otherwise matches the expected wire-format
payload produced by mock_astream). This change ensures _stream_llm preserves
both the parsed tool_calls and the serialized wire-format under
additional_kwargs["tool_calls"] for downstream compatibility.
---
Nitpick comments:
In `@packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/base.py`:
- Around line 134-145: The current code collects all chunks into a list (chunks)
and then performs a second loop to fold them with the LangChain + operator,
which increases peak memory for long streams; change to one-pass accumulation
while iterating runnable.astream: initialize a sentinel (e.g., accumulated =
None) before the async for, on first chunk set accumulated = chunk, on
subsequent chunks do accumulated = accumulated + chunk, and after the loop
return AIMessage(content="") if accumulated is still None otherwise use
accumulated; update references to chunks, runnable.astream, accumulated, and
AIMessage accordingly so behavior (including preservation of tool_call_chunks)
is unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 58f3c097-1fa9-4892-a3ae-f0af1782d00d
📒 Files selected for processing (3)
packages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/base.pypackages/nvidia_nat_langchain/src/nat/plugins/langchain/agent/tool_calling_agent/agent.pypackages/nvidia_nat_langchain/tests/agent/test_base.py
| async def test_streaming_preserves_tool_calls(self, base_agent): | ||
| """Test that tool_calls from native tool calling are preserved.""" | ||
| mock_runnable = Mock() | ||
|
|
||
| async def mock_astream(inputs, **kwargs): | ||
| yield AIMessageChunk( | ||
| content="I'll check the time.", | ||
| tool_call_chunks=[{ | ||
| "name": "get_time", | ||
| "args": '{"tz": "UTC"}', | ||
| "id": "call_123", | ||
| "index": 0, | ||
| "type": "tool_call_chunk", | ||
| }], | ||
| ) | ||
|
|
||
| mock_runnable.astream = mock_astream | ||
|
|
||
| inputs = {"messages": [HumanMessage(content="test")]} | ||
| result = await base_agent._stream_llm(mock_runnable, inputs) | ||
|
|
||
| assert isinstance(result, AIMessage) | ||
| assert result.content == "I'll check the time." | ||
| assert len(result.tool_calls) == 1 | ||
| assert result.tool_calls[0]["name"] == "get_time" | ||
|
|
There was a problem hiding this comment.
Assert OpenAI wire-format tool_calls in this regression test.
This test confirms result.tool_calls, but it does not verify result.additional_kwargs["tool_calls"], which is the critical reconstructed wire format for downstream provider compatibility.
Suggested test addition
assert isinstance(result, AIMessage)
assert result.content == "I'll check the time."
assert len(result.tool_calls) == 1
assert result.tool_calls[0]["name"] == "get_time"
+ assert result.additional_kwargs.get("tool_calls"), "Expected OpenAI wire-format tool_calls to be preserved"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@packages/nvidia_nat_langchain/tests/agent/test_base.py` around lines 114 -
139, Update the test_streaming_preserves_tool_calls regression test to also
assert that the reconstructed OpenAI wire-format is present in the returned
message: after calling base_agent._stream_llm(mock_runnable, inputs) and
validating result.tool_calls, add an assertion that
result.additional_kwargs["tool_calls"] exists, is a list, and contains an entry
whose "name" == "get_time" (or otherwise matches the expected wire-format
payload produced by mock_astream). This change ensures _stream_llm preserves
both the parsed tool_calls and the serialized wire-format under
additional_kwargs["tool_calls"] for downstream compatibility.
_stream_llmreconstructedAIMessagefromcontentparts only, droppingtool_callsfrom streamed chunks. This brokeuse_native_tool_callingon the ReAct agent - the agent treated every response as a final text answer becauseoutput_message.tool_callswas always empty.Use
LangChainchunk accumulation (chunk + chunk) which preservestool_call_chunks, then convert via_chunk_to_message. Move_chunk_to_messagefromtool_calling_agenttobaseso both agent types share the implementation.Description
Closes #1865
By Submitting this PR I confirm:
Summary by CodeRabbit
Release Notes