Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions examples/train/tau_bench/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# tau-bench (retail) evaluation

Baseline evaluation of a policy model on the [tau-bench](https://github.com/sierra-research/tau-bench)
retail domain, using SkyRL's eval-only entrypoint (`skyrl.train.entrypoints.main_generate`)
and the `tau_bench` SkyRL-Gym environment.

The environment is multi-turn and tool-using: the agent either calls a retail tool
(`<tool_call>{"name": ..., "arguments": {...}}</tool_call>`) or sends a message to a
simulated user. The user is an LLM served separately over an OpenAI-compatible
endpoint. Reward is the upstream tau-bench retail reward (final DB-state match +
required outputs), so a run scores `pass@1` over the 115-task test split.

## Files

- `tau_bench_dataset.py` — writes `retail_test.parquet` / `retail_train.parquet`
(one row per task; the system prompt + opening user message are built at rollout
time in `TauBenchEnv.init`).
- `run_eval_taubench.sh` — launches a user-simulator vLLM OpenAI server, then runs
the eval-only generation. Single 8×H100 node: user-sim on GPUs 6,7, policy engines
on GPUs 0-5. Override `MODEL_NAME` for the policy under test.
- `anyscale_taubench_eval.yaml` — Anyscale job wrapper for the above.

## Run

```bash
# Local (single node):
bash examples/train/tau_bench/run_eval_taubench.sh MODEL_NAME=<hf-model-id>

# Anyscale:
anyscale job submit -f examples/train/tau_bench/anyscale_taubench_eval.yaml --env HF_TOKEN=$HF_TOKEN
```

Results (incl. `pass@1`) are logged and dumped under `trainer.export_path`.
43 changes: 43 additions & 0 deletions examples/train/tau_bench/anyscale_taubench_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: taubench-eval
entrypoint: |
set -e
echo "===== [taubench-eval] ENTRYPOINT START $(date -u) ====="
bash examples/train/tau_bench/run_eval_taubench.sh
image_uri: novaskyai/skyrl-train-ray-2.51.1-py3.12-cu12.8
cloud: rkn-gpu-cloud
ray_version: "2.51.1"
working_dir: .
max_retries: 0

env_vars:
HF_HOME: "/mnt/cluster_storage/hf_cache"
HF_HUB_ENABLE_HF_TRANSFER: "1"
UV_CACHE_DIR: "/mnt/cluster_storage/.uv_cache"
UV_LINK_MODE: "copy"
VLLM_USE_V1: "1"
VLLM_ALLOW_INSECURE_SERIALIZATION: "1"
# SkyRL's ray.init sets some of the same env-var keys as this job-level runtime
# env; without this, Ray errors on the conflict instead of merging.
RAY_OVERRIDE_JOB_RUNTIME_ENV: "1"
RAY_worker_register_timeout_seconds: "1800"
SKYRL_RAY_PG_TIMEOUT_IN_S: "1800"
SKYRL_WAIT_UNTIL_INFERENCE_SERVER_HEALTHY_TIMEOUT_S: "3600"

compute_config:
advanced_instance_config:
metadata:
labels:
kueue.x-k8s.io/queue-name: default-queue
# Single-node eval: the user-simulator vLLM server (GPUs 6,7) and the SkyRL eval
# engines (GPUs 0-5) are partitioned via CUDA_VISIBLE_DEVICES inside the run script
# and run as co-resident node-local processes. The entrypoint runs on the head, so
# the head node must carry the GPUs -> all 8 H100s on the head, no worker pool.
# (If your kueue only grants GPUs to the worker pool, move these to worker_nodes
# with min_nodes/max_nodes: 1 and run the user-sim via Ray instead.)
head_node:
required_resources:
CPU: 92
memory: 960Gi
GPU: 8
required_labels:
ray.io/accelerator-type: H100
91 changes: 91 additions & 0 deletions examples/train/tau_bench/run_eval_taubench.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
set -x

# NO-TRAINING baseline eval on tau-bench retail (test split, 115 tasks).
# Default policy is Qwen/Qwen3.6-35B-A3B; override MODEL_NAME for any HF model.
# Single 8x H100 node, partitioned:
# GPUs 6,7 -> user-simulator vLLM OpenAI server (fixed model, talks to the env over HTTP)
# GPUs 0-5 -> SkyRL eval engines serving the policy (agent) under test
# Uses SkyRL's eval-only entrypoint (main_generate) with async inference so the many
# concurrent multi-turn conversations + user-sim HTTP calls overlap. Reports eval success rate.

export HF_HOME=${HF_HOME:-/mnt/cluster_storage/hf_cache}
export HF_HUB_ENABLE_HF_TRANSFER=${HF_HUB_ENABLE_HF_TRANSFER:-1}

# --- models (override MODEL_NAME to the exact HF id you want to eval) ---
MODEL_NAME=${MODEL_NAME:-Qwen/Qwen3.6-35B-A3B}
USER_SIM_MODEL=${USER_SIM_MODEL:-Qwen/Qwen2.5-7B-Instruct}

DATA_DIR=${DATA_DIR:-/mnt/cluster_storage/data/tau_bench}
VAL_FILE="$DATA_DIR/retail_test.parquet"
# Persist eval result dumps on shared cluster storage (survives job termination).
# Dumps land under $EXPORT_PATH/dumped_evals/eval_only/.
EXPORT_PATH=${EXPORT_PATH:-/mnt/cluster_storage/exports/taubench_eval}
LOGGER=${LOGGER:-console}
EVAL_N=${EVAL_N:-1} # pass@1; raise for pass^k
MAX_TURNS=${MAX_TURNS:-30}

# --- GPU partition ---
USER_SIM_GPUS=${USER_SIM_GPUS:-6,7}
USER_SIM_TP=${USER_SIM_TP:-2}
USER_SIM_PORT=${USER_SIM_PORT:-8001}
POLICY_GPUS=${POLICY_GPUS:-0,1,2,3,4,5}
NUM_ENGINES=${NUM_ENGINES:-3}
ENGINE_TP=${ENGINE_TP:-2} # NUM_ENGINES * ENGINE_TP must equal the #policy GPUs

USER_SIM_ENDPOINT="http://127.0.0.1:${USER_SIM_PORT}/v1"

# 1. Build the retail eval parquet (always rebuild so dataset-schema changes take effect;
# cluster_storage persists across jobs, so a stale parquet would otherwise be reused).
mkdir -p "$DATA_DIR"
uv run --isolated --extra fsdp python examples/train/tau_bench/tau_bench_dataset.py --output_dir "$DATA_DIR"

# 2. Launch the user-simulator vLLM OpenAI server (background) and wait until healthy.
CUDA_VISIBLE_DEVICES=$USER_SIM_GPUS uv run --isolated --extra fsdp \
vllm serve "$USER_SIM_MODEL" \
--tensor-parallel-size "$USER_SIM_TP" \
--host 127.0.0.1 --port "$USER_SIM_PORT" \
--dtype bfloat16 --gpu-memory-utilization 0.85 \
--max-model-len 32768 --enable-prefix-caching --trust-remote-code \
> /tmp/tau_user_sim.log 2>&1 &
USER_SIM_PID=$!
trap 'echo "[script] stopping user-sim pid=$USER_SIM_PID"; kill -TERM "$USER_SIM_PID" 2>/dev/null' EXIT INT TERM

echo "[script] waiting for user-sim server on :$USER_SIM_PORT ..."
for i in $(seq 1 120); do
if curl -sf "http://127.0.0.1:${USER_SIM_PORT}/v1/models" >/dev/null 2>&1; then
echo "[script] user-sim server is up."
break
fi
if ! kill -0 "$USER_SIM_PID" 2>/dev/null; then
echo "[script] user-sim server died; see /tmp/tau_user_sim.log"; tail -50 /tmp/tau_user_sim.log; exit 1
fi
sleep 10
done

# 3. Eval-only run (no training). Policy engines on the remaining GPUs.
CUDA_VISIBLE_DEVICES=$POLICY_GPUS uv run --isolated --extra fsdp -m skyrl.train.entrypoints.main_generate \
data.val_data="['$VAL_FILE']" \
environment.env_class=tau_bench \
environment.skyrl_gym.tau_bench.user_simulator_endpoint="$USER_SIM_ENDPOINT" \
environment.skyrl_gym.tau_bench.user_simulator_model="$USER_SIM_MODEL" \
trainer.policy.model.path="$MODEL_NAME" \
trainer.placement.colocate_all=false \
generator.inference_engine.backend=vllm \
generator.inference_engine.run_engines_locally=true \
generator.inference_engine.async_engine=true \
generator.inference_engine.num_engines=$NUM_ENGINES \
generator.inference_engine.tensor_parallel_size=$ENGINE_TP \
generator.inference_engine.gpu_memory_utilization=0.85 \
generator.max_turns=$MAX_TURNS \
generator.use_conversation_multi_turn=true \
generator.max_input_length=16384 \
generator.eval_n_samples_per_prompt=$EVAL_N \
generator.eval_sampling_params.temperature=0.0 \
generator.eval_sampling_params.max_generate_length=2048 \
trainer.max_prompt_length=8192 \
trainer.dump_eval_results=true \
trainer.export_path="$EXPORT_PATH" \
trainer.logger="$LOGGER" \
trainer.project_name="taubench-eval" \
trainer.run_name="qwen3.6_35b_retail_baseline" \
"$@"
72 changes: 72 additions & 0 deletions examples/train/tau_bench/tau_bench_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Build parquet datasets for the tau-bench retail environment.

Each row corresponds to one retail task. The (large, static) retail policy/wiki and
tool schemas are NOT stored per row — ``TauBenchEnv.init`` builds the system prompt
from the vendored domain data and obtains the opening user message from the user
simulator. So the dataset only needs to identify the task and (for record-keeping)
the gold trajectory.

Usage:
uv run --isolated python examples/train/tau_bench/tau_bench_dataset.py \
--output_dir /mnt/cluster_storage/data/tau_bench
"""

import argparse
import os

import pandas as pd

from skyrl_gym.envs.tau_bench.tau_core.retail.tasks_test import TASKS_TEST
from skyrl_gym.envs.tau_bench.tau_core.retail.tasks_train import TASKS_TRAIN


def build_rows(tasks, split: str):
rows = []
for task_index, task in enumerate(tasks):
reward_spec = {
"method": "rule",
"ground_truth": {
"actions": [{"name": a.name, "kwargs": a.kwargs} for a in task.actions],
"outputs": list(task.outputs),
},
}
rows.append(
{
"data_source": "tau_bench_retail",
# Placeholder only — the real prompt (system wiki + dynamic opening user
# message) is constructed in TauBenchEnv.init, which overrides this. It must
# NOT contain the hidden user instruction. A user turn is required so the
# chat template applies cleanly at dataset load (Qwen errors on system-only).
"prompt": [
{"role": "system", "content": "You are a retail customer service agent."},
{"role": "user", "content": "Hi."},
],
"env_class": "tau_bench",
"reward_spec": reward_spec,
"task_index": task_index,
"task_split": split,
}
)
return rows


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output_dir", default=os.path.expanduser("~/data/tau_bench"))
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)

val_rows = build_rows(TASKS_TEST, "test")
train_rows = build_rows(TASKS_TRAIN, "train")

val_path = os.path.join(args.output_dir, "retail_test.parquet")
train_path = os.path.join(args.output_dir, "retail_train.parquet")
pd.DataFrame(val_rows).to_parquet(val_path)
pd.DataFrame(train_rows).to_parquet(train_path)

print(f"Wrote {len(val_rows)} eval tasks -> {val_path}")
print(f"Wrote {len(train_rows)} train tasks -> {train_path}")


if __name__ == "__main__":
main()
7 changes: 6 additions & 1 deletion skyrl-gym/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ requires-python = ">=3.10"

dependencies = [
"func_timeout",
"pandas",
"pandas",
"requests",
"omegaconf",
]
Expand All @@ -27,6 +27,11 @@ Repository = "https://github.com/NovaSky-AI/SkyRL"
[tool.setuptools.packages.find]
include = ["skyrl_gym*"]

[tool.setuptools.package-data]
# Vendored tau-bench retail domain data (DB json + policy wiki) must ship in wheels.
"skyrl_gym.envs.tau_bench.tau_core.retail" = ["*.md"]
"skyrl_gym.envs.tau_bench.tau_core.retail.data" = ["*.json", "*.md"]

[project.optional-dependencies]
dev = [
"pytest"
Expand Down
5 changes: 5 additions & 0 deletions skyrl-gym/skyrl_gym/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
entry_point="skyrl_gym.envs.searchcode.env:SearchCodeEnv",
)

register(
id="tau_bench",
entry_point="skyrl_gym.envs.tau_bench.env:TauBenchEnv",
)

__all__ = [
"deregister",
"register",
Expand Down
Empty file.
Loading
Loading