Skip to content

Commit d1f8142

Browse files
authored
Merge pull request #2 from AgentR1/blackbox_dev
Feat: Offline Blackbox Agent Support and GSM8k Specialized Agent
2 parents 78b2a8c + 2970744 commit d1f8142

File tree

13 files changed

+631
-19
lines changed

13 files changed

+631
-19
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
## News
1313

14-
- **[2026.03]** 🚧 **Claw-R1 Project Init.** We are actively updating the framework. Stay tuned for more features and documentation.
14+
- **[2026.03.06]** 📖 **Claw-R1 Documentation Released.** Project page and documentation are now available at [Claw-R1 Project Page](https://agentr1.github.io/) and [Claw-R1 docs](https://agentr1.github.io/Claw-R1/).
15+
16+
- **[2026.03.03]** 🚧 **Claw-R1 Project Init.** We are actively updating the framework. Stay tuned for more features and documentation.
1517

1618
## Overview
1719

claw_r1/async_rollouter.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,11 +264,22 @@ def _init_gateway(self):
264264
str(gateway_port),
265265
]
266266

267-
self._gateway_process = subprocess.Popen(cmd)
267+
self._gateway_process = subprocess.Popen(
268+
cmd,
269+
stdout=subprocess.DEVNULL,
270+
stderr=subprocess.PIPE,
271+
text=True,
272+
)
268273
self._gateway_url = f"http://localhost:{gateway_port}"
269274
atexit.register(self._stop_gateway)
270275

271276
for _ in range(120):
277+
if self._gateway_process.poll() is not None:
278+
_, err = self._gateway_process.communicate()
279+
err = (err or "").strip() or "(no stderr)"
280+
raise RuntimeError(
281+
f"Gateway process exited before ready ({self._gateway_url}). stderr:\n{err}"
282+
)
272283
try:
273284
resp = httpx.get(f"{self._gateway_url}/docs", timeout=2.0)
274285
if resp.status_code == 200:
@@ -277,7 +288,11 @@ def _init_gateway(self):
277288
except Exception:
278289
pass
279290
time.sleep(1)
280-
raise RuntimeError(f"Gateway did not start within 120s ({self._gateway_url})")
291+
raise RuntimeError(
292+
f"Gateway did not start within 120s ({self._gateway_url}). "
293+
"Check that port %s is free and no firewall blocks it."
294+
% gateway_port
295+
)
281296

282297
def _stop_gateway(self):
283298
proc = getattr(self, "_gateway_process", None)

claw_r1/async_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def fit(self):
270270
def _process_batch(self, batch: DataProto, metrics: dict, timing_raw: dict) -> DataProto:
271271
"""Run the full PPO pipeline on a single batch."""
272272
batch.meta_info["global_token_num"] = batch.batch["attention_mask"].sum(dim=-1).tolist()
273+
batch.meta_info.setdefault("temperature", self.config.actor_rollout_ref.rollout.temperature)
273274

274275
if "response_mask" not in batch.batch:
275276
batch.batch["response_mask"] = compute_response_mask(batch)

claw_r1/blackbox_agent/__init__.py

Whitespace-only changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
- name: blackbox_gsm8k_agent
2+
_target_: claw_r1.blackbox_agent.gsm8k_agent_flow.BlackBoxGSM8KAgentFlow
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""Black-box agent flow — base class.
2+
3+
BlackBoxAgentFlowBase handles the full protocol with Gateway (init_trajectory,
4+
register_trajectory, complete) and delegates agent execution to subclasses via
5+
_run_agent. Subclasses only create and run the concrete Agent; they do not
6+
touch Gateway or implement any task logic. Concrete strategies live in
7+
separate modules (e.g. gsm8k_agent_flow.py).
8+
"""
9+
10+
import json
11+
import logging
12+
import os
13+
from abc import abstractmethod
14+
from typing import Any
15+
16+
import httpx
17+
import numpy as np
18+
19+
from claw_r1.agent_flow.agent_flow import AgentFlowBase, register
20+
21+
logger = logging.getLogger(__name__)
22+
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
23+
24+
_DEFAULT_SKIP_KEYS = frozenset({"raw_prompt", "multi_modal_data", "channel", "agent_name"})
25+
26+
27+
class _NumpyEncoder(json.JSONEncoder):
28+
"""JSON encoder that converts numpy scalars to native Python types for HTTP requests."""
29+
30+
def default(self, o):
31+
if isinstance(o, np.integer):
32+
return int(o)
33+
if isinstance(o, np.floating):
34+
return float(o)
35+
if isinstance(o, np.ndarray):
36+
return o.tolist()
37+
return super().default(o)
38+
39+
40+
class BlackBoxAgentFlowBase(AgentFlowBase):
41+
"""Base class for black-box agent flows.
42+
43+
Handles generic parameter processing and the full Gateway protocol:
44+
init_trajectory (get base_url) -> register_trajectory (channel + metadata)
45+
-> call subclass _run_agent -> complete. Subclasses only implement
46+
_run_agent to create and run the concrete Agent.
47+
"""
48+
49+
def _prepare_params(self, kwargs: dict[str, Any]) -> tuple[str | None, str, dict[str, Any]]:
50+
"""Extract channel, prompt_uid, and metadata from kwargs."""
51+
channel = kwargs.pop("channel", None)
52+
prompt_uid = str(kwargs.get("uid", "1"))
53+
metadata = {k: v for k, v in kwargs.items() if k not in _DEFAULT_SKIP_KEYS}
54+
return channel, prompt_uid, metadata
55+
56+
async def run(self, sampling_params: dict[str, Any], **kwargs) -> int:
57+
channel, prompt_uid, metadata = self._prepare_params(kwargs)
58+
59+
async with httpx.AsyncClient(timeout=30.0) as http:
60+
# 1. Allocate trajectory — get base_url with trajectory_uid embedded.
61+
init_resp = await http.post(f"{self.gateway_url}/init_trajectory")
62+
init_resp.raise_for_status()
63+
init_data = init_resp.json()
64+
base_url_from_init = init_data["base_url"]
65+
# base_url_from_init is http://host:port/{traj_uid}/{default_prompt_uid}/v1
66+
# Replace the default prompt_uid with the actual one.
67+
parts = base_url_from_init.rsplit("/", 2) # [...base, prompt_uid, "v1"]
68+
base_url = f"{parts[0]}/{prompt_uid}/v1"
69+
70+
# 2. Register channel + metadata via base_url.
71+
reg_body: dict[str, Any] = {}
72+
if channel:
73+
reg_body["channel"] = channel
74+
if metadata:
75+
reg_body["metadata"] = metadata
76+
payload = json.dumps(reg_body, cls=_NumpyEncoder).encode()
77+
await http.post(
78+
f"{base_url}/register_trajectory",
79+
content=payload,
80+
headers={"content-type": "application/json"},
81+
)
82+
83+
# 3. Run the concrete agent.
84+
try:
85+
num_turns = await self._run_agent(base_url, kwargs)
86+
finally:
87+
# 4. Mark trajectory complete.
88+
async with httpx.AsyncClient(timeout=httpx.Timeout(600.0)) as http:
89+
await http.post(f"{base_url}/complete_trajectory")
90+
91+
return num_turns
92+
93+
@abstractmethod
94+
async def _run_agent(self, base_url: str, kwargs: dict[str, Any]) -> int:
95+
"""Create and run the concrete Agent. Subclasses implement this."""
96+
raise NotImplementedError
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""GSM8K black-box agent — fully independent of training internals.
2+
3+
This agent uses a standard OpenAI-compatible API to interact with the LLM,
4+
parses tool calls from raw text output (Qwen-style ``<tool_call>`` tags),
5+
and executes a local ``check_answer`` tool.
6+
7+
It knows nothing about trajectory UIDs, Steps, DataPool, or reward — all of
8+
those are transparently handled by the Gateway.
9+
"""
10+
11+
import json
12+
import logging
13+
14+
import regex
15+
16+
logger = logging.getLogger(__name__)
17+
18+
CHECK_ANSWER_TOOL = {
19+
"type": "function",
20+
"function": {
21+
"name": "check_answer",
22+
"description": "Check if your answer to the math problem is correct.",
23+
"parameters": {
24+
"type": "object",
25+
"properties": {
26+
"answer": {
27+
"type": "string",
28+
"description": "Your final numerical answer",
29+
}
30+
},
31+
"required": ["answer"],
32+
},
33+
},
34+
}
35+
36+
TOOL_CALL_REGEX = regex.compile(r"<tool_call>(.*?)</tool_call>", regex.DOTALL)
37+
38+
39+
def parse_tool_calls(content: str) -> tuple[str, list[dict]]:
40+
"""Extract ``<tool_call>`` blocks from raw LLM output.
41+
42+
Mirrors the parsing logic of verl's ``HermesToolParser``.
43+
44+
Returns:
45+
(remaining_text, list_of_tool_calls) where each tool call is a dict
46+
with ``name`` and ``arguments`` keys.
47+
"""
48+
if "<tool_call>" not in content:
49+
return content, []
50+
51+
matches = TOOL_CALL_REGEX.findall(content)
52+
tool_calls = []
53+
for match in matches:
54+
try:
55+
parsed = json.loads(match)
56+
if not isinstance(parsed, dict):
57+
continue
58+
tool_calls.append({"name": parsed["name"], "arguments": parsed["arguments"]})
59+
except (json.JSONDecodeError, KeyError, TypeError):
60+
pass
61+
62+
remaining = TOOL_CALL_REGEX.sub("", content).strip()
63+
return remaining, tool_calls
64+
65+
66+
def check_answer(answer: str, ground_truth: str) -> str:
67+
"""Run local answer verification, returning textual feedback only."""
68+
from verl.utils.reward_score.gsm8k import compute_score
69+
70+
score = compute_score(
71+
f"#### {answer}",
72+
ground_truth,
73+
method="flexible",
74+
format_score=0.0,
75+
score=1.0,
76+
)
77+
if score > 0:
78+
return "Correct! Your answer is right."
79+
return "Incorrect. Your answer is wrong, please try again."
80+
81+
82+
class GSM8KAgent:
83+
"""Stateless GSM8K solving agent that talks to an OpenAI-compatible API.
84+
85+
The agent is completely unaware of training-side concepts such as
86+
``trajectory_uid``, ``Step``, or ``DataPool``. All it needs is a
87+
``base_url`` pointing to an OpenAI-compatible endpoint.
88+
89+
Args:
90+
base_url: Root URL for the API, e.g. ``http://host:port/{traj}/{prompt}``.
91+
The OpenAI SDK client will use ``{base_url}/v1`` as its base.
92+
"""
93+
94+
def __init__(self, base_url: str):
95+
import openai
96+
97+
self.base_url = base_url.rstrip("/")
98+
self.client = openai.AsyncOpenAI(
99+
base_url=self.base_url,
100+
api_key="not-needed",
101+
timeout=600.0,
102+
)
103+
104+
async def solve(self, question: str, ground_truth: str, max_turns: int = 3) -> int:
105+
"""Attempt to solve *question* in up to *max_turns* LLM interactions.
106+
107+
Returns the number of turns actually used. Trajectory completion is
108+
signaled by the caller (BlackBoxAgentFlowBase or online service entrypoint).
109+
"""
110+
messages: list[dict] = [{"role": "user", "content": question}]
111+
112+
turns_used = 0
113+
for turn in range(max_turns):
114+
turns_used = turn + 1
115+
116+
resp = await self.client.chat.completions.create(
117+
model="default",
118+
messages=messages,
119+
tools=[CHECK_ANSWER_TOOL],
120+
)
121+
content = resp.choices[0].message.content or ""
122+
_, tool_calls = parse_tool_calls(content)
123+
124+
if tool_calls:
125+
messages.append({"role": "assistant", "content": content})
126+
for tc in tool_calls:
127+
if tc["name"] == "check_answer":
128+
answer = tc["arguments"].get("answer", "")
129+
result = check_answer(answer, ground_truth)
130+
messages.append({"role": "tool", "content": result})
131+
else:
132+
messages.append({"role": "assistant", "content": content})
133+
break
134+
135+
return turns_used
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""GSM8K black-box agent flow — concrete strategy for GSM8K."""
2+
3+
from typing import Any
4+
5+
from claw_r1.agent_flow.agent_flow import register
6+
7+
from claw_r1.blackbox_agent.blackbox_agent_flow import BlackBoxAgentFlowBase
8+
9+
from claw_r1.blackbox_agent.gsm8k_agent import GSM8KAgent
10+
11+
12+
@register("blackbox_gsm8k_agent")
13+
class BlackBoxGSM8KAgentFlow(BlackBoxAgentFlowBase):
14+
"""Black-box flow that delegates to :class:`GSM8KAgent`."""
15+
16+
async def _run_agent(self, base_url: str, kwargs: dict[str, Any]) -> int:
17+
raw_prompt = kwargs.get("raw_prompt", [])
18+
if isinstance(raw_prompt, list) and raw_prompt:
19+
question = next(
20+
(m.get("content", "") for m in reversed(raw_prompt) if m.get("role") == "user"),
21+
str(raw_prompt),
22+
) or str(raw_prompt)
23+
elif isinstance(raw_prompt, str):
24+
question = raw_prompt
25+
else:
26+
question = str(raw_prompt)
27+
28+
reward_model = kwargs.get("reward_model", {})
29+
if isinstance(reward_model, dict):
30+
ground_truth = str(reward_model.get("ground_truth", ""))
31+
else:
32+
ground_truth = str(getattr(reward_model, "ground_truth", ""))
33+
34+
max_turns = self.config.actor_rollout_ref.rollout.get("max_turns", 3)
35+
agent = GSM8KAgent(base_url=base_url)
36+
return await agent.solve(question=question, ground_truth=ground_truth, max_turns=max_turns)

claw_r1/data_pool/training_backend.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,12 @@ def _pad_single_step(self, step: Step) -> dict[str, Any]:
135135
136136
Returns a dict of tensors, each with a leading batch dim of 1.
137137
"""
138+
pad_token_id = self._tokenizer.pad_token_id or 0
139+
138140
self._tokenizer.padding_side = "left"
141+
prompt_ids = step.prompt_ids if step.prompt_ids else [pad_token_id]
139142
prompt_out = self._tokenizer.pad(
140-
{"input_ids": step.prompt_ids},
143+
{"input_ids": prompt_ids},
141144
padding="max_length",
142145
max_length=self._prompt_length,
143146
return_tensors="pt",
@@ -148,8 +151,9 @@ def _pad_single_step(self, step: Step) -> dict[str, Any]:
148151
prompt_out["attention_mask"] = prompt_out["attention_mask"].unsqueeze(0)
149152

150153
self._tokenizer.padding_side = "right"
154+
response_ids = step.response_ids if step.response_ids else [pad_token_id]
151155
response_out = self._tokenizer.pad(
152-
{"input_ids": step.response_ids},
156+
{"input_ids": response_ids},
153157
padding="max_length",
154158
max_length=self._response_length,
155159
return_tensors="pt",

0 commit comments

Comments
 (0)